Compare commits

...

63 Commits
b4727 ... b4790

Author SHA1 Message Date
Rémy O
438a83926a vulkan: add specific MMV kernels for IQ2 and IQ3 quants + optimizations (#11595)
* vulkan: implement specialized MMV kernels for IQ2 quantizations

* vulkan: add MMV kernels for IQ3 quants

* vulkan: Increase MMV batch size and unroll IQ LUT setup

* vulkan: fix init_iq_shmem for WG sizes larger than tables

* vulkan: common batch size for all I-quants
2025-02-28 09:42:52 +01:00
Johannes Gäßler
9c42b1718c CUDA: fix logic for V100 + GGML_CUDA_FORCE_MMQ (#12098) 2025-02-28 09:26:43 +01:00
Prashant Vithule
05e6f5aad0 ggml: aarch64: implement SVE kernels for q2_k_q8_k vector dot (#12064)
* Added SVE Support for Q2_K Quantized Models

* Use 4-space indentation in the switch cases

* removed comments lines

* Remove the loop Retain the curly bracess for better understanding of code

* Remove the comment like added for q3_k_q8_k kernel

---------

Co-authored-by: vithulep <p.m.vithule1517@gmail.com>
2025-02-28 09:36:12 +02:00
hipudding
673cfef9aa CANN: Fix build error with GCC 13 (#11990)
Remove unused header file that causes compilation failure on ARM
platform with GCC 13.
2025-02-28 15:23:47 +08:00
Eve
fbeda9002d vulkan: matmul dequantization improvements (#12015)
* faster dequant for old quants

* dont use unpack for iq4_nl

* vec2 unpack for q8
2025-02-28 08:20:08 +01:00
Daniele
581650b7ca vulkan: improve im2col (#11826)
* vulkan: improve im2col performance
2025-02-28 07:52:51 +01:00
Vladimir Vuksanovic
b95c8af37c cmake: Fix ggml backend dependencies and installation (#11818)
* Fix dependencies between ggml and backends

ggml backends link only to ggml-base and ggml links to all backends.

* Fix installation of ggml backends

Set up GNUInstallDirs before setting the installation directory of ggml backends
2025-02-27 09:42:48 +02:00
Ting Lou
a800ae46da llava : add struct for FFI bindgen (#12079)
Some checks failed
flake8 Lint / Lint (push) Has been cancelled
Python Type-Check / pyright type-check (push) Has been cancelled
* add struct for FFI bindgen

* Apply suggestions from code review

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2025-02-26 15:26:52 +01:00
Sigbjørn Skjæret
69050a11be Refactor gguf scripts to improve metadata handling (#11909)
* Refactor gguf scripts to improve metadata handling

Added contents method to ReaderField class
Added endianess property to GGUFReader class

* update scripts

* fix import

* remove unused import

* attempt to work around flake and pyright errors

* second attempt

* give up, ignore type

* bump version

* apply newbyteorder fixes
2025-02-26 08:04:48 -05:00
Aleksei Nikiforov
3567ee3a94 gguf-py: enable reading non-native endian files (#12081)
Currently self.byte_order is never used.
Actually use it to byteswap read data to
allow reading big endian files on little endian systems
and vice versa.

Now it's possible to convert little-endian model
into a big-endian model and back
on a little-endian system.
2025-02-26 11:39:27 +00:00
Kante Yin
53e4db1012 readme : update infra list (#9096)
Signed-off-by: kerthcet <kerthcet@gmail.com>
2025-02-26 09:49:36 +02:00
Olivier Chafik
d7cfe1ffe0 docs: add docs/function-calling.md to lighten server/README.md's plight (#12069) 2025-02-25 18:52:56 +00:00
Jeff Bolz
a82c9e7c23 vulkan: fix assertion when qy_needs_dequant (#12068)
Looks like a copy/paste bug from qx_needs_dequant.
2025-02-25 16:30:21 +01:00
rhjdvsgsgks
401af80b54 server: handle echo=false on /v1/completions (#12060) 2025-02-25 12:52:52 +01:00
Judd
c132239bfb add OP sigmoid (#12056)
Co-authored-by: Judd <foldl@boxvest.com>
2025-02-25 12:32:20 +01:00
Molly Sophia
393fca629e ggml-cpu: Fix build with sve (#12059)
* ggml-cpu: Fix build with sve

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml-cpu: Remove unused variable in sve q3_k vec dot

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
2025-02-25 19:28:22 +08:00
Rémy O
61d4f39dfe vulkan: implement more backpropagation operators (#11914)
* vulkan: implement GGML_OP_ROPE_BACK

* vulkan: implement GGML_OP_RMS_NORM_BACK

* vulkan: implement GGML_OP_SILU_BACK

* vulkan: implement GGML_OP_SOFTMAX_BACK
2025-02-25 12:04:45 +01:00
Olivier Chafik
0b52745649 server: support add_generation_prompt query param (#12062) 2025-02-25 10:40:22 +00:00
Alex Brooks
4d1051a40f Add Doc for Converting Granite Vision -> GGUF (#12006)
* Add example docs for granite vision

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
2025-02-25 10:46:05 +01:00
Vitali Lovich
3e9a2860e9 llama : expose llama_model_n_head_kv in the API (#11997)
It's useful to be able to have this from the library layer as it's a key
parameter of the model (e.g. to figure out how much KV cache memory is
needed).
2025-02-25 11:29:33 +02:00
Gian-Carlo Pascutto
58d07a8043 metal : copy kernels for quant to F32/F16 conversions (#12017)
metal: use dequantize_q templates

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-02-25 11:27:58 +02:00
lhez
34a846b584 opencl: fix for small models (#11950)
* opencl: fix small shape gemv, remove unused extensions

* opencl: fix `transpose_16`, `dump_tensor`, enforce subgroup size

* opencl: fix for token length < 4

* opencl: use wave size of 64 for all Adreno GPUs

---------

Co-authored-by: Shawn Gu <quic_shawngu@quicinc.com>
Co-authored-by: Skyler Szot <quic_sszot@quicinc.com>
2025-02-24 14:47:07 -07:00
Alex Brooks
7a2c913e66 llava : Add Granite Vision Support (#11794)
Some checks failed
flake8 Lint / Lint (push) Has been cancelled
Python Type-Check / pyright type-check (push) Has been cancelled
* Add super wip scripts for multimodal granite gguf

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Add example for converting mmgranite to gguf

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* remove hardcoded path

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Add vision feature layer to gguf params

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Clean up llava surgery and remove name substitution hacks

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Add transformers llava next tensor name mapping

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Make siglip / openclip mutuall exclusive

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Fix projector linear substitution

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Fix linear 2 substitution index

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Increase max flattened gridpoints to 64

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Fix hardcoded concat for multiple feature layers

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Pull vision feature layers out of gguf keys

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* fix num gridpoints and use all layers

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Avoid dropping last image encoder layer in llava models

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Use 10 for max number of patches

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Standardize vision feature layers

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Cleanup logs

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Update comment for vision feature layer init

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Update notes for alternative to legacy llm conversion script

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Fix notes rendering

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Add v prefix to vision feature layer log

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Use current defaults for feature layer

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Use constant for max gridpoints / feat layers, style fixes

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* clarify non-negative feature layers

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Remove CLIP_API from func signature

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* USE MAX_IMAGE_FEATURE_LAYERS const in layer calc

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Clarify feature layers are non negative ints and not uint

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Fix condition for reading feature layers

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* pop last llava layer when feature layers are unset

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Fix unset vision layer 0

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Update examples/llava/clip.cpp

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

* Reenable assertion for out of bounds get_rows

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Use std vector for gridpoints and feature layers

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Caculate max feature layer at load time

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Include base patch for granite vision allocation

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Fix trailing whitespace

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Add max num patches = 10 back for minicpmv

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Use unordered set to store feature layers

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Use max feature layer for postnorm

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Apply suggestions from code review

---------

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2025-02-24 17:09:51 +01:00
Neo Zhang Jianyu
08d5986290 [SYCL] Optimize mul_mat for Q4_0 on Intel GPU (#12035)
* opt performance by reorder for Intel GPU

* detect hw type and save opt feature, and print opt feature

* correct name

* support optimize graph once when compute graph, record the opt status in tensor->extra, make CI passed

* add env variable GGML_SYCL_DISABLE_OPT for debug

* use syclex::architecture replace the custom hw define, update the guide for GGML_SYCL_DISABLE_OPT

* add performance data

* mv getrows functions to separeted files

* fix global variables

---------

Co-authored-by: arthw <14088817+arthw@users.noreply.github.com>
2025-02-24 22:33:23 +08:00
Aleksei Nikiforov
651adf4b66 gguf_convert_endian.py: implement byteswapping for q4_k and q6_k (#11349) 2025-02-24 11:27:01 +00:00
Akarshan Biswas
8303e8b0fb SYCL: Fix GGML_SYCL_DEBUG macro (#11995) 2025-02-24 10:18:25 +00:00
Florent BENOIT
7ad0779f5d run: allow to customize prompt by env var LLAMA_PROMPT_PREFIX (#12041)
Signed-off-by: Florent Benoit <fbenoit@redhat.com>
2025-02-23 17:15:51 +00:00
Eric Curtin
f777a73e18 Some llama-run cleanups (#11973)
Use consolidated open function call from File class. Change
read_all to to_string(). Remove exclusive locking, the intent for
that lock is to avoid multiple processes writing to the same file,
it's not an issue for readers, although we may want to consider
adding a shared lock. Remove passing nullptr as reference,
references are never supposed to be null. clang-format the code
for consistent styling.

Signed-off-by: Eric Curtin <ecurtin@redhat.com>
2025-02-23 13:14:32 +00:00
Aaron Teo
af7747c95a ggml-cpu: Support s390x SIMD Instruction Set (#12019)
* ggml: add s390x ARCH_FLAGS for compilation

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add SIMD for s390x using vector intrinsics

SIMD is activated for:
* ggml_vec_dot_f32
* ggml_vec_dot_f16
* ggml_vec_mad_f32
* ggml_vec_mad_f16
* ggml_vec_mad_f32_unroll
* ggml_vec_scale_f32
* ggml_vec_scale_f16

SIMD is NOT activated for:
* ggml_vec_dot_f16_unroll (pending bugfix)

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix missing escape character in GGML_F32x4_REDUCE

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add temporary patch for GGML_F32_ARR and GGML_F16_ARR

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix s390x GGML_F32x4_REDUCE

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: full SIMD activation for F32,F16 s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add option to disable s390x VXE/VXE2

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: change vecintrin.h include to ggml-cpu-impl

* add __VXE__ and __VXE2__ macros

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* cmake: add s390x target detection for VX/VXE/VXE2

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: move s390x vector intrinsics to ggml-cpu-impl.h

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x Q8_0 SIMD

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: correct documentation for Q8_0

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x reduce code complexity Q8_0

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x bugfix typo Q8_0

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activated for Q4_1

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x inline vec_reve

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for Q4_0

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add VXE backend feature

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: remove test.py

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for quantize_row_q8_0

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for quantize_row_q8_1

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for iq4_xs

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: bugfix iq4_xs

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for iq4_nl

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add float, double, and long vector data type

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: clean up iq4_xs SIMD

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix improper use of restrict keyword

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: update warning message for ggml_vec_tbl

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: untested implementation of ggml_vec_dot_iq2_xxs_q8_K

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: update ggml_vec_dot_q4_1_q8_1 to use typedefs

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: switch to restrict for iq4_nl

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: slight dot product speed improvement for q4_1_q8_1

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for q6_K

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add missing `_t` to ggml_int8x16x4_t

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix missing `_t` for ggml_vec_xl_s8x4

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix more missing `_t`

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add unroll and prefetch to Q8_0

increase of 3.86% for prompt processing and 32.22% for token generation

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: patch Q8_0 to use proper vector sizes

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: optimise Q8_0 dot prod compute kernel further

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add unroll and prefetch to Q4_1

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: refactor Q6_K variable naming for readability

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix Q6_K typos

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for Q5_K

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix wrong char*x16_t naming

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: Q5_K y0 wrong signness

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix Q5_K invalid uchar type

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix Q5_K invalid uchar type

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for Q4_K

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix Q4_K invalid vector intrinsics

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: simplify ggml_padd_s16 compute kernel

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: correct ggml-cpu vxe wording

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: change ggml_aligned_malloc alignment to 256

256 is the cache line size for s390x platforms

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: resolve pr merge via cherry-pick 225bbbf

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml : fix LoongArch compile error with 128-bit SIMD (#11701)

* ggml: resolve pr merge via cherry-pick 4571953

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: cmake remove fork when determining s390x machine type

thank you @ericcurtin

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

---------

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
Co-authored-by: Jinyang He <hejinyang@loongson.cn>
Co-authored-by: junchao-zhao <68935141+junchao-loongson@users.noreply.github.com>
2025-02-22 21:39:24 +00:00
Johannes Gäßler
a28e0d5eb1 CUDA: app option to compile without FlashAttention (#12025) 2025-02-22 20:44:34 +01:00
Ting Lou
36c258ee92 llava: build clip image from pixels (#11999)
Some checks failed
flake8 Lint / Lint (push) Has been cancelled
Python Type-Check / pyright type-check (push) Has been cancelled
* llava: export function `clip_build_img_from_pixels` to build image from pixels decoded by other libraries instead of stb_image.h for better performance

* Apply suggestions from code review

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2025-02-22 15:28:28 +01:00
Georgi Gerganov
f3e64859ed ci : fix arm upload artifacts (#12024)
* ci : fix arm upload artifacts

* cont : fix archive name to use matrix
2025-02-22 15:03:00 +02:00
Johannes Gäßler
5fa07c2f93 CUDA: optimize FA for GQA + large batches (#12014) 2025-02-22 12:20:17 +01:00
Rohanjames1997
335eb04a91 ci : Build on Github-hosted arm64 runners (#12009) 2025-02-22 11:48:57 +01:00
Georgi Gerganov
cf756d6e0a server : disable Nagle's algorithm (#12020) 2025-02-22 11:46:31 +01:00
Gian-Carlo Pascutto
d70908421f cuda: Add Q5_1, Q5_0, Q4_1 and Q4_0 to F32 conversion support. (#12000) 2025-02-22 09:43:24 +01:00
Daniel Bevenius
de8b5a3624 llama.swiftui : add "Done" dismiss button to help view (#11998)
The commit updates the help view in the llama.swiftui example to use a
NavigationView and a Done button to dismiss the help view.

The motivation for this is that without this change there is now way to
dimiss the help view.
2025-02-22 06:33:29 +01:00
Georgi Gerganov
51f311e057 llama : skip loading unused tensors (#12004)
* llama : assign unknown/unused tensors to host buffer type

ggml-ci

* llama : skip unused tensors

ggml-ci
2025-02-21 18:33:18 +02:00
Johannes Gäßler
586d5fe6eb doc: update contributing guidelines [no ci] (#11969) 2025-02-21 12:51:25 +01:00
PureJourney
ecc8e3aeff CUDA: correct the lowest Maxwell supported by CUDA 12 (#11984)
* CUDA: correct the lowest Maxwell supported by CUDA 12

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2025-02-21 12:21:05 +01:00
Bodhi
0b3863ff95 MUSA: support ARM64 and enable dp4a .etc (#11843)
* MUSA:  support ARM64 and enable __dp4a .etc

* fix cross entropy loss op for musa

* update

* add cc info log for musa

* add comment for the MUSA .cc calculation block

---------

Co-authored-by: Bodhi Hu <huaishun.hu@mthreads.com>
2025-02-21 09:46:23 +02:00
Alex Brooks
ee02ad02c5 clip : fix visual encoders with no CLS (#11982)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
2025-02-21 08:11:03 +02:00
momonga
c392e5094d server (webui): Fix Premature Submission During IME Conversion (#11971)
* fix skip ime composing

* fix npm rebuild

* fix warn

---------

Co-authored-by: momonga <115213907+mmnga@users.noreply.github.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-02-20 19:43:22 +01:00
Charles Xu
c5d91a7400 ggml-cpu: Add CPU backend support for KleidiAI library (#11390)
* ggml-cpu: Add CPU backend support for KleidiAI library

* Add environmental variable GGML_KLEIDIAI_SME

* Add support for multithread LHS conversion

* Switch kernel selection order to dotprod and i8mm

* updates for review comments

* More updates for review comments

* Reorganize and rename KleidiAI files

* Move ggml-cpu-traits.h to source file

* Update cmake for SME build and add alignment for SME

* Remove append GGML_USE_CPU_KLEIDIAI to the GGML_CDEF_PUBLIC list
2025-02-20 15:06:51 +02:00
Prashant Vithule
4806498bf1 ggml: aarch64: implement SVE kernels for q3_K_q8_K vector dot (#11917)
* Added SVE Implementation for Q3_K Kernel in ggml-cpu-quants.c file

* Improved Formating of code in  ggml-cpu-quants.c file

* style : minor fixes

* style : less whitespaces

* style : ptr spaceing

---------

Co-authored-by: vithulep <p.m.vithule1517@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-02-20 12:08:32 +02:00
Michael Engel
0d559580a0 run : add --chat-template-file (#11961)
Relates to: https://github.com/ggml-org/llama.cpp/issues/11178

Added --chat-template-file CLI option to llama-run. If specified, the file
will be read and the content passed for overwriting the chat template of
the model to common_chat_templates_from_model.

Signed-off-by: Michael Engel <mengel@redhat.com>
2025-02-20 10:35:11 +02:00
Johannes Gäßler
d04e7163c8 doc: add links to ggml examples [no ci] (#11958) 2025-02-19 20:45:17 +01:00
Daniel Bevenius
d07c621393 common : add llama.vim preset for Qwen2.5 Coder (#11945)
This commit adds a preset for llama.vim to use the default Qwen 2.5
Coder models.

The motivation for this change is to make it easier to start a server
suitable to be used with the llama.vim plugin. For example, the server
can be started with a command like the following:
```console
$ llama.vim --fim-qwen-1.5b-default
```

Refs: https://github.com/ggml-org/llama.cpp/issues/10932
2025-02-19 12:29:52 +01:00
Georgi Gerganov
abd4d0bc4f speculative : update default params (#11954)
* speculative : update default params

* speculative : do not discard the last drafted token
2025-02-19 13:29:42 +02:00
Daniel Bevenius
9626d9351a llama : fix indentation in llama-grammar [no ci] (#11943)
This commit adjusts the indentation for the functions `parse_sequence`
and `parse_rule` in src/llama-grammar.cpp.

The motivation is consistency and improve readability.
2025-02-19 06:16:23 +01:00
igardev
b58934c183 server : (webui) Enable communication with parent html (if webui is in iframe) (#11940)
Some checks failed
flake8 Lint / Lint (push) Has been cancelled
Python Type-Check / pyright type-check (push) Has been cancelled
* Webui: Enable communication with parent html (if webui is in iframe):
- Listens for "setText" command from parent with "text" and "context" fields. "text" is set in inputMsg, "context" is used as hidden context on the following requests to the llama.cpp server
- On pressing na Escape button sends command "escapePressed" to the parent

Example handling from the parent html side:
- Send command "setText" from parent html to webui in iframe:
const iframe = document.getElementById('askAiIframe');
if (iframe) {
	iframe.contentWindow.postMessage({ command: 'setText', text: text, context: context }, '*');
}

- Listen for Escape key from webui on parent html:
// Listen for escape key event in the iframe
window.addEventListener('keydown', (event) => {
	if (event.key === 'Escape') {
		// Process case when Escape is pressed inside webui
	}
});

* Move the extraContext from storage to app.context.

* Fix formatting.

* add Message.extra

* format + build

* MessageExtraContext

* build

* fix display

* rm console.log

---------

Co-authored-by: igardev <ivailo.gardev@akros.ch>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-02-18 23:01:44 +01:00
Olivier Chafik
63e489c025 tool-call: refactor common chat / tool-call api (+ tests / fixes) (#11900)
* tool-call refactoring: moved common_chat_* to chat.h, common_chat_templates_init return a unique_ptr to opaque type

* addressed clang-tidy lints in [test-]chat.*

* rm minja deps from util & common & move it to common/minja/

* add name & tool_call_id to common_chat_msg

* add common_chat_tool

* added json <-> tools, msgs conversions to chat.h

* fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens)

* fix deepseek r1 slow test (no longer <think> opening w/ new template)

* allow empty tools w/ auto + grammar

* fix & test server grammar & json_schema params w/ & w/o --jinja
2025-02-18 18:03:23 +00:00
Xuan-Son Nguyen
63ac128563 server : add TEI API format for /rerank endpoint (#11942)
Some checks are pending
flake8 Lint / Lint (push) Waiting to run
Python Type-Check / pyright type-check (push) Waiting to run
* server : add TEI API format for /rerank endpoint

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* fix

* also gitignore examples/server/*.gz.hpp

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-02-18 14:21:41 +01:00
MoonRide303
5137da7b8c scripts: corrected encoding when getting chat template (#11866) (#11907)
Signed-off-by: MoonRide303 <moonride303@gmail.com>
2025-02-18 10:30:16 +01:00
xiaobing318
09aaf4f1f5 docs : Fix duplicated file extension in test command (#11935)
This commit fixes an issue in the llama.cpp project where the command for testing the llama-server object contained a duplicated file extension. The original command was:
./tests.sh unit/test_chat_completion.py.py -v -x
It has been corrected to:
./tests.sh unit/test_chat_completion.py -v -x
This change ensures that the test script correctly locates and executes the intended test file, preventing test failures due to an incorrect file name.
2025-02-18 10:12:49 +01:00
Johannes Gäßler
73e2ed3ce3 CUDA: use async data loading for FlashAttention (#11894)
* CUDA: use async data loading for FlashAttention

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2025-02-17 14:03:24 +01:00
Eve
f7b1116af1 update release requirements (#11897) 2025-02-17 12:20:23 +01:00
Antoine Viallon
c4d29baf32 server : fix divide-by-zero in metrics reporting (#11915) 2025-02-17 11:25:12 +01:00
Rémy O
2eea03d86a vulkan: implement several ops relevant for ggml_opt (#11769)
* vulkan: support memset_tensor

* vulkan: support GGML_OP_SUM

* vulkan: implement GGML_OP_ARGMAX

* vulkan: implement GGML_OP_SUB

* vulkan: implement GGML_OP_COUNT_EQUAL

* vulkan: implement GGML_OP_OPT_STEP_ADAMW

* vulkan: fix check_results RWKV_WKV6 crash and memory leaks

* vulkan: implement GGML_OP_REPEAT_BACK

* tests: remove invalid test-backend-ops REPEAT_BACK tests

* vulkan: fix COUNT_EQUAL memset using a fillBuffer command
2025-02-17 07:55:57 +01:00
Xuan-Son Nguyen
0f2bbe6564 server : bump httplib to 0.19.0 (#11908) 2025-02-16 17:11:22 +00:00
standby24x7
fe163d5bf3 common : Fix a typo in help (#11899)
This patch fixes a typo in command help.
prefx -> prefix

Signed-off-by: Masanari Iida <standby24x7@gmail.com>
2025-02-16 10:51:13 +01:00
Xuan-Son Nguyen
818a340ea8 ci : fix (again) arm64 build fails (#11895)
* docker : attempt fixing arm64 build on ci

* qemu v7.0.0-28
2025-02-16 10:36:39 +01:00
Jeff Bolz
bf42a23d0a vulkan: support multi/vision rope, and noncontiguous rope (#11902) 2025-02-16 08:52:23 +01:00
161 changed files with 9753 additions and 3822 deletions

View File

@@ -173,7 +173,15 @@ jobs:
name: llama-bin-macos-x64.zip
ubuntu-cpu-cmake:
runs-on: ubuntu-22.04
strategy:
matrix:
include:
- build: 'x64'
os: ubuntu-22.04
- build: 'arm64'
os: ubuntu-22.04-arm
runs-on: ${{ matrix.os }}
steps:
- name: Clone
@@ -239,14 +247,14 @@ jobs:
run: |
cp LICENSE ./build/bin/
cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip ./build/bin/*
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/*
- name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip
name: llama-bin-ubuntu-x64.zip
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip
name: llama-bin-ubuntu-${{ matrix.build }}.zip
ubuntu-latest-cmake-sanitizer:
runs-on: ubuntu-latest
@@ -374,6 +382,8 @@ jobs:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
@@ -1373,8 +1383,10 @@ jobs:
needs:
- ubuntu-cpu-cmake
- ubuntu-22-cmake-vulkan
- windows-latest-cmake
- windows-2019-cmake-cuda
- windows-latest-cmake-sycl
- windows-latest-cmake-hip-release
- macOS-latest-cmake-arm64
- macOS-latest-cmake-x64

View File

@@ -51,6 +51,8 @@ jobs:
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
with:
image: tonistiigi/binfmt:qemu-v7.0.0-28
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

1
.gitignore vendored
View File

@@ -98,6 +98,7 @@ examples/server/*.css.hpp
examples/server/*.html.hpp
examples/server/*.js.hpp
examples/server/*.mjs.hpp
examples/server/*.gz.hpp
!build_64.sh
!examples/*.bat
!examples/*/*.kts

View File

@@ -1,10 +1,12 @@
# Pull requests (for contributors)
- llama.cpp uses the ggml tensor library for model evaluation. If you are unfamiliar with ggml, consider taking a look at the [examples in the ggml repository](https://github.com/ggml-org/ggml/tree/master/examples/). [simple](https://github.com/ggml-org/ggml/tree/master/examples/simple) shows the bare minimum for using ggml. [gpt-2](https://github.com/ggml-org/ggml/tree/master/examples/gpt-2) has minimal implementations for language model inference using GPT-2. [mnist](https://github.com/ggml-org/ggml/tree/master/examples/mnist) demonstrates how to train and evaluate a simple image classifier
- Test your changes:
- Execute [the full CI locally on your machine](ci/README.md) before publishing
- Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`)
- If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends)
- If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops`
- Create separate PRs for each feature or fix. Avoid combining unrelated changes in a single PR
- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly
- If your PR becomes stale, don't hesitate to ping the maintainers in the comments

View File

@@ -680,6 +680,10 @@ ifdef GGML_CUDA_CCBIN
MK_NVCCFLAGS += -ccbin $(GGML_CUDA_CCBIN)
endif # GGML_CUDA_CCBIN
ifdef GGML_CUDA_NO_FA
MK_NVCCFLAGS += -DGGML_CUDA_NO_FA
endif # GGML_CUDA_NO_FA
ifdef GGML_CUDA_FA_ALL_QUANTS
MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
endif # GGML_CUDA_FA_ALL_QUANTS
@@ -800,6 +804,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY
endif # GGML_CUDA_NO_PEER_COPY
ifdef GGML_CUDA_NO_FA
HIPFLAGS += -DGGML_CUDA_NO_FA
endif # GGML_CUDA_NO_FA
OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o
OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
OBJ_GGML_EXT += $(OBJ_CUDA_TMPL)
@@ -847,7 +855,7 @@ ifdef GGML_MUSA
CXX := $(MUSA_PATH)/bin/clang++
MCC := $(CCACHE) $(MUSA_PATH)/bin/mcc
MUSAFLAGS = -x musa -mtgpu
MUSAFLAGS = -fsigned-char -x musa -mtgpu
MUSAFLAGS += $(foreach arch,$(subst ;, ,$(MUSA_ARCHITECTURES)),--cuda-gpu-arch=mp_$(arch))
ifdef GGML_CUDA_FORCE_MMQ
@@ -876,6 +884,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY
endif # GGML_CUDA_NO_PEER_COPY
ifdef GGML_CUDA_NO_FA
MUSAFLAGS += -DGGML_CUDA_NO_FA
endif # GGML_CUDA_NO_FA
ifdef GGML_CUDA_FA_ALL_QUANTS
MUSAFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
endif # GGML_CUDA_FA_ALL_QUANTS
@@ -1364,7 +1376,7 @@ llama-server: \
examples/server/index.html.hpp \
examples/server/loading.html.hpp \
common/chat.cpp \
common/chat.hpp \
common/chat.h \
common/chat-template.hpp \
common/json.hpp \
common/minja.hpp \

View File

@@ -219,7 +219,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [llama_cpp_canister](https://github.com/onicai/llama_cpp_canister) - llama.cpp as a smart contract on the Internet Computer, using WebAssembly
- [llama-swap](https://github.com/mostlygeek/llama-swap) - transparent proxy that adds automatic model switching with llama-server
- [Kalavai](https://github.com/kalavai-net/kalavai-client) - Crowdsource end to end LLM deployment at any scale
- [llmaz](https://github.com/InftyAI/llmaz) - ☸️ Easy, advanced inference platform for large language models on Kubernetes.
</details>
<details>

View File

@@ -57,8 +57,7 @@ add_library(${TARGET} STATIC
arg.h
base64.hpp
chat.cpp
chat.hpp
chat-template.hpp
chat.h
common.cpp
common.h
console.cpp
@@ -68,7 +67,8 @@ add_library(${TARGET} STATIC
llguidance.cpp
log.cpp
log.h
minja.hpp
minja/chat-template.hpp
minja/minja.hpp
ngram-cache.cpp
ngram-cache.h
sampling.cpp

View File

@@ -2,6 +2,7 @@
#include "log.h"
#include "sampling.h"
#include "chat.h"
#include <algorithm>
#include <climits>
@@ -2247,7 +2248,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_env("LLAMA_LOG_VERBOSITY"));
add_opt(common_arg(
{"--log-prefix"},
"Enable prefx in log messages",
"Enable prefix in log messages",
[](common_params &) {
common_log_set_prefix(common_log_main(), true);
}
@@ -2501,5 +2502,53 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--fim-qwen-1.5b-default"},
string_format("use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet)"),
[](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf";
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--fim-qwen-3b-default"},
string_format("use default Qwen 2.5 Coder 3B (note: can download weights from the internet)"),
[](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-3b-q8_0.gguf";
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--fim-qwen-7b-default"},
string_format("use default Qwen 2.5 Coder 7B (note: can download weights from the internet)"),
[](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
return ctx_arg;
}

View File

@@ -1,8 +1,433 @@
#include "chat.hpp"
#include "chat-template.hpp"
#include "chat.h"
#include "json-schema-to-grammar.h"
#include "log.h"
#include "minja.hpp"
#include "minja/chat-template.hpp"
#include "minja/minja.hpp"
#include <optional>
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
};
struct templates_params {
json messages;
json tools;
common_chat_tool_choice tool_choice;
json json_schema;
bool parallel_tool_calls;
bool stream;
std::string grammar;
bool add_generation_prompt = true;
bool extract_reasoning = true;
};
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
if (tool_choice == "auto") {
return COMMON_CHAT_TOOL_CHOICE_AUTO;
}
if (tool_choice == "none") {
return COMMON_CHAT_TOOL_CHOICE_NONE;
}
if (tool_choice == "required") {
return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
}
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
}
template <>
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
std::vector<common_chat_msg> msgs;
try {
if (!messages.is_array()) {
throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
}
for (const auto & message : messages) {
if (!message.is_object()) {
throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
}
common_chat_msg msg;
if (!message.contains("role")) {
throw std::runtime_error("Missing 'role' in message: " + message.dump());
}
msg.role = message.at("role");
if (message.contains("content")) {
const auto & content = message.at("content");
if (content.is_string()) {
msg.content = content;
} else if (content.is_array()) {
for (const auto & part : content) {
if (!part.contains("type")) {
throw std::runtime_error("Missing content part type: " + part.dump());
}
const auto & type = part.at("type");
if (type != "text") {
throw std::runtime_error("Unsupported content part type: " + type.dump());
}
common_chat_msg_content_part msg_part;
msg_part.type = type;
msg_part.text = part.at("text");
msg.content_parts.push_back(msg_part);
}
} else if (!content.is_null()) {
throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
} else {
throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
if (message.contains("reasoning_content")) {
msg.reasoning_content = message.at("reasoning_content");
}
if (message.contains("name")) {
msg.tool_name = message.at("name");
}
if (message.contains("tool_call_id")) {
msg.tool_call_id = message.at("tool_call_id");
}
if (message.contains("tool_calls")) {
for (const auto & tool_call : message.at("tool_calls")) {
common_chat_tool_call tc;
if (!tool_call.contains("type")) {
throw std::runtime_error("Missing tool call type: " + tool_call.dump());
}
const auto & type = tool_call.at("type");
if (type != "function") {
throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
}
if (!tool_call.contains("function")) {
throw std::runtime_error("Missing tool call function: " + tool_call.dump());
}
const auto & fc = tool_call.at("function");
if (!fc.contains("name")) {
throw std::runtime_error("Missing tool call name: " + tool_call.dump());
}
tc.name = fc.at("name");
tc.arguments = fc.at("arguments");
if (tool_call.contains("id")) {
tc.id = tool_call.at("id");
}
msg.tool_calls.push_back(tc);
}
}
msgs.push_back(msg);
}
} catch (const std::exception & e) {
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
}
return msgs;
}
template <>
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
json messages = json::array();
for (const auto & msg : msgs) {
if (!msg.content.empty() && !msg.content_parts.empty()) {
throw std::runtime_error("Cannot specify both content and content_parts");
}
json jmsg {
{"role", msg.role},
};
if (!msg.content.empty()) {
jmsg["content"] = msg.content;
} else if (!msg.content_parts.empty()) {
if (concat_typed_text) {
std::string text;
for (const auto & part : msg.content_parts) {
if (part.type != "text") {
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
continue;
}
if (!text.empty()) {
text += '\n';
}
text += part.text;
}
jmsg["content"] = text;
} else {
auto & parts = jmsg["content"] = json::array();
for (const auto & part : msg.content_parts) {
parts.push_back({
{"type", part.type},
{"text", part.text},
});
}
}
} else {
jmsg["content"] = json(); // null
}
if (!msg.reasoning_content.empty()) {
jmsg["reasoning_content"] = msg.reasoning_content;
}
if (!msg.tool_name.empty()) {
jmsg["name"] = msg.tool_name;
}
if (!msg.tool_call_id.empty()) {
jmsg["tool_call_id"] = msg.tool_call_id;
}
if (!msg.tool_calls.empty()) {
auto & tool_calls = jmsg["tool_calls"] = json::array();
for (const auto & tool_call : msg.tool_calls) {
json tc {
{"type", "function"},
{"function", {
{"name", tool_call.name},
{"arguments", tool_call.arguments},
}},
};
if (!tool_call.id.empty()) {
tc["id"] = tool_call.id;
}
tool_calls.push_back(tc);
}
}
messages.push_back(jmsg);
}
return messages;
}
template <>
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
return common_chat_msgs_parse_oaicompat(json::parse(messages));
}
template <>
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
std::vector<common_chat_tool> result;
try {
if (!tools.is_null()) {
if (!tools.is_array()) {
throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
}
for (const auto & tool : tools) {
if (!tool.contains("type")) {
throw std::runtime_error("Missing tool type: " + tool.dump());
}
const auto & type = tool.at("type");
if (!type.is_string() || type != "function") {
throw std::runtime_error("Unsupported tool type: " + tool.dump());
}
if (!tool.contains("function")) {
throw std::runtime_error("Missing tool function: " + tool.dump());
}
const auto & function = tool.at("function");
result.push_back({
/* .name = */ function.at("name"),
/* .description = */ function.at("description"),
/* .parameters = */ function.at("parameters").dump(),
});
}
}
} catch (const std::exception & e) {
throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
}
return result;
}
template <>
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
return common_chat_tools_parse_oaicompat(json::parse(tools));
}
template <>
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
if (tools.empty()) {
return json();
}
auto result = json::array();
for (const auto & tool : tools) {
result.push_back({
{"type", "function"},
{"function", {
{"name", tool.name},
{"description", tool.description},
{"parameters", json::parse(tool.parameters)},
}},
});
}
return result;
}
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
common_chat_msg msg;
msg.role = "user";
msg.content = "test";
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
common_chat_templates_inputs inputs;
inputs.messages = {msg};
common_chat_templates_apply(tmpls.get(), inputs);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string common_chat_format_single(
const struct common_chat_templates * tmpls,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja) {
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
std::string fmt_past_msg;
if (!past_msg.empty()) {
inputs.messages = past_msg;
inputs.add_generation_prompt = false;
fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
}
std::ostringstream ss;
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
inputs.messages.push_back(new_msg);
inputs.add_generation_prompt = add_ass;
auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
auto add_simple_msg = [&](auto role, auto content) {
common_chat_msg msg;
msg.role = role;
msg.content = content;
inputs.messages.push_back(msg);
};
add_simple_msg("system", "You are a helpful assistant");
add_simple_msg("user", "Hello");
add_simple_msg("assistant", "Hi there");
add_simple_msg("user", "How are you?");
return common_chat_templates_apply(tmpls, inputs).prompt;
}
#define CHATML_TEMPLATE_SRC \
"{%- for message in messages -%}\n" \
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
"{%- endfor -%}\n" \
"{%- if add_generation_prompt -%}\n" \
" {{- '<|im_start|>assistant\n' -}}\n" \
"{%- endif -%}"
void common_chat_templates_free(struct common_chat_templates * tmpls) {
delete tmpls;
}
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
return tmpls->has_explicit_template;
}
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
if (variant != nullptr) {
if (strcmp(variant, "tool_use") == 0) {
if (tmpls->template_tool_use) {
return tmpls->template_tool_use->source().c_str();
}
return nullptr;
} else {
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
}
}
return tmpls->template_default->source().c_str();
}
common_chat_templates_ptr common_chat_templates_init(
const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override,
const std::string & eos_token_override)
{
std::string default_template_src;
std::string template_tool_use_src;
bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
GGML_ASSERT(model != nullptr);
const auto * str = llama_model_chat_template(model, /* name */ nullptr);
if (str) {
default_template_src = str;
has_explicit_template = true;
}
str = llama_model_chat_template(model, /* name */ "tool_use");
if (str) {
template_tool_use_src = str;
has_explicit_template = true;
}
} else {
default_template_src = chat_template_override;
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
} else {
default_template_src = CHATML_TEMPLATE_SRC;
}
}
std::string token_bos = bos_token_override;
std::string token_eos = eos_token_override;
if (model) {
const auto * vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
}
return std::string();
}
return common_token_to_piece(vocab, token, true);
};
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
}
common_chat_templates_ptr tmpls(new common_chat_templates());
tmpls->has_explicit_template = has_explicit_template;
try {
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
}
if (!template_tool_use_src.empty()) {
try {
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
}
}
return tmpls;
}
std::string common_chat_format_name(common_chat_format format) {
switch (format) {
@@ -38,22 +463,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
json_error_locator() : position(0), found_error(false) {}
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
this->position = position - 1;
this->found_error = true;
return false;
}
bool null() override { return true; }
bool boolean(bool) override { return true; }
bool number_integer(number_integer_t) override { return true; }
bool number_unsigned(number_unsigned_t) override { return true; }
bool number_float(number_float_t, const string_t &) override { return true; }
bool string(string_t &) override { return true; }
bool binary(binary_t &) override { return true; }
bool start_object(std::size_t) override { return true; }
bool key(string_t &) override { return true; }
bool null() override { return true; } // NOLINT
bool boolean(bool) override { return true; } // NOLINT
bool number_integer(number_integer_t) override { return true; } // NOLINT
bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
bool string(string_t &) override { return true; } // NOLINT
bool binary(binary_t &) override { return true; } // NOLINT
bool start_object(std::size_t) override { return true; } // NOLINT
bool key(string_t &) override { return true; } // NOLINT
bool end_object() override { return true; }
bool start_array(std::size_t) override { return true; }
bool start_array(std::size_t) override { return true; } // NOLINT
bool end_array() override { return true; }
};
json_error_locator err_loc;
@@ -187,13 +612,20 @@ static std::string apply(
// tmpl_inputs.now = std::chrono::system_clock::now();
minja::chat_template_options tmpl_opts;
tmpl_opts.use_bos_token = false;
tmpl_opts.use_eos_token = false;
return tmpl.apply(tmpl_inputs, tmpl_opts);
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
if (string_starts_with(result, tmpl.bos_token())) {
result = result.substr(tmpl.bos_token().size());
}
if (string_ends_with(result, tmpl.eos_token())) {
result = result.substr(0, result.size() - tmpl.eos_token().size());
}
return result;
}
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
auto tool_call_schemas = json::array();
@@ -247,7 +679,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
{"required", json::array({"tool_call"})},
};
const auto schema =
inputs.tool_choice != "required"
inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
? json {
{"anyOf", json::array({
tool_call,
@@ -303,9 +735,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
return result;
}
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
@@ -348,9 +780,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
}
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
@@ -455,10 +887,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
const auto & parameters_required = parameters.at("required");
for (const auto & prop : expected_properties) {
if (!parameters_properties.contains(prop)) {
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
}
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
}
}
if (parameters_properties.size() != expected_properties.size()) {
@@ -466,18 +898,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
}
}
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
auto builtin_tools = json::array();
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
if (name == "wolfram_alpha") {
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "python" || name == "code_interpreter") {
@@ -489,7 +919,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
std::vector<std::string> kvs;
for (const auto & [key, value] : parameters.at("properties").items()) {
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
}
tool_rules.push_back(
@@ -560,34 +990,33 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
auto arg_value_str = raw_args.substr(it_eq + 1);
auto arg_value = json::parse(arg_value_str);
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ match[1],
/* .arguments = */ (json {
{arg_name, arg_value},
}).dump(),
/* .id = */ "",
},
},
};
common_chat_msg msg;
msg.role = "assistant";
msg.content = match.prefix().str();
msg.tool_calls.push_back({
/* .name = */ name,
/* .arguments = */ (json {
{arg_name, arg_value},
}).dump(),
/* .id = */ "",
});
return msg;
}
}
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
}
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
auto args_rule = builder.add_schema(name + "-args", parameters);
tool_rules.push_back(builder.add_rule(name + "-call",
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n"
@@ -666,15 +1095,15 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input,
return msg;
}
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
fprintf(stderr, "%s\n", __func__);
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
LOG_DBG("%s\n", __func__);
common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
});
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
@@ -712,14 +1141,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
}
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules;
@@ -727,6 +1156,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
@@ -795,14 +1225,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
}
}
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_params data;
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
std::string python_code_argument_name;
auto has_raw_python = false;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
@@ -814,7 +1244,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
throw std::runtime_error("Missing type in python tool");
}
has_raw_python = true;
auto type = parameters.at("type");
const auto & type = parameters.at("type");
if (type == "object") {
auto properties = parameters.at("properties");
for (auto it = properties.begin(); it != properties.end(); ++it) {
@@ -854,17 +1284,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) {
auto code = match[1].str();
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ "python",
/* .arguments = */ (json {{"code", code}}).dump(),
/* .id = */ "",
},
}
};
common_chat_msg msg;
msg.role = "assistant";
msg.content = match.prefix().str();
msg.tool_calls.push_back({
/* .name = */ "python",
/* .arguments = */ (json {{"code", code}}).dump(),
/* .id = */ "",
});
return msg;
}
static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)");
@@ -872,10 +1300,10 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
}
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
@@ -908,20 +1336,18 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
common_chat_msg msg;
msg.role = "assistant";
auto end = input.end();
std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
msg.content = input;
return msg;
}
common_chat_msg result;
result.role = "assistant";
result.content = rit->prefix();
msg.content = rit->prefix();
auto it = rit->suffix().first;
while (it != end) {
@@ -930,7 +1356,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
throw std::runtime_error("Failed to parse json tool call");
}
const auto & arguments = call.at("arguments");
result.tool_calls.push_back({
msg.tool_calls.push_back({
call.at("name"),
arguments.dump(),
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
@@ -947,17 +1373,17 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
break;
}
}
return result;
return msg;
} catch (const std::exception & e) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
common_chat_msg msg;
msg.role = "assistant";
msg.content = input;
return msg;
}
}
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@@ -973,12 +1399,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
return data;
}
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_templates_apply_jinja(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
templates_params params;
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
? *tmpls->template_tool_use
: *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
params.add_generation_prompt = inputs.add_generation_prompt;
params.extract_reasoning = inputs.extract_reasoning;
params.tool_choice = inputs.tool_choice;
params.grammar = inputs.grammar;
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
if (inputs.tools.is_array()) {
if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
params.parallel_tool_calls = false;
} else {
params.parallel_tool_calls = inputs.parallel_tool_calls;
}
if (params.tools.is_array()) {
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
if (caps.supports_tool_calls && !caps.supports_tools) {
@@ -987,68 +1436,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
}
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
if (src.find("<tool▁calls▁begin>") != std::string::npos && inputs.json_schema.is_null()) {
return common_chat_params_init_deepseek_r1(tmpl, inputs);
if (src.find("<tool▁calls▁begin>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_deepseek_r1(tmpl, params);
}
// Command R7B: : use handler in all cases except json schema (thinking / tools).
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
return common_chat_params_init_command_r7b(tmpl, inputs);
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_command_r7b(tmpl, params);
}
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
return common_chat_params_init_generic(tmpl, inputs);
if ((params.tools.is_array() && params.json_schema.is_object())) {
return common_chat_params_init_generic(tmpl, params);
}
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
if (src.find(">>>all") != std::string::npos) {
return common_chat_params_init_functionary_v3_2(tmpl, inputs);
return common_chat_params_init_functionary_v3_2(tmpl, params);
}
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
if (src.find(" functools[") != std::string::npos) {
return common_chat_params_init_firefunction_v2(tmpl, inputs);
return common_chat_params_init_firefunction_v2(tmpl, params);
}
// Plain handler (no tools)
if (inputs.tools.is_null() || inputs.tool_choice == "none") {
return common_chat_params_init_without_tools(tmpl, inputs);
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params);
}
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos) {
return common_chat_params_init_hermes_2_pro(tmpl, inputs);
return common_chat_params_init_hermes_2_pro(tmpl, params);
}
// Functionary v3.1 (w/ tools)
if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
}
// Llama 3.1, 3.2, 3.3 (w/ tools)
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
}
// Mistral Nemo (w/ tools)
if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_params_init_mistral_nemo(tmpl, inputs);
return common_chat_params_init_mistral_nemo(tmpl, params);
}
// Generic fallback
return common_chat_params_init_generic(tmpl, inputs);
return common_chat_params_init_generic(tmpl, params);
}
// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
static common_chat_params common_chat_templates_apply_legacy(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
int alloc_size = 0;
std::vector<llama_chat_message> chat;
std::vector<std::string> contents;
for (const auto & msg : inputs.messages) {
auto content = msg.content;
for (const auto & part : msg.content_parts) {
if (part.type != "text") {
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
continue;
}
if (!content.empty()) {
content += "\n";;
}
content += part.text;
}
contents.emplace_back(std::move(content));
}
for (size_t i = 0; i < contents.size(); ++i) {
const auto & msg = inputs.messages[i];
const auto & content = contents[i];
chat.push_back({msg.role.c_str(), content.c_str()});
alloc_size += (msg.role.size() + content.size()) * 1.25;
}
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
const auto & src = tmpls->template_default->source();
int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
}
common_chat_params params;
params.prompt = std::string(buf.data(), res);
if (!inputs.json_schema.empty()) {
params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
} else {
params.grammar = inputs.grammar;
}
return params;
}
common_chat_params common_chat_templates_apply(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
GGML_ASSERT(tmpls != nullptr);
return inputs.use_jinja
? common_chat_templates_apply_jinja(tmpls, inputs)
: common_chat_templates_apply_legacy(tmpls, inputs);
}
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
common_chat_msg msg;
msg.role = "assistant";
msg.content = input;
return msg;
}
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {

134
common/chat.h Normal file
View File

@@ -0,0 +1,134 @@
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
#pragma once
#include "common.h"
#include <string>
#include <vector>
struct common_chat_templates;
struct common_chat_tool_call {
std::string name;
std::string arguments;
std::string id;
};
struct common_chat_msg_content_part {
std::string type;
std::string text;
};
struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_chat_msg_content_part> content_parts = {};
std::vector<common_chat_tool_call> tool_calls = {};
std::string reasoning_content;
std::string tool_name;
std::string tool_call_id;
};
struct common_chat_tool {
std::string name;
std::string description;
std::string parameters;
};
enum common_chat_tool_choice {
COMMON_CHAT_TOOL_CHOICE_AUTO,
COMMON_CHAT_TOOL_CHOICE_REQUIRED,
COMMON_CHAT_TOOL_CHOICE_NONE,
};
enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
struct common_chat_templates_inputs {
std::vector<common_chat_msg> messages;
std::string grammar;
std::string json_schema;
bool add_generation_prompt = true;
bool use_jinja = true;
// Parameters below only supported when use_jinja is true
std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
bool extract_reasoning = true;
};
struct common_chat_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::string prompt;
std::string grammar;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
void common_chat_templates_free(struct common_chat_templates * tmpls);
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
common_chat_templates_ptr common_chat_templates_init(
const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override = "",
const std::string & eos_token_override = "");
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
struct common_chat_params common_chat_templates_apply(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(
const struct common_chat_templates * tmpls,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string common_chat_format_example(
const struct common_chat_templates * tmpls,
bool use_jinja);
std::string common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
// Parses a JSON array of messages in OpenAI's chat completion API format.
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);

View File

@@ -1,55 +0,0 @@
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
#pragma once
#include "common.h"
#include <json.hpp>
#include <optional>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
struct common_chat_inputs {
json messages;
json tools;
json tool_choice;
json json_schema;
bool parallel_tool_calls;
bool stream;
std::string grammar;
bool add_generation_prompt = true;
bool extract_reasoning = true;
};
enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
struct common_chat_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
json prompt;
std::string grammar;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
std::string common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);

View File

@@ -12,8 +12,6 @@
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "chat.hpp"
#include "chat-template.hpp"
#include <algorithm>
#include <cinttypes>
@@ -1768,174 +1766,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
return text;
}
//
// Chat template utils
//
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
common_chat_inputs inputs;
inputs.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
common_chat_params_init(chat_template, inputs);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string common_chat_apply_template(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & msgs,
bool add_ass,
bool use_jinja) {
if (use_jinja) {
auto messages = json::array();
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
common_chat_inputs inputs;
inputs.messages = messages;
inputs.add_generation_prompt = add_ass;
return common_chat_params_init(tmpl, inputs).prompt;
}
int alloc_size = 0;
std::vector<llama_chat_message> chat;
for (const auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
}
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}
std::string common_chat_format_single(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja) {
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
std::vector<common_chat_msg> msgs = {
{"system", "You are a helpful assistant", {}},
{"user", "Hello", {}},
{"assistant", "Hi there", {}},
{"user", "How are you?", {}},
};
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
}
#define CHATML_TEMPLATE_SRC \
"{%- for message in messages -%}\n" \
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
"{%- endfor -%}\n" \
"{%- if add_generation_prompt -%}\n" \
" {{- '<|im_start|>assistant\n' -}}\n" \
"{%- endif -%}"
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
{
std::string default_template_src;
std::string template_tool_use_src;
bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
auto str = llama_model_chat_template(model, /* name */ nullptr);
if (str) {
default_template_src = str;
has_explicit_template = true;
}
str = llama_model_chat_template(model, /* name */ "tool_use");
if (str) {
template_tool_use_src = str;
has_explicit_template = true;
}
} else {
default_template_src = chat_template_override;
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
} else {
default_template_src = CHATML_TEMPLATE_SRC;
}
}
auto vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
}
return std::string();
} else {
return common_token_to_piece(vocab, token, true);
}
};
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
try {
return {
has_explicit_template,
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
template_tool_use_src.empty()
? nullptr
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
};
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what());
return {
has_explicit_template,
std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
nullptr,
};
}
}
//
// KV cache utils
//

View File

@@ -178,10 +178,10 @@ struct common_params_speculative {
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.9f; // minimum speculative decoding probability (greedy)
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
@@ -616,62 +616,6 @@ std::string common_detokenize(
const std::vector<llama_token> & tokens,
bool special = true);
//
// Chat template utils
//
struct common_tool_call {
std::string name;
std::string arguments;
std::string id;
};
// same with llama_chat_message, but uses std::string
struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_tool_call> tool_calls;
std::string reasoning_content = "";
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
namespace minja {
class chat_template;
}
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
};
// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
std::string common_chat_apply_template(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & chat,
bool add_ass,
bool use_jinja);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string common_chat_format_example(
const common_chat_template & tmpl, bool use_jinja);
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
//
// KV cache utils
//

View File

@@ -252,11 +252,6 @@ llama_tokens common_speculative_gen_draft(
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
break;
}
common_sampler_accept(smpl, id, true);
result.push_back(id);
@@ -265,6 +260,11 @@ llama_tokens common_speculative_gen_draft(
break;
}
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
break;
}
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
// evaluate the drafted tokens on the draft model

View File

@@ -9,7 +9,7 @@ struct common_speculative_params {
int n_draft = 16; // max drafted tokens
int n_reuse = 256;
float p_min = 0.9f; // min probability required to accept a token in the draft
float p_min = 0.75f; // min probability required to accept a token in the draft
};
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);

View File

@@ -42,6 +42,16 @@ The following release is verified with good quality:
## News
- 2025.2
- Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC).
|GPU|Base tokens/s|Increased tokens/s|Percent|
|-|-|-|-|
|PVC 1550|39|73|+87%|
|Flex 170|39|50|+28%|
|Arc770|42|55|+30%|
|MTL|13|16|+23%|
|ARL-H|14|17|+21%|
- 2024.11
- Use syclcompat to improve the performance on some platforms. This requires to use oneAPI 2025.0 or newer.
@@ -97,8 +107,8 @@ SYCL backend supports Intel GPU Family:
| Intel Data Center Max Series | Support | Max 1550, 1100 |
| Intel Data Center Flex Series | Support | Flex 170 |
| Intel Arc Series | Support | Arc 770, 730M, Arc A750 |
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake |
| Intel iGPU | Support | iGPU in 13700k, i5-1250P, i7-1260P, i7-1165G7 |
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake |
| Intel iGPU | Support | iGPU in 13700k,iGPU in 13400, i5-1250P, i7-1260P, i7-1165G7 |
*Notes:*
@@ -660,8 +670,10 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| Name | Value | Function |
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
## Known Issues
- `Split-mode:[row]` is not supported.

View File

@@ -206,6 +206,14 @@ This provides GPU acceleration using the MUSA cores of your Moore Threads MTT GP
cmake --build build --config Release
```
For static build:
```bash
cmake -B build -DGGML_MUSA=ON \
-DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
cmake --build build --config Release
```
The environment variable [`MUSA_VISIBLE_DEVICES`](https://docs.mthreads.com/musa-sdk/musa-sdk-doc-online/programming_guide/Z%E9%99%84%E5%BD%95/) can be used to specify which GPU(s) will be used.
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted.

390
docs/function-calling.md Normal file
View File

@@ -0,0 +1,390 @@
# Function Calling
[chat.h](../common/chat.h) (https://github.com/ggml-org/llama.cpp/pull/9639) adds support for [OpenAI-style function calling](https://platform.openai.com/docs/guides/function-calling) and is used in:
- `llama-server` when started w/ `--jinja` flag
- `llama-cli` (WIP: https://github.com/ggml-org/llama.cpp/pull/11556)
## Universal support w/ Native & Generic handlers
Function calling is supported for all models (see https://github.com/ggml-org/llama.cpp/pull/9639):
- Native tool call formats supported:
- Llama 3.1 / 3.3 (including builtin tools support - tool names for `wolfram_alpha`, `web_search` / `brave_search`, `code_interpreter`), Llama 3.2
- Functionary v3.1 / v3.2
- Hermes 2/3, Qwen 2.5
- Qwen 2.5 Coder (WIP: https://github.com/ggml-org/llama.cpp/pull/12034)
- Mistral Nemo
- Firefunction v2
- Command R7B
- DeepSeek R1 (WIP / seems reluctant to call any tools?)
- Generic tool call is supported when the template isn't recognized by native format handlers (you'll see `Chat format: Generic` in the logs).
- Use `--chat-template-file` to override the template when appropriate (see examples below)
- Generic support may consume more tokens and be less efficient than a model's native format.
<details>
<summary>Show some common templates and which format handler they use</summary>
| Template | Format |
|----------|--------|
| Almawave-Velvet-14B.jinja | Hermes 2 Pro |
| AtlaAI-Selene-1-Mini-Llama-3.1-8B.jinja | Llama 3.x |
| CohereForAI-aya-expanse-8b.jinja | Generic |
| CohereForAI-c4ai-command-r-plus-default.jinja | Generic |
| CohereForAI-c4ai-command-r-plus-rag.jinja | Generic |
| CohereForAI-c4ai-command-r-plus-tool_use.jinja | Generic |
| CohereForAI-c4ai-command-r7b-12-2024-default.jinja | Command R7B (extract reasoning) |
| CohereForAI-c4ai-command-r7b-12-2024-rag.jinja | Command R7B (extract reasoning) |
| CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja | Command R7B (extract reasoning) |
| CohereForAI-c4ai-command-r7b-12-2024.jinja | Generic |
| DavieLion-Llama-3.2-1B-SPIN-iter3.jinja | Generic |
| Delta-Vector-Rei-12B.jinja | Mistral Nemo |
| EpistemeAI-Mistral-Nemo-Instruct-12B-Philosophy-Math.jinja | Mistral Nemo |
| FlofloB-83k_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit.jinja | Hermes 2 Pro |
| FlofloB-test_continued_pretraining_Phi-3-mini-4k-instruct_Unsloth_merged_16bit.jinja | Generic |
| HelpingAI-HAI-SER.jinja | Generic |
| HuggingFaceTB-SmolLM2-1.7B-Instruct.jinja | Generic |
| HuggingFaceTB-SmolLM2-135M-Instruct.jinja | Generic |
| HuggingFaceTB-SmolLM2-360M-Instruct.jinja | Generic |
| INSAIT-Institute-BgGPT-Gemma-2-27B-IT-v1.0.jinja | Generic |
| Ihor-Text2Graph-R1-Qwen2.5-0.5b.jinja | Hermes 2 Pro |
| Infinigence-Megrez-3B-Instruct.jinja | Generic |
| Josephgflowers-TinyLlama_v1.1_math_code-world-test-1.jinja | Generic |
| LGAI-EXAONE-EXAONE-3.5-2.4B-Instruct.jinja | Generic |
| LGAI-EXAONE-EXAONE-3.5-7.8B-Instruct.jinja | Generic |
| LatitudeGames-Wayfarer-12B.jinja | Generic |
| Magpie-Align-Llama-3-8B-Magpie-Align-v0.1.jinja | Generic |
| Magpie-Align-Llama-3.1-8B-Magpie-Align-v0.1.jinja | Generic |
| MaziyarPanahi-calme-3.2-instruct-78b.jinja | Generic |
| MiniMaxAI-MiniMax-Text-01.jinja | Generic |
| MiniMaxAI-MiniMax-VL-01.jinja | Generic |
| NaniDAO-deepseek-r1-qwen-2.5-32B-ablated.jinja | DeepSeek R1 (extract reasoning) |
| NexaAIDev-Octopus-v2.jinja | Generic |
| NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | Generic |
| NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | Hermes 2 Pro |
| NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | Generic |
| NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | Hermes 2 Pro |
| NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | Generic |
| NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | Hermes 2 Pro |
| NovaSky-AI-Sky-T1-32B-Flash.jinja | Hermes 2 Pro |
| NovaSky-AI-Sky-T1-32B-Preview.jinja | Hermes 2 Pro |
| OnlyCheeini-greesychat-turbo.jinja | Generic |
| Orenguteng-Llama-3.1-8B-Lexi-Uncensored-V2.jinja | Llama 3.x |
| OrionStarAI-Orion-14B-Chat.jinja | Generic |
| PowerInfer-SmallThinker-3B-Preview.jinja | Generic |
| PrimeIntellect-INTELLECT-1-Instruct.jinja | Generic |
| Qwen-QVQ-72B-Preview.jinja | Generic |
| Qwen-QwQ-32B-Preview.jinja | Hermes 2 Pro |
| Qwen-Qwen1.5-7B-Chat.jinja | Generic |
| Qwen-Qwen2-7B-Instruct.jinja | Generic |
| Qwen-Qwen2-VL-72B-Instruct.jinja | Generic |
| Qwen-Qwen2-VL-7B-Instruct.jinja | Generic |
| Qwen-Qwen2.5-0.5B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-1.5B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-14B-Instruct-1M.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-14B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-32B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-32B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-3B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-72B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-7B-Instruct-1M.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-7B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-7B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-Coder-32B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-Coder-7B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-Math-1.5B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-Math-7B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-VL-3B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-VL-72B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-VL-7B-Instruct.jinja | Hermes 2 Pro |
| RWKV-Red-Team-ARWKV-7B-Preview-0.1.jinja | Hermes 2 Pro |
| SakanaAI-TinySwallow-1.5B-Instruct.jinja | Hermes 2 Pro |
| SakanaAI-TinySwallow-1.5B.jinja | Hermes 2 Pro |
| Sao10K-70B-L3.3-Cirrus-x1.jinja | Llama 3.x |
| SentientAGI-Dobby-Mini-Leashed-Llama-3.1-8B.jinja | Llama 3.x |
| SentientAGI-Dobby-Mini-Unhinged-Llama-3.1-8B.jinja | Llama 3.x |
| Steelskull-L3.3-Damascus-R1.jinja | Llama 3.x |
| Steelskull-L3.3-MS-Nevoria-70b.jinja | Llama 3.x |
| Steelskull-L3.3-Nevoria-R1-70b.jinja | Llama 3.x |
| THUDM-glm-4-9b-chat.jinja | Generic |
| THUDM-glm-edge-1.5b-chat.jinja | Generic |
| Tarek07-Progenitor-V1.1-LLaMa-70B.jinja | Llama 3.x |
| TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | Generic |
| TinyLlama-TinyLlama-1.1B-Chat-v1.0.jinja | Generic |
| UCLA-AGI-Mistral7B-PairRM-SPPO-Iter3.jinja | Generic |
| ValiantLabs-Llama3.1-8B-Enigma.jinja | Llama 3.x |
| abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | Generic |
| ai21labs-AI21-Jamba-1.5-Large.jinja | Generic |
| allenai-Llama-3.1-Tulu-3-405B-SFT.jinja | Generic |
| allenai-Llama-3.1-Tulu-3-405B.jinja | Generic |
| allenai-Llama-3.1-Tulu-3-8B.jinja | Generic |
| arcee-ai-Virtuoso-Lite.jinja | Hermes 2 Pro |
| arcee-ai-Virtuoso-Medium-v2.jinja | Hermes 2 Pro |
| arcee-ai-Virtuoso-Small-v2.jinja | Hermes 2 Pro |
| avemio-GRAG-NEMO-12B-ORPO-HESSIAN-AI.jinja | Generic |
| bespokelabs-Bespoke-Stratos-7B.jinja | Hermes 2 Pro |
| bfuzzy1-acheron-m1a-llama.jinja | Generic |
| bofenghuang-vigogne-2-70b-chat.jinja | Generic |
| bytedance-research-UI-TARS-72B-DPO.jinja | Generic |
| bytedance-research-UI-TARS-7B-DPO.jinja | Generic |
| bytedance-research-UI-TARS-7B-SFT.jinja | Generic |
| carsenk-phi3.5_mini_exp_825_uncensored.jinja | Generic |
| cyberagent-DeepSeek-R1-Distill-Qwen-14B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
| cyberagent-DeepSeek-R1-Distill-Qwen-32B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
| databricks-dbrx-instruct.jinja | Generic |
| deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | Generic |
| deepseek-ai-DeepSeek-Coder-V2-Lite-Base.jinja | Generic |
| deepseek-ai-DeepSeek-Coder-V2-Lite-Instruct.jinja | Generic |
| deepseek-ai-DeepSeek-R1-Distill-Llama-70B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Qwen-1.5B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Qwen-14B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Zero.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-V2-Lite.jinja | Generic |
| deepseek-ai-DeepSeek-V2.5.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-V3.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-deepseek-coder-33b-instruct.jinja | Generic |
| deepseek-ai-deepseek-coder-6.7b-instruct.jinja | Generic |
| deepseek-ai-deepseek-coder-7b-instruct-v1.5.jinja | Generic |
| deepseek-ai-deepseek-llm-67b-chat.jinja | Generic |
| deepseek-ai-deepseek-llm-7b-chat.jinja | Generic |
| dicta-il-dictalm2.0-instruct.jinja | Generic |
| ehristoforu-Falcon3-8B-Franken-Basestruct.jinja | Hermes 2 Pro |
| fireworks-ai-llama-3-firefunction-v2.jinja | FireFunction v2 |
| godlikehhd-alpaca_data_sampled_ifd_new_5200.jinja | Hermes 2 Pro |
| godlikehhd-alpaca_data_score_max_0.7_2600.jinja | Hermes 2 Pro |
| google-gemma-2-27b-it.jinja | Generic |
| google-gemma-2-2b-it.jinja | Generic |
| google-gemma-2-2b-jpn-it.jinja | Generic |
| google-gemma-7b-it.jinja | Generic |
| huihui-ai-DeepSeek-R1-Distill-Llama-70B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-DeepSeek-R1-Distill-Llama-8B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-DeepSeek-R1-Distill-Qwen-14B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-DeepSeek-R1-Distill-Qwen-32B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-DeepSeek-R1-Distill-Qwen-7B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-Qwen2.5-14B-Instruct-1M-abliterated.jinja | Hermes 2 Pro |
| ibm-granite-granite-3.1-8b-instruct.jinja | Generic |
| indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | Generic |
| inflatebot-MN-12B-Mag-Mell-R1.jinja | Generic |
| jinaai-ReaderLM-v2.jinja | Generic |
| kms7530-chemeng_qwen-math-7b_24_1_100_1_nonmath.jinja | Hermes 2 Pro |
| knifeayumu-Cydonia-v1.3-Magnum-v4-22B.jinja | Mistral Nemo |
| langgptai-qwen1.5-7b-chat-sa-v0.1.jinja | Generic |
| lightblue-DeepSeek-R1-Distill-Qwen-7B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
| mattshumer-Reflection-Llama-3.1-70B.jinja | Generic |
| meetkai-functionary-medium-v3.1.jinja | Functionary v3.1 Llama 3.1 |
| meetkai-functionary-medium-v3.2.jinja | Functionary v3.2 |
| meta-llama-Llama-2-7b-chat-hf.jinja | Generic |
| meta-llama-Llama-3.1-8B-Instruct.jinja | Llama 3.x |
| meta-llama-Llama-3.2-11B-Vision-Instruct.jinja | Llama 3.x |
| meta-llama-Llama-3.2-1B-Instruct.jinja | Llama 3.x |
| meta-llama-Llama-3.2-3B-Instruct.jinja | Llama 3.x |
| meta-llama-Llama-3.3-70B-Instruct.jinja | Llama 3.x |
| meta-llama-Meta-Llama-3-8B-Instruct.jinja | Generic |
| meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | Llama 3.x |
| microsoft-Phi-3-medium-4k-instruct.jinja | Generic |
| microsoft-Phi-3-mini-4k-instruct.jinja | Generic |
| microsoft-Phi-3-small-8k-instruct.jinja | Generic |
| microsoft-Phi-3.5-mini-instruct.jinja | Generic |
| microsoft-Phi-3.5-vision-instruct.jinja | Generic |
| microsoft-phi-4.jinja | Generic |
| migtissera-Tess-3-Mistral-Nemo-12B.jinja | Generic |
| ministral-Ministral-3b-instruct.jinja | Generic |
| mistralai-Codestral-22B-v0.1.jinja | Generic |
| mistralai-Mistral-7B-Instruct-v0.1.jinja | Generic |
| mistralai-Mistral-7B-Instruct-v0.2.jinja | Generic |
| mistralai-Mistral-7B-Instruct-v0.3.jinja | Mistral Nemo |
| mistralai-Mistral-Large-Instruct-2407.jinja | Mistral Nemo |
| mistralai-Mistral-Large-Instruct-2411.jinja | Generic |
| mistralai-Mistral-Nemo-Instruct-2407.jinja | Mistral Nemo |
| mistralai-Mistral-Small-24B-Instruct-2501.jinja | Generic |
| mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | Generic |
| mkurman-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
| mlabonne-AlphaMonarch-7B.jinja | Generic |
| mlx-community-Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32.jinja | Hermes 2 Pro |
| mlx-community-Qwen2.5-VL-7B-Instruct-8bit.jinja | Hermes 2 Pro |
| mobiuslabsgmbh-DeepSeek-R1-ReDistill-Qwen-1.5B-v1.1.jinja | DeepSeek R1 (extract reasoning) |
| netcat420-MFANNv0.20.jinja | Generic |
| netcat420-MFANNv0.24.jinja | Generic |
| netease-youdao-Confucius-o1-14B.jinja | Hermes 2 Pro |
| nvidia-AceMath-7B-RM.jinja | Hermes 2 Pro |
| nvidia-Eagle2-1B.jinja | Hermes 2 Pro |
| nvidia-Eagle2-9B.jinja | Hermes 2 Pro |
| nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | Llama 3.x |
| onnx-community-DeepSeek-R1-Distill-Qwen-1.5B-ONNX.jinja | DeepSeek R1 (extract reasoning) |
| open-thoughts-OpenThinker-7B.jinja | Hermes 2 Pro |
| openchat-openchat-3.5-0106.jinja | Generic |
| pankajmathur-orca_mini_v6_8b.jinja | Generic |
| princeton-nlp-Mistral-7B-Base-SFT-RDPO.jinja | Generic |
| princeton-nlp-Mistral-7B-Instruct-DPO.jinja | Generic |
| princeton-nlp-Mistral-7B-Instruct-RDPO.jinja | Generic |
| prithivMLmods-Bellatrix-Tiny-1.5B-R1.jinja | Hermes 2 Pro |
| prithivMLmods-Bellatrix-Tiny-1B-R1.jinja | Llama 3.x |
| prithivMLmods-Bellatrix-Tiny-1B-v3.jinja | Generic |
| prithivMLmods-Bellatrix-Tiny-3B-R1.jinja | Llama 3.x |
| prithivMLmods-Blaze-14B-xElite.jinja | Generic |
| prithivMLmods-Calcium-Opus-14B-Elite2-R1.jinja | Hermes 2 Pro |
| prithivMLmods-Calme-Ties-78B.jinja | Generic |
| prithivMLmods-Calme-Ties2-78B.jinja | Generic |
| prithivMLmods-Calme-Ties3-78B.jinja | Generic |
| prithivMLmods-ChemQwen2-vL.jinja | Generic |
| prithivMLmods-GWQ2b.jinja | Generic |
| prithivMLmods-LatexMind-2B-Codec.jinja | Generic |
| prithivMLmods-Llama-3.2-6B-AlgoCode.jinja | Llama 3.x |
| prithivMLmods-Megatron-Opus-14B-Exp.jinja | Hermes 2 Pro |
| prithivMLmods-Megatron-Opus-14B-Stock.jinja | Hermes 2 Pro |
| prithivMLmods-Megatron-Opus-7B-Exp.jinja | Hermes 2 Pro |
| prithivMLmods-Omni-Reasoner-Merged.jinja | Hermes 2 Pro |
| prithivMLmods-Omni-Reasoner4-Merged.jinja | Hermes 2 Pro |
| prithivMLmods-Primal-Opus-14B-Optimus-v1.jinja | Hermes 2 Pro |
| prithivMLmods-QwQ-Math-IO-500M.jinja | Hermes 2 Pro |
| prithivMLmods-Qwen-7B-Distill-Reasoner.jinja | DeepSeek R1 (extract reasoning) |
| prithivMLmods-Qwen2.5-1.5B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro |
| prithivMLmods-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
| prithivMLmods-Qwen2.5-32B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro |
| prithivMLmods-Qwen2.5-7B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
| prithivMLmods-Triangulum-v2-10B.jinja | Hermes 2 Pro |
| qingy2024-Falcon3-2x10B-MoE-Instruct.jinja | Hermes 2 Pro |
| rubenroy-Zurich-14B-GCv2-5m.jinja | Hermes 2 Pro |
| rubenroy-Zurich-7B-GCv2-5m.jinja | Hermes 2 Pro |
| silma-ai-SILMA-Kashif-2B-Instruct-v1.0.jinja | Generic |
| simplescaling-s1-32B.jinja | Hermes 2 Pro |
| sometimesanotion-Lamarck-14B-v0.7.jinja | Hermes 2 Pro |
| sonthenguyen-zephyr-sft-bnb-4bit-DPO-mtbr-180steps.jinja | Generic |
| sthenno-tempesthenno-icy-0130.jinja | Generic |
| sumink-qwft.jinja | Hermes 2 Pro |
| teknium-OpenHermes-2.5-Mistral-7B.jinja | Generic |
| thirdeyeai-elevate360m.jinja | Generic |
| tiiuae-Falcon3-10B-Instruct.jinja | Hermes 2 Pro |
| unsloth-DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit.jinja | DeepSeek R1 (extract reasoning) |
| unsloth-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) |
| unsloth-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) |
| unsloth-Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit.jinja | Generic |
| upstage-solar-pro-preview-instruct.jinja | Generic |
| whyhow-ai-PatientSeek.jinja | Generic |
| xwen-team-Xwen-72B-Chat.jinja | Hermes 2 Pro |
| xwen-team-Xwen-7B-Chat.jinja | Hermes 2 Pro |
This table can be generated with:
```bash
./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
```
</details>
# Usage - need tool-aware Jinja template
First, start a server with any model, but make sure it has a tools-enabled template: you can verify this by inspecting the `chat_template` or `chat_template_tool_use` properties in `http://localhost:8080/props`).
Here are some models known to work (w/ chat template override when needed):
```shell
# Native support:
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
# Native support for DeepSeek R1 works best w/ our own template (official template buggy)
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M \
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
# Native support requires the right template for these GGUFs:
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \
--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
# Generic format support
llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0
llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K
```
> [!TIP]
> If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills)
Test in CLI (or with any library / software that can use OpenAI-compatible API backends):
```bash
curl http://localhost:8080/v1/chat/completions -d '{
"model": "gpt-3.5-turbo",
"tools": [
{
"type":"function",
"function":{
"name":"python",
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters":{
"type":"object",
"properties":{
"code":{
"type":"string",
"description":"The code to run in the ipython interpreter."
}
},
"required":["code"]
}
}
}
],
"messages": [
{
"role": "user",
"content": "Print a hello world message with python."
}
]
}'
```
<details>
<summary>Show output</summary>
```json
{
"choices": [
{
"finish_reason": "tool",
"index": 0,
"message": {
"content": null,
"tool_calls": [
{
"name": "python",
"arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}"
}
],
"role": "assistant"
}
}
],
"created": 1727287211,
"model": "gpt-3.5-turbo",
"object": "chat.completion",
"usage": {
"completion_tokens": 16,
"prompt_tokens": 44,
"total_tokens": 60
},
"id": "chatcmpl-Htbgh9feMmGM0LEH2hmQvwsCxq3c6Ni8"
}
```
</details>

View File

@@ -124,15 +124,26 @@ struct ContentView: View {
}
}
}.sheet(isPresented: $showingHelp) { // Sheet for help modal
VStack(alignment: .leading) {
NavigationView {
VStack(alignment: .leading) {
Text("1. Make sure the model is in GGUF Format")
.padding()
Text("2. Copy the download link of the quantized model")
.padding()
VStack(alignment: .leading) {
Text("1. Make sure the model is in GGUF Format")
.padding()
Text("2. Copy the download link of the quantized model")
.padding()
}
Spacer()
}
Spacer()
}
.navigationTitle("Help")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .navigationBarTrailing) {
Button("Done") {
showingHelp = false
}
}
}
}
}
}
}

View File

@@ -0,0 +1,183 @@
# Granite Vision
Download the model and point your `GRANITE_MODEL` environment variable to the path.
```bash
$ git clone https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview
$ export GRANITE_MODEL=./granite-vision-3.1-2b-preview
```
### 1. Running llava surgery v2.
First, we need to run the llava surgery script as shown below:
`python llava_surgery_v2.py -C -m $GRANITE_MODEL`
You should see two new files (`llava.clip` and `llava.projector`) written into your model's directory, as shown below.
```bash
$ ls $GRANITE_MODEL | grep -i llava
llava.clip
llava.projector
```
We should see that the projector and visual encoder get split out into the llava files. Quick check to make sure they aren't empty:
```python
import os
import torch
MODEL_PATH = os.getenv("GRANITE_MODEL")
if not MODEL_PATH:
raise ValueError("env var GRANITE_MODEL is unset!")
encoder_tensors = torch.load(os.path.join(MODEL_PATH, "llava.clip"))
projector_tensors = torch.load(os.path.join(MODEL_PATH, "llava.projector"))
assert len(encoder_tensors) > 0
assert len(projector_tensors) > 0
```
If you actually inspect the `.keys()` of the loaded tensors, you should see a lot of `vision_model` tensors in the `encoder_tensors`, and 5 tensors (`'multi_modal_projector.linear_1.bias'`, `'multi_modal_projector.linear_1.weight'`, `'multi_modal_projector.linear_2.bias'`, `'multi_modal_projector.linear_2.weight'`, `'image_newline'`) in the multimodal `projector_tensors`.
### 2. Creating the Visual Component GGUF
To create the GGUF for the visual components, we need to write a config for the visual encoder; make sure the config contains the correct `image_grid_pinpoints`
Note: we refer to this file as `$VISION_CONFIG` later on.
```json
{
"_name_or_path": "siglip-model",
"architectures": [
"SiglipVisionModel"
],
"image_grid_pinpoints": [
[384,768],
[384,1152],
[384,1536],
[384,1920],
[384,2304],
[384,2688],
[384,3072],
[384,3456],
[384,3840],
[768,384],
[768,768],
[768,1152],
[768,1536],
[768,1920],
[1152,384],
[1152,768],
[1152,1152],
[1536,384],
[1536,768],
[1920,384],
[1920,768],
[2304,384],
[2688,384],
[3072,384],
[3456,384],
[3840,384]
],
"mm_patch_merge_type": "spatial_unpad",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"layer_norm_eps": 1e-6,
"hidden_act": "gelu_pytorch_tanh",
"projection_dim": 0,
"vision_feature_layer": [-24, -20, -12, -1]
}
```
Create a new directory to hold the visual components, and copy the llava.clip/projector files, as well as the vision config into it.
```bash
$ ENCODER_PATH=$PWD/visual_encoder
$ mkdir $ENCODER_PATH
$ cp $GRANITE_MODEL/llava.clip $ENCODER_PATH/pytorch_model.bin
$ cp $GRANITE_MODEL/llava.projector $ENCODER_PATH/
$ cp $VISION_CONFIG $ENCODER_PATH/config.json
```
At which point you should have something like this:
```bash
$ ls $ENCODER_PATH
config.json llava.projector pytorch_model.bin
```
Now convert the components to GGUF; Note that we also override the image mean/std dev to `[.5,.5,.5]` since we use the siglip visual encoder - in the transformers model, you can find these numbers in the [preprocessor_config.json](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview/blob/main/preprocessor_config.json).
```bash
$ python convert_image_encoder_to_gguf.py \
-m $ENCODER_PATH \
--llava-projector $ENCODER_PATH/llava.projector \
--output-dir $ENCODER_PATH \
--clip-model-is-vision \
--clip-model-is-siglip \
--image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5
```
this will create the first GGUF file at `$ENCODER_PATH/mmproj-model-f16.gguf`; we will refer to the abs path of this file as the `$VISUAL_GGUF_PATH.`
### 3. Creating the LLM GGUF.
The granite vision model contains a granite LLM as its language model. For now, the easiest way to get the GGUF for LLM is by loading the composite model in `transformers` and exporting the LLM so that it can be directly converted with the normal conversion path.
First, set the `LLM_EXPORT_PATH` to the path to export the `transformers` LLM to.
```
$ export LLM_EXPORT_PATH=$PWD/granite_vision_llm
```
```python
import os
import transformers
MODEL_PATH = os.getenv("GRANITE_MODEL")
if not MODEL_PATH:
raise ValueError("env var GRANITE_MODEL is unset!")
LLM_EXPORT_PATH = os.getenv("LLM_EXPORT_PATH")
if not MODEL_PATH:
raise ValueError("env var LLM_EXPORT_PATH is unset!")
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_PATH)
# NOTE: granite vision support was added to transformers very recently (4.49);
# if you get size mismatches, your version is too old.
# If you are running with an older version, set `ignore_mismatched_sizes=True`
# as shown below; it won't be loaded correctly, but the LLM part of the model that
# we are exporting will be loaded correctly.
model = transformers.AutoModelForImageTextToText.from_pretrained(MODEL_PATH, ignore_mismatched_sizes=True)
tokenizer.save_pretrained(LLM_EXPORT_PATH)
model.language_model.save_pretrained(LLM_EXPORT_PATH)
```
Now you can convert the exported LLM to GGUF with the normal converter in the root of the llama cpp project.
```bash
$ LLM_GGUF_PATH=$LLM_EXPORT_PATH/granite_llm.gguf
...
$ python convert_hf_to_gguf.py --outfile $LLM_GGUF_PATH $LLM_EXPORT_PATH
```
### 4. Running the Model in Llama cpp
Build llama cpp normally; you should have a target binary named `llama-llava-cli`, which you can pass two binaries to. Sample usage:
Note - the test image shown below can be found [here](https://github-production-user-asset-6210df.s3.amazonaws.com/10740300/415512792-d90d5562-8844-4f34-a0a5-77f62d5a58b5.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20250221%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250221T054145Z&X-Amz-Expires=300&X-Amz-Signature=86c60be490aa49ef7d53f25d6c973580a8273904fed11ed2453d0a38240ee40a&X-Amz-SignedHeaders=host).
```bash
$ ./build/bin/llama-llava-cli -m $LLM_GGUF_PATH \
--mmproj $VISUAL_GGUF_PATH \
--image cherry_blossom.jpg \
-c 16384 \
-p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|user|>\n\<image>\nWhat type of flowers are in this picture?\n<|assistant|>\n" \
--temp 0
```
Sample response: `The flowers in the picture are cherry blossoms, which are known for their delicate pink petals and are often associated with the beauty of spring.`

View File

@@ -101,8 +101,27 @@ python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow
```
**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)
**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)
**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model.
```python
import os
import transformers
model_path = ...
llm_export_path = ...
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
model = transformers.AutoModelForImageTextToText.from_pretrained(model_path)
tokenizer.save_pretrained(llm_export_path)
model.language_model.save_pretrained(llm_export_path)
```
Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures.
## llava-cli templating and llava-1.6 prompting
llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."`

View File

@@ -40,6 +40,7 @@
#include <map>
#include <regex>
#include <stdexcept>
#include <unordered_set>
#include <vector>
#include <sstream>
#include <cinttypes>
@@ -120,6 +121,7 @@ static std::string format(const char * fmt, ...) {
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
@@ -444,8 +446,9 @@ struct clip_hparams {
char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default)
int32_t image_grid_pinpoints[32];
std::vector<int32_t> image_grid_pinpoints;
int32_t image_crop_resolution;
std::unordered_set<int32_t> vision_feature_layer;
};
struct clip_layer {
@@ -585,6 +588,7 @@ struct clip_ctx {
struct clip_vision_model vision_model;
projector_type proj_type = PROJECTOR_TYPE_MLP;
int32_t max_feature_layer;
float image_mean[3];
float image_std[3];
bool use_gelu = false;
@@ -651,7 +655,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
int n_layer = hparams.n_layer;
const float eps = hparams.eps;
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
@@ -752,13 +755,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
}
std::vector<struct ggml_tensor *> embedding_stack;
const auto & vision_feature_layer = hparams.vision_feature_layer;
// loop over layers
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
n_layer += 1;
}
for (int il = 0; il < n_layer - 1; il++) {
for (int il = 0; il < ctx->max_feature_layer; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
// If this is an embedding feature layer, save the output.
// NOTE: 0 index here refers to the input to the encoder.
if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
embedding_stack.push_back(embeddings);
}
//const size_t nb_q_w = model.layers[il].q_w->nb[0];
// layernorm1
@@ -846,7 +855,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
}
// post-layernorm
@@ -857,6 +865,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
}
// final layer is a vision feature layer
if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) {
embedding_stack.push_back(embeddings);
}
// If feature layers are explicitly set, stack them (if we have multiple)
if (!embedding_stack.empty()) {
embeddings = embedding_stack[0];
for (size_t i = 1; i < embedding_stack.size(); i++) {
embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
}
}
// llava projector
if (ctx->has_llava_projector) {
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
@@ -1443,14 +1464,26 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
int n = gguf_get_arr_n(ctx, idx);
const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx);
for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) {
hparams.image_grid_pinpoints[i] = pinpoints[i];
for (int i = 0; i < n; ++i) {
hparams.image_grid_pinpoints.push_back(pinpoints[i]);
}
if (n < 32)
hparams.image_grid_pinpoints[n] = 0;
} catch (std::runtime_error & /*e*/) {
hparams.image_grid_pinpoints[0]=0;
}
} catch (std::runtime_error & /*e*/) { }
// Load the vision feature layer indices if they are explicitly provided;
// if multiple vision feature layers are present, the values will be concatenated
// to form the final visual features.
// NOTE: gguf conversions should standardize the values of the vision feature layer to
// be non-negative, since we use -1 to mark values as unset here.
try {
int idx = get_key_idx(ctx, KEY_FEATURE_LAYER);
int n = gguf_get_arr_n(ctx, idx);
const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx);
for (int i = 0; i < n; ++i) {
hparams.vision_feature_layer.insert(vision_feature_layer[i]);
}
} catch (std::runtime_error & /*e*/) { }
try {
int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE);
@@ -1476,6 +1509,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
new_clip->image_std[i] = std_data[i];
}
// Calculate the deepest feature layer based on hparams and projector type
new_clip->max_feature_layer = get_deepest_feature_layer(new_clip);
if (verbosity >= 2) {
LOG_INF("\n%s: vision model hparams\n", __func__);
LOG_INF("image_size %d\n", hparams.image_size);
@@ -1489,8 +1525,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]);
LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]);
LOG_INF("v_image_grid_pinpoints: ");
for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) {
LOG_INF("%d ", hparams.image_grid_pinpoints[i]);
for (const auto & pp : hparams.image_grid_pinpoints) {
LOG_INF("%d ", pp);
}
LOG_INF("\n");
LOG_INF("v_vision_feature_layer: ");
for (const auto & feature_layer: hparams.vision_feature_layer) {
LOG_INF("%d ", feature_layer);
}
LOG_INF("\n");
LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type);
@@ -1729,11 +1770,11 @@ void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) {
}
}
static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) {
void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
img->nx = nx;
img->ny = ny;
img->buf.resize(3 * nx * ny);
memcpy(img->buf.data(), data, img->buf.size());
memcpy(img->buf.data(), rgb_pixels, img->buf.size());
}
bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
@@ -1743,7 +1784,7 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
LOG_ERR("%s: failed to load image '%s'\n", __func__, fname);
return false;
}
build_clip_img_from_data(data, nx, ny, img);
clip_build_img_from_pixels(data, nx, ny, img);
stbi_image_free(data);
return true;
}
@@ -1755,7 +1796,7 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length
LOG_ERR("%s: failed to decode image bytes\n", __func__);
return false;
}
build_clip_img_from_data(data, nx, ny, img);
clip_build_img_from_pixels(data, nx, ny, img);
stbi_image_free(data);
return true;
}
@@ -2235,10 +2276,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
}
}
} else {
if (params.image_grid_pinpoints[0] != 0) {
if (!params.image_grid_pinpoints.empty()) {
// "spatial_unpad" with "anyres" processing for llava-1.6
std::vector<std::pair<int, int>> possible_resolutions;
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
for (size_t i = 0; i < params.image_grid_pinpoints.size(); i+=2) {
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
}
std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);
@@ -2404,7 +2445,14 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
}
const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.image_grid_pinpoints;
if (ctx->vision_model.hparams.image_grid_pinpoints.size()) {
return &ctx->vision_model.hparams.image_grid_pinpoints.front();
}
return nullptr;
}
size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.image_grid_pinpoints.size();
}
int clip_n_patches(const struct clip_ctx * ctx) {
@@ -2712,9 +2760,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
if (!ctx->has_glm_projector) {
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
// The patches vector is used to get rows to index into the embeds with;
// we should skip dim 0 only if we have CLS to avoid going out of bounds
// when retrieving the rows.
int patch_offset = ctx->has_class_embedding ? 1 : 0;
int* patches_data = (int*)malloc(ggml_nbytes(patches));
for (int i = 0; i < num_patches; i++) {
patches_data[i] = i + 1;
patches_data[i] = i + patch_offset;
}
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
free(patches_data);
@@ -2925,6 +2977,28 @@ bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
return ctx->has_qwen2vl_merger;
}
// Determine the number of encoder layers to iterate over
int get_deepest_feature_layer(const struct clip_ctx * ctx) {
// Get the index of the second to last layer; this is the
// default for models that have a llava projector
const auto & hparams = ctx->vision_model.hparams;
int n_layer = hparams.n_layer - 1;
int deepest_feature_layer = -1;
// Handle other projectors; incrementing here indicates that we
// should use the last encoder layer for the vision features.
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
n_layer += 1;
}
// If we set explicit vision feature layers, only go up to the deepest one
for (const auto & feature_layer : hparams.vision_feature_layer) {
if (feature_layer > deepest_feature_layer) {
deepest_feature_layer = feature_layer;
}
}
return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;

View File

@@ -55,6 +55,7 @@ CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
@@ -73,6 +74,12 @@ CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
/**
* Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
* The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes
*/
CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
@@ -89,11 +96,13 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx);
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
#ifdef __cplusplus
}

View File

@@ -6,7 +6,7 @@ import re
import torch
import numpy as np
from gguf import *
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, SiglipVisionModel
TEXT = "clip.text"
VISION = "clip.vision"
@@ -37,6 +37,18 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b
def get_tensor_name(name: str) -> str:
# Standardize the transformers llava next keys for
# image newline / mm projector with the classes in haotian-liu LLaVA
if name == "image_newline":
return "model.image_newline"
if name.startswith("multi_modal_projector"):
name = name.replace("multi_modal_projector", "mm")
if "linear_1" in name:
name = name.replace("linear_1", "0")
if "linear_2" in name:
name = name.replace("linear_2", "2")
return name
if "projection" in name:
return name
if "mm_projector" in name:
@@ -83,8 +95,14 @@ ap.add_argument("--vision-only", action="store_true", required=False,
help="Save a vision-only model. It can't be used to encode texts")
ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
# Selectable visual encoders that are compatible with this script
encoder_group = ap.add_mutually_exclusive_group()
encoder_group.add_argument("--clip-model-is-openclip", action="store_true", required=False,
help="The clip model is from openclip (for ViT-SO400M type))")
encoder_group.add_argument("--clip-model-is-siglip", action="store_true", required=False,
help="the visual encoder is Siglip.")
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
@@ -109,7 +127,12 @@ if args.use_f32:
# output in the same directory as the model if output_dir is None
dir_model = args.model_dir
if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
if (
args.clip_model_is_vision or
not os.path.exists(dir_model + "/vocab.json") or
args.clip_model_is_openclip or
args.clip_model_is_siglip
):
vocab = None
tokens = None
else:
@@ -137,7 +160,10 @@ ftype = 1
if args.use_f32:
ftype = 0
if args.clip_model_is_vision or args.clip_model_is_openclip:
if args.clip_model_is_siglip:
model = SiglipVisionModel.from_pretrained(dir_model)
processor = None
elif args.clip_model_is_vision or args.clip_model_is_openclip:
model = CLIPVisionModel.from_pretrained(dir_model)
processor = None
else:
@@ -187,26 +213,71 @@ else:
if has_text_encoder:
assert t_hparams is not None
assert tokens is not None
if args.clip_model_is_siglip:
text_projection_dim = 0
else:
text_projection_dim = t_hparams.get("projection_dim", config["projection_dim"])
# text_model hparams
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"])
fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"]))
fout.add_uint32("clip.text.projection_dim", text_projection_dim)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"])
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"])
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
fout.add_token_list(tokens)
def get_non_negative_vision_feature_layers(v_hparams):
"""
Determine the vision feature layer(s) for the llava model, which are indices into the
hidden states of the visual encoder. Note that the hidden states array generally takes the
form:
[<emb input>, <output of enc block 0>, ... <output of enc block num_hidden_layers>]
so feature indices should be offset as n+1 to get the output of encoder block n.
We convert all vision feature layers to non-negative so that -1 can be used in
the model as an unset value. If no vision feature layer is found, we leave it unset.
"""
num_hidden_layers = v_hparams["num_hidden_layers"]
to_non_negative = lambda layer_idx: layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
feature_layers_key = None
# Key used for llava models in transformers
if "vision_feature_layer" in config:
feature_layers_key = "vision_feature_layer"
# Key used for llava models in the original format
elif "mm_vision_select_layer" in config:
feature_layers_key = "mm_vision_select_layer"
if feature_layers_key is not None:
feature_layers = config[feature_layers_key]
if isinstance(feature_layers, int):
feature_layers = [feature_layers]
return [to_non_negative(feature_layer) for feature_layer in feature_layers]
# Determine if we have explicitly specified vision feature layers in our config
feature_layers = get_non_negative_vision_feature_layers(v_hparams)
if has_vision_encoder:
# vision_model hparams
# Siglip does not have a visual projector; set projection dim to 0
if args.clip_model_is_siglip:
visual_projection_dim = 0
else:
visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"])
# set vision_model hparams
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"])
fout.add_uint32("clip.vision.projection_dim", v_hparams.get("projection_dim", config["projection_dim"]))
fout.add_uint32("clip.vision.projection_dim", visual_projection_dim)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
if feature_layers:
block_count = max(feature_layers)
else:
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
# /**
# "image_grid_pinpoints": [
@@ -258,7 +329,8 @@ if has_vision_encoder:
fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"])
if "mm_projector_type" in v_hparams:
fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"])
if feature_layers:
fout.add_array("clip.vision.feature_layer", feature_layers)
if processor is not None:
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue]
@@ -274,7 +346,13 @@ fout.add_bool("clip.use_gelu", use_gelu)
if has_llava_projector:
model.vision_model.encoder.layers.pop(-1)
# By default, we drop the last layer for llava projector
# models unless we have explicitly set vision feature layers
if feature_layers is None:
model.vision_model.encoder.layers.pop(-1)
else:
model.vision_model.encoder.layers = model.vision_model.encoder.layers[:max(feature_layers)]
projector = torch.load(args.llava_projector)
for name, data in projector.items():
name = get_tensor_name(name)

View File

@@ -353,9 +353,10 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
const int32_t * image_grid = clip_image_grid(ctx_clip);
const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip);
std::vector<std::pair<int, int>> grid_pinpoints;
for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) {
for (size_t i = 0; i < num_gridpoints; i += 2) {
grid_pinpoints.push_back({image_grid[i], image_grid[i+1]});
}
@@ -405,7 +406,8 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx *
}
bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
int num_max_patches = 6;
// Granite vision uses up to 10 patches + base patch
int num_max_patches = 11;
if (clip_is_minicpmv(ctx_clip)) {
num_max_patches = 10;
}

View File

@@ -33,6 +33,33 @@ def save_model(model, file_path, file_type):
else:
torch.save(model, file_path)
# Helpers to match weight names from specific components or
# determine if a saved shard contains that component
def is_vision_tower(weight_name):
return (
weight_name.startswith("model.vision_tower") or
weight_name.startswith("vit.") or
weight_name.startswith("vision_tower")
)
def is_newline(weight_name):
return (
weight_name.startswith("model.image_newline") or
weight_name.startswith("image_newline")
)
def is_mm_projector(weight_name):
return (
weight_name.startswith("model.mm_projector") or
weight_name.startswith("vision_proj.") or
weight_name.startswith("multi_modal_projector")
)
def newline_criteria(checkpoint):
return any(is_newline(k) for k in checkpoint.keys())
def proj_criteria(checkpoint):
return any(is_mm_projector(k) for k in checkpoint.keys())
# Adapted function to clean vision tower from checkpoint
def clean_vision_tower_from_checkpoint(checkpoint_path):
@@ -40,7 +67,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
# file_type = 'pytorch'
model_path = os.path.dirname(checkpoint_path)
print(f"Searching for vision tower tensors in {checkpoint_path}")
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
clip_tensors = [k for k, v in checkpoint.items() if is_vision_tower(k)]
if len(clip_tensors) > 0:
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
@@ -84,12 +111,6 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
return newline_checkpoint_path, projector_checkpoint_path
def newline_criteria(checkpoint):
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
def proj_criteria(checkpoint):
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
# Command-line interface setup
ap = argparse.ArgumentParser()
@@ -123,14 +144,14 @@ first_checkpoint = None
if newline_checkpoint_path is not None:
print(f"Taking newline from {newline_checkpoint_path}")
first_checkpoint, file_type = load_model(newline_checkpoint_path)
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
first_mm_tensors = [k for k, v in first_checkpoint.items() if is_newline(k)]
# Load the checkpoint
mm_tensors = []
last_checkpoint = None
if projector_checkpoint_path is not None:
last_checkpoint, file_type = load_model(projector_checkpoint_path)
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
mm_tensors = [k for k, v in last_checkpoint.items() if is_mm_projector(k)]
if len(mm_tensors) == 0:
if last_checkpoint is not None:
@@ -155,5 +176,5 @@ if len(projector) > 0:
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
print("Done!")
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")

View File

@@ -4,7 +4,7 @@
#include "log.h"
#include "sampling.h"
#include "llama.h"
#include "chat-template.hpp"
#include "chat.h"
#include <cstdio>
#include <cstring>
@@ -158,7 +158,7 @@ int main(int argc, char ** argv) {
}
const llama_vocab * vocab = llama_model_get_vocab(model);
auto chat_templates = common_chat_templates_from_model(model, params.chat_template);
auto chat_templates = common_chat_templates_init(model, params.chat_template);
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
@@ -201,7 +201,7 @@ int main(int argc, char ** argv) {
}
// auto enable conversation mode if chat template is available
const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default;
const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get());
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
if (has_chat_template) {
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
@@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
// print chat template example in conversation mode
if (params.conversation_mode) {
if (params.enable_chat_template) {
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
} else {
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
}
@@ -264,9 +264,11 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp;
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
common_chat_msg new_msg{role, content, {}};
auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
chat_msgs.push_back({role, content, {}});
common_chat_msg new_msg;
new_msg.role = role;
new_msg.content = content;
auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja);
chat_msgs.push_back(new_msg);
LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted;
};
@@ -755,11 +757,14 @@ int main(int argc, char ** argv) {
// check for reverse prompt using special tokens
llama_token last_token = common_sampler_last(smpl);
if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) {
if (params.interactive) {
is_interacting = true;
for (auto token : antiprompt_token) {
if (token == last_token) {
if (params.interactive) {
is_interacting = true;
}
is_antiprompt = true;
break;
}
is_antiprompt = true;
}
if (is_antiprompt) {

View File

@@ -24,7 +24,7 @@
#include <string>
#include <vector>
#include "chat-template.hpp"
#include "chat.h"
#include "common.h"
#include "json.hpp"
#include "linenoise.cpp/linenoise.h"
@@ -113,6 +113,7 @@ class Opt {
llama_context_params ctx_params;
llama_model_params model_params;
std::string model_;
std::string chat_template_file;
std::string user;
bool use_jinja = false;
int context_size = -1, ngl = -1;
@@ -148,6 +149,16 @@ class Opt {
return 0;
}
int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) {
if (i + 1 >= argc) {
return 1;
}
option_value = argv[++i];
return 0;
}
int parse(int argc, const char ** argv) {
bool options_parsing = true;
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
@@ -169,6 +180,11 @@ class Opt {
verbose = true;
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
use_jinja = true;
} else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0){
if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) {
return 1;
}
use_jinja = true;
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
help = true;
return 0;
@@ -207,6 +223,11 @@ class Opt {
"Options:\n"
" -c, --context-size <value>\n"
" Context size (default: %d)\n"
" --chat-template-file <path>\n"
" Path to the file containing the chat template to use with the model.\n"
" Only supports jinja templates and implicitly sets the --jinja flag.\n"
" --jinja\n"
" Use jinja templating for the chat template of the model\n"
" -n, -ngl, --ngl <value>\n"
" Number of GPU layers (default: %d)\n"
" --temp <value>\n"
@@ -261,13 +282,12 @@ static int get_terminal_width() {
#endif
}
#ifdef LLAMA_USE_CURL
class File {
public:
FILE * file = nullptr;
FILE * open(const std::string & filename, const char * mode) {
file = fopen(filename.c_str(), mode);
file = ggml_fopen(filename.c_str(), mode);
return file;
}
@@ -303,6 +323,20 @@ class File {
return 0;
}
std::string to_string() {
fseek(file, 0, SEEK_END);
const size_t size = ftell(file);
fseek(file, 0, SEEK_SET);
std::string out;
out.resize(size);
const size_t read_size = fread(&out[0], 1, size, file);
if (read_size != size) {
printe("Error reading file: %s", strerror(errno));
}
return out;
}
~File() {
if (fd >= 0) {
# ifdef _WIN32
@@ -327,6 +361,7 @@ class File {
# endif
};
#ifdef LLAMA_USE_CURL
class HttpClient {
public:
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
@@ -557,7 +592,7 @@ class LlamaData {
llama_model_ptr model;
llama_sampler_ptr sampler;
llama_context_ptr context;
std::vector<llama_chat_message> messages;
std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
std::list<std::string> msg_strs;
std::vector<char> fmtted;
@@ -834,44 +869,23 @@ static void add_message(const char * role, const std::string & text, LlamaData &
}
// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
if (use_jinja) {
json messages = json::array();
for (const auto & msg : llama_data.messages) {
messages.push_back({
{"role", msg.role},
{"content", msg.content},
});
}
try {
minja::chat_template_inputs tmpl_inputs;
tmpl_inputs.messages = messages;
tmpl_inputs.add_generation_prompt = append;
minja::chat_template_options tmpl_opts;
tmpl_opts.use_bos_token = false;
tmpl_opts.use_eos_token = false;
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
} catch (const std::exception & e) {
printe("failed to render the chat template: %s\n", e.what());
return -1;
}
}
int result = llama_chat_apply_template(
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size());
static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
common_chat_templates_inputs inputs;
for (const auto & msg : llama_data.messages) {
common_chat_msg cmsg;
cmsg.role = msg.role;
cmsg.content = msg.content;
inputs.messages.push_back(cmsg);
}
inputs.add_generation_prompt = append;
inputs.use_jinja = use_jinja;
return result;
auto chat_params = common_chat_templates_apply(tmpls, inputs);
// TODO: use other params for tool calls.
auto result = chat_params.prompt;
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
}
// Function to tokenize the prompt
@@ -963,7 +977,8 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
}
static int read_user_input(std::string & user_input) {
static const char * prompt_prefix = "> ";
static const char * prompt_prefix_env = std::getenv("LLAMA_PROMPT_PREFIX");
static const char * prompt_prefix = prompt_prefix_env ? prompt_prefix_env : "> ";
#ifdef WIN32
printf("\r" LOG_CLR_TO_EOL LOG_COL_DEFAULT "%s", prompt_prefix);
@@ -1015,8 +1030,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
}
// Helper function to apply the chat template and handle errors
static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
if (new_len < 0) {
printe("failed to apply the chat template\n");
return -1;
@@ -1074,40 +1089,68 @@ static int get_user_input(std::string & user_input, const std::string & user) {
return 0;
}
// Reads a chat template file to be used
static std::string read_chat_template_file(const std::string & chat_template_file) {
File file;
if (!file.open(chat_template_file, "r")) {
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
return "";
}
return file.to_string();
}
static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data,
const common_chat_templates_ptr & chat_templates, int & prev_len,
const bool stdout_a_terminal) {
add_message("user", opt.user.empty() ? user_input : opt.user, llama_data);
int new_len;
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) {
return 1;
}
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
std::string response;
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
return 1;
}
if (!opt.user.empty()) {
return 2;
}
add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) {
return 1;
}
return 0;
}
// Main chat loop function
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
static int chat_loop(LlamaData & llama_data, const Opt & opt) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
GGML_ASSERT(chat_templates.template_default);
std::string chat_template;
if (!opt.chat_template_file.empty()) {
chat_template = read_chat_template_file(opt.chat_template_file);
}
common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template);
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
std::string user_input;
if (get_user_input(user_input, user) == 1) {
if (get_user_input(user_input, opt.user) == 1) {
return 0;
}
add_message("user", user.empty() ? user_input : user, llama_data);
int new_len;
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
if (ret == 1) {
return 1;
}
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
std::string response;
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
return 1;
}
if (!user.empty()) {
} else if (ret == 2) {
break;
}
add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
return 1;
}
}
return 0;
@@ -1165,7 +1208,7 @@ int main(int argc, const char ** argv) {
return 1;
}
if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
if (chat_loop(llama_data, opt)) {
return 1;
}

View File

@@ -13,6 +13,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
* Multimodal (wip)
* Monitoring endpoints
* Schema-constrained JSON response format
* [Function calling](../../docs/function-calling.md) / tool use for ~any model
The project is under active development, and we are [looking for feedback and contributors](https://github.com/ggml-org/llama.cpp/issues/4216).
@@ -1120,381 +1121,9 @@ curl http://localhost:8080/v1/chat/completions \
*Tool call support*
[Function calling](https://platform.openai.com/docs/guides/function-calling) is supported for all models (see https://github.com/ggml-org/llama.cpp/pull/9639):
[OpenAI-style function calling](https://platform.openai.com/docs/guides/function-calling) is supported with the `--jinja` flag (and may require a `--chat-template-file` override to get the right tool-use compatible Jinja template; worst case, `--chat-template chatml` may also work).
- Requires `--jinja` flag
- Native tool call formats supported:
- Llama 3.1 / 3.3 (including builtin tools support - tool names for `wolfram_alpha`, `web_search` / `brave_search`, `code_interpreter`), Llama 3.2
- Functionary v3.1 / v3.2
- Hermes 2/3, Qwen 2.5
- Mistral Nemo
- Firefunction v2
- Command R7B
- DeepSeek R1 (WIP / seems reluctant to call any tools?)
<details>
<summary>Show some common templates and which format handler they use</summary>
| Template | Format |
|----------|--------|
| Almawave-Velvet-14B.jinja | Hermes 2 Pro |
| AtlaAI-Selene-1-Mini-Llama-3.1-8B.jinja | Llama 3.x |
| CohereForAI-aya-expanse-8b.jinja | Generic |
| CohereForAI-c4ai-command-r-plus-default.jinja | Generic |
| CohereForAI-c4ai-command-r-plus-rag.jinja | Generic |
| CohereForAI-c4ai-command-r-plus-tool_use.jinja | Generic |
| CohereForAI-c4ai-command-r7b-12-2024-default.jinja | Command R7B (extract reasoning) |
| CohereForAI-c4ai-command-r7b-12-2024-rag.jinja | Command R7B (extract reasoning) |
| CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja | Command R7B (extract reasoning) |
| CohereForAI-c4ai-command-r7b-12-2024.jinja | Generic |
| DavieLion-Llama-3.2-1B-SPIN-iter3.jinja | Generic |
| Delta-Vector-Rei-12B.jinja | Mistral Nemo |
| EpistemeAI-Mistral-Nemo-Instruct-12B-Philosophy-Math.jinja | Mistral Nemo |
| FlofloB-83k_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit.jinja | Hermes 2 Pro |
| FlofloB-test_continued_pretraining_Phi-3-mini-4k-instruct_Unsloth_merged_16bit.jinja | Generic |
| HelpingAI-HAI-SER.jinja | Generic |
| HuggingFaceTB-SmolLM2-1.7B-Instruct.jinja | Generic |
| HuggingFaceTB-SmolLM2-135M-Instruct.jinja | Generic |
| HuggingFaceTB-SmolLM2-360M-Instruct.jinja | Generic |
| INSAIT-Institute-BgGPT-Gemma-2-27B-IT-v1.0.jinja | Generic |
| Ihor-Text2Graph-R1-Qwen2.5-0.5b.jinja | Hermes 2 Pro |
| Infinigence-Megrez-3B-Instruct.jinja | Generic |
| Josephgflowers-TinyLlama_v1.1_math_code-world-test-1.jinja | Generic |
| LGAI-EXAONE-EXAONE-3.5-2.4B-Instruct.jinja | Generic |
| LGAI-EXAONE-EXAONE-3.5-7.8B-Instruct.jinja | Generic |
| LatitudeGames-Wayfarer-12B.jinja | Generic |
| Magpie-Align-Llama-3-8B-Magpie-Align-v0.1.jinja | Generic |
| Magpie-Align-Llama-3.1-8B-Magpie-Align-v0.1.jinja | Generic |
| MaziyarPanahi-calme-3.2-instruct-78b.jinja | Generic |
| MiniMaxAI-MiniMax-Text-01.jinja | Generic |
| MiniMaxAI-MiniMax-VL-01.jinja | Generic |
| NaniDAO-deepseek-r1-qwen-2.5-32B-ablated.jinja | DeepSeek R1 (extract reasoning) |
| NexaAIDev-Octopus-v2.jinja | Generic |
| NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | Generic |
| NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | Hermes 2 Pro |
| NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | Generic |
| NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | Hermes 2 Pro |
| NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | Generic |
| NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | Hermes 2 Pro |
| NovaSky-AI-Sky-T1-32B-Flash.jinja | Hermes 2 Pro |
| NovaSky-AI-Sky-T1-32B-Preview.jinja | Hermes 2 Pro |
| OnlyCheeini-greesychat-turbo.jinja | Generic |
| Orenguteng-Llama-3.1-8B-Lexi-Uncensored-V2.jinja | Llama 3.x |
| OrionStarAI-Orion-14B-Chat.jinja | Generic |
| PowerInfer-SmallThinker-3B-Preview.jinja | Generic |
| PrimeIntellect-INTELLECT-1-Instruct.jinja | Generic |
| Qwen-QVQ-72B-Preview.jinja | Generic |
| Qwen-QwQ-32B-Preview.jinja | Hermes 2 Pro |
| Qwen-Qwen1.5-7B-Chat.jinja | Generic |
| Qwen-Qwen2-7B-Instruct.jinja | Generic |
| Qwen-Qwen2-VL-72B-Instruct.jinja | Generic |
| Qwen-Qwen2-VL-7B-Instruct.jinja | Generic |
| Qwen-Qwen2.5-0.5B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-1.5B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-14B-Instruct-1M.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-14B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-32B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-32B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-3B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-72B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-7B-Instruct-1M.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-7B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-7B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-Coder-32B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-Coder-7B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-Math-1.5B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-Math-7B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-VL-3B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-VL-72B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-VL-7B-Instruct.jinja | Hermes 2 Pro |
| RWKV-Red-Team-ARWKV-7B-Preview-0.1.jinja | Hermes 2 Pro |
| SakanaAI-TinySwallow-1.5B-Instruct.jinja | Hermes 2 Pro |
| SakanaAI-TinySwallow-1.5B.jinja | Hermes 2 Pro |
| Sao10K-70B-L3.3-Cirrus-x1.jinja | Llama 3.x |
| SentientAGI-Dobby-Mini-Leashed-Llama-3.1-8B.jinja | Llama 3.x |
| SentientAGI-Dobby-Mini-Unhinged-Llama-3.1-8B.jinja | Llama 3.x |
| Steelskull-L3.3-Damascus-R1.jinja | Llama 3.x |
| Steelskull-L3.3-MS-Nevoria-70b.jinja | Llama 3.x |
| Steelskull-L3.3-Nevoria-R1-70b.jinja | Llama 3.x |
| THUDM-glm-4-9b-chat.jinja | Generic |
| THUDM-glm-edge-1.5b-chat.jinja | Generic |
| Tarek07-Progenitor-V1.1-LLaMa-70B.jinja | Llama 3.x |
| TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | Generic |
| TinyLlama-TinyLlama-1.1B-Chat-v1.0.jinja | Generic |
| UCLA-AGI-Mistral7B-PairRM-SPPO-Iter3.jinja | Generic |
| ValiantLabs-Llama3.1-8B-Enigma.jinja | Llama 3.x |
| abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | Generic |
| ai21labs-AI21-Jamba-1.5-Large.jinja | Generic |
| allenai-Llama-3.1-Tulu-3-405B-SFT.jinja | Generic |
| allenai-Llama-3.1-Tulu-3-405B.jinja | Generic |
| allenai-Llama-3.1-Tulu-3-8B.jinja | Generic |
| arcee-ai-Virtuoso-Lite.jinja | Hermes 2 Pro |
| arcee-ai-Virtuoso-Medium-v2.jinja | Hermes 2 Pro |
| arcee-ai-Virtuoso-Small-v2.jinja | Hermes 2 Pro |
| avemio-GRAG-NEMO-12B-ORPO-HESSIAN-AI.jinja | Generic |
| bespokelabs-Bespoke-Stratos-7B.jinja | Hermes 2 Pro |
| bfuzzy1-acheron-m1a-llama.jinja | Generic |
| bofenghuang-vigogne-2-70b-chat.jinja | Generic |
| bytedance-research-UI-TARS-72B-DPO.jinja | Generic |
| bytedance-research-UI-TARS-7B-DPO.jinja | Generic |
| bytedance-research-UI-TARS-7B-SFT.jinja | Generic |
| carsenk-phi3.5_mini_exp_825_uncensored.jinja | Generic |
| cyberagent-DeepSeek-R1-Distill-Qwen-14B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
| cyberagent-DeepSeek-R1-Distill-Qwen-32B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
| databricks-dbrx-instruct.jinja | Generic |
| deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | Generic |
| deepseek-ai-DeepSeek-Coder-V2-Lite-Base.jinja | Generic |
| deepseek-ai-DeepSeek-Coder-V2-Lite-Instruct.jinja | Generic |
| deepseek-ai-DeepSeek-R1-Distill-Llama-70B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Qwen-1.5B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Qwen-14B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Zero.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-V2-Lite.jinja | Generic |
| deepseek-ai-DeepSeek-V2.5.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-V3.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-deepseek-coder-33b-instruct.jinja | Generic |
| deepseek-ai-deepseek-coder-6.7b-instruct.jinja | Generic |
| deepseek-ai-deepseek-coder-7b-instruct-v1.5.jinja | Generic |
| deepseek-ai-deepseek-llm-67b-chat.jinja | Generic |
| deepseek-ai-deepseek-llm-7b-chat.jinja | Generic |
| dicta-il-dictalm2.0-instruct.jinja | Generic |
| ehristoforu-Falcon3-8B-Franken-Basestruct.jinja | Hermes 2 Pro |
| fireworks-ai-llama-3-firefunction-v2.jinja | FireFunction v2 |
| godlikehhd-alpaca_data_sampled_ifd_new_5200.jinja | Hermes 2 Pro |
| godlikehhd-alpaca_data_score_max_0.7_2600.jinja | Hermes 2 Pro |
| google-gemma-2-27b-it.jinja | Generic |
| google-gemma-2-2b-it.jinja | Generic |
| google-gemma-2-2b-jpn-it.jinja | Generic |
| google-gemma-7b-it.jinja | Generic |
| huihui-ai-DeepSeek-R1-Distill-Llama-70B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-DeepSeek-R1-Distill-Llama-8B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-DeepSeek-R1-Distill-Qwen-14B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-DeepSeek-R1-Distill-Qwen-32B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-DeepSeek-R1-Distill-Qwen-7B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-Qwen2.5-14B-Instruct-1M-abliterated.jinja | Hermes 2 Pro |
| ibm-granite-granite-3.1-8b-instruct.jinja | Generic |
| indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | Generic |
| inflatebot-MN-12B-Mag-Mell-R1.jinja | Generic |
| jinaai-ReaderLM-v2.jinja | Generic |
| kms7530-chemeng_qwen-math-7b_24_1_100_1_nonmath.jinja | Hermes 2 Pro |
| knifeayumu-Cydonia-v1.3-Magnum-v4-22B.jinja | Mistral Nemo |
| langgptai-qwen1.5-7b-chat-sa-v0.1.jinja | Generic |
| lightblue-DeepSeek-R1-Distill-Qwen-7B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
| mattshumer-Reflection-Llama-3.1-70B.jinja | Generic |
| meetkai-functionary-medium-v3.1.jinja | Functionary v3.1 Llama 3.1 |
| meetkai-functionary-medium-v3.2.jinja | Functionary v3.2 |
| meta-llama-Llama-2-7b-chat-hf.jinja | Generic |
| meta-llama-Llama-3.1-8B-Instruct.jinja | Llama 3.x |
| meta-llama-Llama-3.2-11B-Vision-Instruct.jinja | Llama 3.x |
| meta-llama-Llama-3.2-1B-Instruct.jinja | Llama 3.x |
| meta-llama-Llama-3.2-3B-Instruct.jinja | Llama 3.x |
| meta-llama-Llama-3.3-70B-Instruct.jinja | Llama 3.x |
| meta-llama-Meta-Llama-3-8B-Instruct.jinja | Generic |
| meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | Llama 3.x |
| microsoft-Phi-3-medium-4k-instruct.jinja | Generic |
| microsoft-Phi-3-mini-4k-instruct.jinja | Generic |
| microsoft-Phi-3-small-8k-instruct.jinja | Generic |
| microsoft-Phi-3.5-mini-instruct.jinja | Generic |
| microsoft-Phi-3.5-vision-instruct.jinja | Generic |
| microsoft-phi-4.jinja | Generic |
| migtissera-Tess-3-Mistral-Nemo-12B.jinja | Generic |
| ministral-Ministral-3b-instruct.jinja | Generic |
| mistralai-Codestral-22B-v0.1.jinja | Generic |
| mistralai-Mistral-7B-Instruct-v0.1.jinja | Generic |
| mistralai-Mistral-7B-Instruct-v0.2.jinja | Generic |
| mistralai-Mistral-7B-Instruct-v0.3.jinja | Mistral Nemo |
| mistralai-Mistral-Large-Instruct-2407.jinja | Mistral Nemo |
| mistralai-Mistral-Large-Instruct-2411.jinja | Generic |
| mistralai-Mistral-Nemo-Instruct-2407.jinja | Mistral Nemo |
| mistralai-Mistral-Small-24B-Instruct-2501.jinja | Generic |
| mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | Generic |
| mkurman-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
| mlabonne-AlphaMonarch-7B.jinja | Generic |
| mlx-community-Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32.jinja | Hermes 2 Pro |
| mlx-community-Qwen2.5-VL-7B-Instruct-8bit.jinja | Hermes 2 Pro |
| mobiuslabsgmbh-DeepSeek-R1-ReDistill-Qwen-1.5B-v1.1.jinja | DeepSeek R1 (extract reasoning) |
| netcat420-MFANNv0.20.jinja | Generic |
| netcat420-MFANNv0.24.jinja | Generic |
| netease-youdao-Confucius-o1-14B.jinja | Hermes 2 Pro |
| nvidia-AceMath-7B-RM.jinja | Hermes 2 Pro |
| nvidia-Eagle2-1B.jinja | Hermes 2 Pro |
| nvidia-Eagle2-9B.jinja | Hermes 2 Pro |
| nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | Llama 3.x |
| onnx-community-DeepSeek-R1-Distill-Qwen-1.5B-ONNX.jinja | DeepSeek R1 (extract reasoning) |
| open-thoughts-OpenThinker-7B.jinja | Hermes 2 Pro |
| openchat-openchat-3.5-0106.jinja | Generic |
| pankajmathur-orca_mini_v6_8b.jinja | Generic |
| princeton-nlp-Mistral-7B-Base-SFT-RDPO.jinja | Generic |
| princeton-nlp-Mistral-7B-Instruct-DPO.jinja | Generic |
| princeton-nlp-Mistral-7B-Instruct-RDPO.jinja | Generic |
| prithivMLmods-Bellatrix-Tiny-1.5B-R1.jinja | Hermes 2 Pro |
| prithivMLmods-Bellatrix-Tiny-1B-R1.jinja | Llama 3.x |
| prithivMLmods-Bellatrix-Tiny-1B-v3.jinja | Generic |
| prithivMLmods-Bellatrix-Tiny-3B-R1.jinja | Llama 3.x |
| prithivMLmods-Blaze-14B-xElite.jinja | Generic |
| prithivMLmods-Calcium-Opus-14B-Elite2-R1.jinja | Hermes 2 Pro |
| prithivMLmods-Calme-Ties-78B.jinja | Generic |
| prithivMLmods-Calme-Ties2-78B.jinja | Generic |
| prithivMLmods-Calme-Ties3-78B.jinja | Generic |
| prithivMLmods-ChemQwen2-vL.jinja | Generic |
| prithivMLmods-GWQ2b.jinja | Generic |
| prithivMLmods-LatexMind-2B-Codec.jinja | Generic |
| prithivMLmods-Llama-3.2-6B-AlgoCode.jinja | Llama 3.x |
| prithivMLmods-Megatron-Opus-14B-Exp.jinja | Hermes 2 Pro |
| prithivMLmods-Megatron-Opus-14B-Stock.jinja | Hermes 2 Pro |
| prithivMLmods-Megatron-Opus-7B-Exp.jinja | Hermes 2 Pro |
| prithivMLmods-Omni-Reasoner-Merged.jinja | Hermes 2 Pro |
| prithivMLmods-Omni-Reasoner4-Merged.jinja | Hermes 2 Pro |
| prithivMLmods-Primal-Opus-14B-Optimus-v1.jinja | Hermes 2 Pro |
| prithivMLmods-QwQ-Math-IO-500M.jinja | Hermes 2 Pro |
| prithivMLmods-Qwen-7B-Distill-Reasoner.jinja | DeepSeek R1 (extract reasoning) |
| prithivMLmods-Qwen2.5-1.5B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro |
| prithivMLmods-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
| prithivMLmods-Qwen2.5-32B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro |
| prithivMLmods-Qwen2.5-7B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
| prithivMLmods-Triangulum-v2-10B.jinja | Hermes 2 Pro |
| qingy2024-Falcon3-2x10B-MoE-Instruct.jinja | Hermes 2 Pro |
| rubenroy-Zurich-14B-GCv2-5m.jinja | Hermes 2 Pro |
| rubenroy-Zurich-7B-GCv2-5m.jinja | Hermes 2 Pro |
| silma-ai-SILMA-Kashif-2B-Instruct-v1.0.jinja | Generic |
| simplescaling-s1-32B.jinja | Hermes 2 Pro |
| sometimesanotion-Lamarck-14B-v0.7.jinja | Hermes 2 Pro |
| sonthenguyen-zephyr-sft-bnb-4bit-DPO-mtbr-180steps.jinja | Generic |
| sthenno-tempesthenno-icy-0130.jinja | Generic |
| sumink-qwft.jinja | Hermes 2 Pro |
| teknium-OpenHermes-2.5-Mistral-7B.jinja | Generic |
| thirdeyeai-elevate360m.jinja | Generic |
| tiiuae-Falcon3-10B-Instruct.jinja | Hermes 2 Pro |
| unsloth-DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit.jinja | DeepSeek R1 (extract reasoning) |
| unsloth-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) |
| unsloth-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) |
| unsloth-Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit.jinja | Generic |
| upstage-solar-pro-preview-instruct.jinja | Generic |
| whyhow-ai-PatientSeek.jinja | Generic |
| xwen-team-Xwen-72B-Chat.jinja | Hermes 2 Pro |
| xwen-team-Xwen-7B-Chat.jinja | Hermes 2 Pro |
This table can be generated with:
```bash
./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
```
</details>
- Generic tool call is supported when the template isn't recognized by native format handlers (you'll see `Chat format: Generic` in the logs).
- Use `--chat-template-file` to override the template when appropriate (see examples below)
- Generic support may consume more tokens and be less efficient than a model's native format.
- Run with:
```shell
# Native support:
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
# Native support for DeepSeek R1 works best w/ our own template (official template buggy)
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M \
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
# Native support requires the right template for these GGUFs:
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \
--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
# Generic format support
llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0
llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K
```
- Test in CLI:
```bash
curl http://localhost:8080/v1/chat/completions -d '{
"model": "gpt-3.5-turbo",
"tools": [
{
"type":"function",
"function":{
"name":"python",
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters":{
"type":"object",
"properties":{
"code":{
"type":"string",
"description":"The code to run in the ipython interpreter."
}
},
"required":["code"]
}
}
}
],
"messages": [
{
"role": "user",
"content": "Print a hello world message with python."
}
]
}'
```
<details>
<summary>Show output</summary>
```json
{
"choices": [
{
"finish_reason": "tool",
"index": 0,
"message": {
"content": null,
"tool_calls": [
{
"name": "python",
"arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}"
}
],
"role": "assistant"
}
}
],
"created": 1727287211,
"model": "gpt-3.5-turbo",
"object": "chat.completion",
"usage": {
"completion_tokens": 16,
"prompt_tokens": 44,
"total_tokens": 60
},
"id": "chatcmpl-Htbgh9feMmGM0LEH2hmQvwsCxq3c6Ni8"
}
```
</details>
**See our [Function calling](../../docs/function-calling.md) docs** for more details, supported native tool call styles (generic tool call style is used as fallback) / examples of use.
### POST `/v1/embeddings`: OpenAI-compatible embeddings API

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@@ -274,7 +274,7 @@ struct server_task {
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
params.speculative.n_min = std::max(params.speculative.n_min, 2);
params.speculative.n_min = std::max(params.speculative.n_min, 0);
params.speculative.n_max = std::max(params.speculative.n_max, 0);
// Use OpenAI API logprobs only if n_probs wasn't provided
@@ -329,9 +329,6 @@ struct server_task {
}
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
@@ -1807,7 +1804,7 @@ struct server_context {
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
common_chat_templates chat_templates;
common_chat_templates_ptr chat_templates;
~server_context() {
// Clear any sampling context
@@ -1891,45 +1888,17 @@ struct server_context {
llama_init_dft.context.reset();
}
if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
chat_templates = common_chat_templates_init(model, params_base.chat_template);
try {
common_chat_format_example(chat_templates.get(), params.use_jinja);
} catch (const std::exception & e) {
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_from_model(model, "chatml");
} else {
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
chat_templates = common_chat_templates_init(model, "chatml");
}
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
return true;
}
bool validate_builtin_chat_template(bool use_jinja) const {
llama_chat_message chat[] = {{"user", "test"}};
if (use_jinja) {
auto templates = common_chat_templates_from_model(model, "");
common_chat_inputs inputs;
inputs.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
GGML_ASSERT(templates.template_default);
try {
common_chat_params_init(*templates.template_default, inputs);
if (templates.template_tool_use) {
common_chat_params_init(*templates.template_tool_use, inputs);
}
return true;
} catch (const std::exception & e) {
SRV_ERR("failed to apply template: %s\n", e.what());
return false;
}
} else {
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0;
}
}
void init() {
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
@@ -3656,7 +3625,7 @@ int main(int argc, char ** argv) {
}, {
{"name", "n_busy_slots_per_decode"},
{"help", "Average number of busy slots per llama_decode() call"},
{"value", (float) res_metrics->n_busy_slots_total / (float) res_metrics->n_decode_total}
{"value", (float) res_metrics->n_busy_slots_total / std::max((float) res_metrics->n_decode_total, 1.f)}
}}},
{"gauge", {{
{"name", "prompt_tokens_seconds"},
@@ -3822,13 +3791,15 @@ int main(int argc, char ** argv) {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model },
{ "chat_template", ctx_server.chat_templates.template_default->source() },
{ "bos_token", ctx_server.chat_templates.template_default->bos_token() },
{ "eos_token", ctx_server.chat_templates.template_default->eos_token() },
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
{ "build_info", build_info },
};
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
if (ctx_server.params_base.use_jinja) {
if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
data["chat_template_tool_use"] = tool_use_src;
}
}
res_ok(res, data);
@@ -4063,7 +4034,7 @@ int main(int argc, char ** argv) {
}
auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
@@ -4076,7 +4047,7 @@ int main(int argc, char ** argv) {
// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};
@@ -4263,6 +4234,11 @@ int main(int argc, char ** argv) {
// return;
//}
// if true, use TEI API format, otherwise use Jina API format
// Jina: https://jina.ai/reranker/
// TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
bool is_tei_format = body.contains("texts");
json query;
if (body.count("query") == 1) {
query = body.at("query");
@@ -4275,7 +4251,8 @@ int main(int argc, char ** argv) {
return;
}
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
std::vector<std::string> documents = json_value(body, "documents",
json_value(body, "texts", std::vector<std::string>()));
if (documents.empty()) {
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
return;
@@ -4320,7 +4297,12 @@ int main(int argc, char ** argv) {
}
// write JSON response
json root = format_response_rerank(body, responses);
json root = format_response_rerank(
body,
responses,
is_tei_format,
documents);
res_ok(res, root);
};
@@ -4482,8 +4464,8 @@ int main(int argc, char ** argv) {
// print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
ctx_server.chat_templates.template_default->source().c_str(),
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
common_chat_templates_source(ctx_server.chat_templates.get()),
common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
ctx_server.process_single_task(task);

View File

@@ -48,7 +48,7 @@ DEBUG=1 ./tests.sh -s -v -x
To run all the tests in a file:
```shell
./tests.sh unit/test_chat_completion.py.py -v -x
./tests.sh unit/test_chat_completion.py -v -x
```
To run a single test:

View File

@@ -21,6 +21,8 @@ def create_server():
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
@@ -44,7 +46,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"])
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
assert choice["finish_reason"] == finish_reason
@@ -169,6 +171,47 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
assert "error" in res.body
@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
(False, {"const": "42"}, 6, "\"42\""),
(True, {"const": "42"}, 6, "\"42\""),
])
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"json_schema": json_schema,
})
assert res.status_code == 200, f'Expected 200, got {res.status_code}'
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
(False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
(True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
])
def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "user", "content": "Does not matter what I say, does it?"},
],
"grammar": grammar,
})
assert res.status_code == 200, res.body
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
@pytest.mark.parametrize("messages", [
None,
"string",

View File

@@ -10,17 +10,20 @@ def create_server():
server = ServerPreset.jina_reranker_tiny()
TEST_DOCUMENTS = [
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]
def test_rerank():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"documents": [
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]
"documents": TEST_DOCUMENTS,
})
assert res.status_code == 200
assert len(res.body["results"]) == 4
@@ -38,6 +41,29 @@ def test_rerank():
assert least_relevant["index"] == 3
def test_rerank_tei_format():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"texts": TEST_DOCUMENTS,
})
assert res.status_code == 200
assert len(res.body) == 4
most_relevant = res.body[0]
least_relevant = res.body[0]
for doc in res.body:
if doc["score"] > most_relevant["score"]:
most_relevant = doc
if doc["score"] < least_relevant["score"]:
least_relevant = doc
assert most_relevant["score"] > least_relevant["score"]
assert most_relevant["index"] == 2
assert least_relevant["index"] == 3
@pytest.mark.parametrize("documents", [
[],
None,

View File

@@ -356,12 +356,12 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
])
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
@@ -401,7 +401,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
{
"role": "tool",
"name": "calculate",
"content": 0.55644242476,
"content": "0.55644242476",
"tool_call_id": "call_6789"
}
],
@@ -444,7 +444,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
(128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'none', "<think>\n?I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'none', "^I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
])

View File

@@ -7,14 +7,14 @@
// increase max payload length to allow use of larger context size
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
// disable Nagle's algorithm
#define CPPHTTPLIB_TCP_NODELAY true
#include "httplib.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "minja.hpp"
#include "chat.hpp"
#include "chat-template.hpp"
#include "chat.h"
#include <random>
#include <sstream>
@@ -347,41 +347,6 @@ static llama_tokens format_infill(
return embd_inp;
}
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
std::vector<common_chat_msg> chat;
for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i];
std::string role = json_value(curr_msg, "role", std::string(""));
std::string content;
if (curr_msg.contains("content")) {
if (curr_msg["content"].is_string()) {
content = curr_msg["content"].get<std::string>();
} else if (curr_msg["content"].is_array()) {
for (const auto & part : curr_msg["content"]) {
if (part.contains("text")) {
content += "\n" + part["text"].get<std::string>();
}
}
} else {
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
} else {
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
chat.push_back({role, content, /* tool_calls= */ {}});
}
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
return formatted_chat;
}
//
// base64 utils (TODO: move to common in the future)
//
@@ -556,8 +521,13 @@ static json oaicompat_completion_params_parse(const json & body) {
throw std::runtime_error("Only one completion choice is allowed");
}
// Handle "echo" field
if (json_value(body, "echo", false)) {
throw std::runtime_error("Only no echo is supported");
}
// Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "best_of", "echo", "suffix" };
static const std::vector<std::string> unsupported_params { "best_of", "suffix" };
for (const auto & param : unsupported_params) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);
@@ -579,12 +549,9 @@ static json oaicompat_completion_params_parse(
const json & body, /* openai api json semantics */
bool use_jinja,
common_reasoning_format reasoning_format,
const common_chat_templates & chat_templates)
const struct common_chat_templates * tmpls)
{
json llama_params;
const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use
? *chat_templates.template_tool_use
: *chat_templates.template_default;
auto tools = json_value(body, "tools", json());
auto stream = json_value(body, "stream", false);
@@ -610,62 +577,58 @@ static json oaicompat_completion_params_parse(
llama_params["stop"] = json_value(body, "stop", json::array());
}
auto json_schema = json_value(body, "json_schema", json());
auto grammar = json_value(body, "grammar", std::string());
if (!json_schema.is_null() && !grammar.empty()) {
throw std::runtime_error("Cannot use both json_schema and grammar");
}
// Handle "response_format" field
if (body.contains("response_format")) {
json response_format = json_value(body, "response_format", json::object());
std::string response_type = json_value(response_format, "type", std::string());
if (response_type == "json_object") {
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
json_schema = json_value(response_format, "schema", json::object());
} else if (response_type == "json_schema") {
json json_schema = json_value(response_format, "json_schema", json::object());
llama_params["json_schema"] = json_value(json_schema, "schema", json::object());
json_schema = json_value(json_schema, "schema", json::object());
} else if (!response_type.empty() && response_type != "text") {
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
}
}
// Apply chat template to the list of messages
if (use_jinja) {
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
}
if (tool_choice != "none" && llama_params.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
common_chat_inputs inputs;
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
inputs.messages = body.at("messages");
inputs.tools = tools;
inputs.tool_choice = tool_choice;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
inputs.parallel_tool_calls = false;
}
inputs.stream = stream;
// TODO: support mixing schema w/ tools beyond generic format.
inputs.json_schema = json_value(llama_params, "json_schema", json());
auto chat_params = common_chat_params_init(tmpl, inputs);
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages"));
inputs.tools = common_chat_tools_parse_oaicompat(tools);
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
inputs.grammar = grammar;
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
inputs.use_jinja = use_jinja;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
llama_params["chat_format"] = static_cast<int>(chat_params.format);
llama_params["prompt"] = chat_params.prompt;
llama_params["grammar"] = chat_params.grammar;
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
auto grammar_triggers = json::array();
for (const auto & trigger : chat_params.grammar_triggers) {
grammar_triggers.push_back({
{"word", trigger.word},
{"at_start", trigger.at_start},
});
}
llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}
} else {
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
// Apply chat template to the list of messages
auto chat_params = common_chat_templates_apply(tmpls, inputs);
llama_params["chat_format"] = static_cast<int>(chat_params.format);
llama_params["prompt"] = chat_params.prompt;
llama_params["grammar"] = chat_params.grammar;
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
auto grammar_triggers = json::array();
for (const auto & trigger : chat_params.grammar_triggers) {
grammar_triggers.push_back({
{"word", trigger.word},
{"at_start", trigger.at_start},
});
}
llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}
// Handle "n" field
@@ -737,29 +700,51 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
return res;
}
static json format_response_rerank(const json & request, const json & ranks) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
for (const auto & rank : ranks) {
data.push_back(json{
{"index", i++},
{"relevance_score", json_value(rank, "score", 0.0)},
});
static json format_response_rerank(
const json & request,
const json & ranks,
bool is_tei_format,
std::vector<std::string> & texts) {
json res;
if (is_tei_format) {
// TEI response format
res = json::array();
bool return_text = json_value(request, "return_text", false);
for (const auto & rank : ranks) {
int index = json_value(rank, "index", 0);
json elem = json{
{"index", index},
{"score", json_value(rank, "score", 0.0)},
};
if (return_text) {
elem["text"] = std::move(texts[index]);
}
res.push_back(elem);
}
} else {
// Jina response format
json results = json::array();
int32_t n_tokens = 0;
for (const auto & rank : ranks) {
results.push_back(json{
{"index", json_value(rank, "index", 0)},
{"relevance_score", json_value(rank, "score", 0.0)},
});
n_tokens += json_value(rank, "tokens_evaluated", 0);
n_tokens += json_value(rank, "tokens_evaluated", 0);
}
res = json{
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json{
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"results", results}
};
}
json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json {
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"results", data}
};
return res;
}

View File

@@ -159,6 +159,35 @@ export default function ChatMessage({
</div>
</details>
)}
{msg.extra && msg.extra.length > 0 && (
<details
className={classNames({
'collapse collapse-arrow mb-4 bg-base-200': true,
'bg-opacity-10': msg.role !== 'assistant',
})}
>
<summary className="collapse-title">
Extra content
</summary>
<div className="collapse-content">
{msg.extra.map(
(extra, i) =>
extra.type === 'textFile' ? (
<div key={extra.name}>
<b>{extra.name}</b>
<pre>{extra.content}</pre>
</div>
) : extra.type === 'context' ? (
<div key={i}>
<pre>{extra.content}</pre>
</div>
) : null // TODO: support other extra types
)}
</div>
</details>
)}
<MarkdownDisplay
content={content}
isGenerating={isPending}

View File

@@ -1,10 +1,11 @@
import { useEffect, useMemo, useState } from 'react';
import { useEffect, useMemo, useRef, useState } from 'react';
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
import ChatMessage from './ChatMessage';
import { CanvasType, Message, PendingMessage } from '../utils/types';
import { classNames, throttle } from '../utils/misc';
import CanvasPyInterpreter from './CanvasPyInterpreter';
import StorageUtils from '../utils/storage';
import { useVSCodeContext } from '../utils/llama-vscode';
/**
* A message display is a message node with additional information for rendering.
@@ -81,6 +82,14 @@ export default function ChatScreen() {
replaceMessageAndGenerate,
} = useAppContext();
const [inputMsg, setInputMsg] = useState('');
const inputRef = useRef<HTMLTextAreaElement>(null);
const { extraContext, clearExtraContext } = useVSCodeContext(
inputRef,
setInputMsg
);
// TODO: improve this when we have "upload file" feature
const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined;
// keep track of leaf node for rendering
const [currNodeId, setCurrNodeId] = useState<number>(-1);
@@ -115,10 +124,20 @@ export default function ChatScreen() {
setCurrNodeId(-1);
// get the last message node
const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
if (!(await sendMessage(currConvId, lastMsgNodeId, inputMsg, onChunk))) {
if (
!(await sendMessage(
currConvId,
lastMsgNodeId,
inputMsg,
currExtra,
onChunk
))
) {
// restore the input message if failed
setInputMsg(lastInpMsg);
}
// OK
clearExtraContext();
};
const handleEditMessage = async (msg: Message, content: string) => {
@@ -129,6 +148,7 @@ export default function ChatScreen() {
viewingChat.conv.id,
msg.parent,
content,
msg.extra,
onChunk
);
setCurrNodeId(-1);
@@ -143,6 +163,7 @@ export default function ChatScreen() {
viewingChat.conv.id,
msg.parent,
null,
msg.extra,
onChunk
);
setCurrNodeId(-1);
@@ -203,9 +224,11 @@ export default function ChatScreen() {
<textarea
className="textarea textarea-bordered w-full"
placeholder="Type a message (Shift+Enter to add a new line)"
ref={inputRef}
value={inputMsg}
onChange={(e) => setInputMsg(e.target.value)}
onKeyDown={(e) => {
if (e.nativeEvent.isComposing || e.keyCode === 229) return;
if (e.key === 'Enter' && e.shiftKey) return;
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();

View File

@@ -25,6 +25,7 @@ interface AppContextValue {
convId: string | null,
leafNodeId: Message['id'] | null,
content: string,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk
) => Promise<boolean>;
stopGenerating: (convId: string) => void;
@@ -32,6 +33,7 @@ interface AppContextValue {
convId: string,
parentNodeId: Message['id'], // the parent node of the message to be replaced
content: string | null,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk
) => Promise<void>;
@@ -274,6 +276,7 @@ export const AppContextProvider = ({
convId: string | null,
leafNodeId: Message['id'] | null,
content: string,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk
): Promise<boolean> => {
if (isGenerating(convId ?? '') || content.trim().length === 0) return false;
@@ -298,6 +301,7 @@ export const AppContextProvider = ({
convId,
role: 'user',
content,
extra,
parent: leafNodeId,
children: [],
},
@@ -324,6 +328,7 @@ export const AppContextProvider = ({
convId: string,
parentNodeId: Message['id'], // the parent node of the message to be replaced
content: string | null,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk
) => {
if (isGenerating(convId)) return;
@@ -339,6 +344,7 @@ export const AppContextProvider = ({
convId,
role: 'user',
content,
extra,
parent: parentNodeId,
children: [],
},

View File

@@ -0,0 +1,62 @@
import { useEffect, useState } from 'react';
import { MessageExtraContext } from './types';
// Extra context when using llama.cpp WebUI from llama-vscode, inside an iframe
// Ref: https://github.com/ggml-org/llama.cpp/pull/11940
interface SetTextEvData {
text: string;
context: string;
}
/**
* To test it:
* window.postMessage({ command: 'setText', text: 'Spot the syntax error', context: 'def test()\n return 123' }, '*');
*/
export const useVSCodeContext = (
inputRef: React.RefObject<HTMLTextAreaElement>,
setInputMsg: (text: string) => void
) => {
const [extraContext, setExtraContext] = useState<MessageExtraContext | null>(
null
);
// Accept setText message from a parent window and set inputMsg and extraContext
useEffect(() => {
const handleMessage = (event: MessageEvent) => {
if (event.data?.command === 'setText') {
const data: SetTextEvData = event.data;
setInputMsg(data?.text);
if (data?.context && data.context.length > 0) {
setExtraContext({
type: 'context',
content: data.context,
});
}
inputRef.current?.focus();
}
};
window.addEventListener('message', handleMessage);
return () => window.removeEventListener('message', handleMessage);
}, [inputRef, setInputMsg]);
// Add a keydown listener that sends the "escapePressed" message to the parent window
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape') {
window.parent.postMessage({ command: 'escapePressed' }, '*');
}
};
window.addEventListener('keydown', handleKeyDown);
return () => window.removeEventListener('keydown', handleKeyDown);
}, []);
return {
extraContext,
// call once the user message is sent, to clear the extra context
clearExtraContext: () => setExtraContext(null),
};
};

View File

@@ -53,12 +53,23 @@ export const copyStr = (textToCopy: string) => {
/**
* filter out redundant fields upon sending to API
* also format extra into text
*/
export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
return messages.map((msg) => {
let newContent = '';
for (const extra of msg.extra ?? []) {
if (extra.type === 'context') {
newContent += `${extra.content}\n\n`;
}
}
newContent += msg.content;
return {
role: msg.role,
content: msg.content,
content: newContent,
};
}) as APIMessage[];
}

View File

@@ -42,11 +42,25 @@ export interface Message {
role: 'user' | 'assistant' | 'system';
content: string;
timings?: TimingReport;
extra?: MessageExtra[];
// node based system for branching
parent: Message['id'];
children: Message['id'][];
}
type MessageExtra = MessageExtraTextFile | MessageExtraContext; // TODO: will add more in the future
export interface MessageExtraTextFile {
type: 'textFile';
name: string;
content: string;
}
export interface MessageExtraContext {
type: 'context';
content: string;
}
export type APIMessage = Pick<Message, 'role' | 'content'>;
export interface Conversation {

View File

@@ -3,7 +3,7 @@
# MIT license
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: MIT
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
source /opt/intel/oneapi/setvars.sh
#export GGML_SYCL_DEBUG=1
@@ -13,7 +13,7 @@ source /opt/intel/oneapi/setvars.sh
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
MODEL_FILE=models/llama-2-7b.Q4_0.gguf
NGL=33
CONEXT=8192
CONEXT=4096
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1

View File

@@ -102,6 +102,7 @@ endif()
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
@@ -121,6 +122,7 @@ endif()
option(GGML_LASX "ggml: enable lasx" ON)
option(GGML_LSX "ggml: enable lsx" ON)
option(GGML_RVV "ggml: enable rvv" ON)
option(GGML_VXE "ggml: enable vxe" ON)
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
@@ -150,6 +152,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"ggml: max. batch size for using peer access")
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
@@ -209,6 +212,8 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
include(GNUInstallDirs)
#
# build the library
#
@@ -232,7 +237,6 @@ endif ()
# install
#
include(GNUInstallDirs)
include(CMakePackageConfigHelpers)
# all public headers

View File

@@ -112,7 +112,7 @@ foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
if(is_cpu_variant)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base")
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
@@ -124,7 +124,7 @@ foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
endif()
else()
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base")
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
@@ -139,6 +139,11 @@ foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
endforeach()
list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}")
set_target_properties(ggml::ggml
PROPERTIES
INTERFACE_LINK_LIBRARIES "${GGML_INTERFACE_LINK_LIBRARIES}")
add_library(ggml::all INTERFACE IMPORTED)
set_target_properties(ggml::all
PROPERTIES

View File

@@ -95,9 +95,11 @@ extern "C" {
GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
GGML_BACKEND_API int ggml_cpu_has_sve (void);
GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes
GGML_BACKEND_API int ggml_cpu_has_sme (void);
// other
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
GGML_BACKEND_API int ggml_cpu_has_vsx (void);
GGML_BACKEND_API int ggml_cpu_has_vxe (void);
GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
GGML_BACKEND_API int ggml_cpu_has_llamafile (void);

View File

@@ -1,7 +1,5 @@
#include "kernel_operator.h"
#include <cmath>
using namespace AscendC;
#define BUFFER_NUM 2
@@ -183,7 +181,7 @@ extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32(
copy_to_ub(output_ne_gm, output_ne_ub, 32);
copy_to_ub(output_nb_gm, output_nb_ub, 32);
DupByRows<float_t, float_t> op;
DupByRows<float, float> op;
op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
op.dup();
}
@@ -206,7 +204,7 @@ extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32_to_fp16(
copy_to_ub(output_ne_gm, output_ne_ub, 32);
copy_to_ub(output_nb_gm, output_nb_ub, 32);
DupByRows<float_t, half> op;
DupByRows<float, half> op;
op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
op.dup_with_cast();
}
@@ -230,7 +228,7 @@ extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16_to_fp32(
copy_to_ub(output_ne_gm, output_ne_ub, 32);
copy_to_ub(output_nb_gm, output_nb_ub, 32);
DupByRows<half, float_t> op;
DupByRows<half, float> op;
op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
op.dup_with_cast();
}

View File

@@ -111,14 +111,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
function(check_arm_feature tag code)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
check_cxx_source_runs(
"${code}"
GGML_MACHINE_SUPPORTS_${tag}
)
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
if (GGML_MACHINE_SUPPORTS_${tag})
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
else()
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}")
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
if (GGML_MACHINE_SUPPORTS_no${tag})
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
endif()
endif()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
endfunction()
@@ -126,6 +127,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
else()
@@ -150,7 +152,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
if (ARM_FEATURE_RESULT)
message(WARNING "Failed to get ARM features")
else()
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC)
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
if (NOT ${feature_pos} EQUAL -1)
message(STATUS "ARM feature ${feature} enabled")
@@ -308,6 +310,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
if (GGML_RVV)
list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
message(STATUS "s390x detected")
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
# TODO: Separation to determine activation of VX/VXE/VXE2
if (${S390X_M} MATCHES "8561|8562")
message(STATUS "z15 target")
list(APPEND ARCH_FLAGS -march=z15 -mtune=z15)
elseif (${S390X_M} MATCHES "3931")
message(STATUS "z16 target")
list(APPEND ARCH_FLAGS -march=z16 -mtune=z16)
else()
message(STATUS "Unknown target")
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
list(APPEND ARCH_FLAGS -march=native -mtune=native)
endif()
if (GGML_VXE)
list(APPEND ARCH_FLAGS -mvx -mzvector)
endif()
else()
message(STATUS "Unknown architecture")
endif()
@@ -316,6 +339,94 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64)
endif()
if (GGML_CPU_KLEIDIAI)
message(STATUS "Using KleidiAI optimized kernels if applicable")
# Disable the KleidiAI tests
set(KLEIDIAI_BUILD_TESTS OFF)
# Fetch KleidiAI sources:
include(FetchContent)
set(KLEIDIAI_COMMIT_TAG "v1.3.0")
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
set(KLEIDIAI_ARCHIVE_MD5 "060bd2dc64642b091f461cc8dd7426d9")
if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
endif()
FetchContent_Declare(KleidiAI_Download
URL ${KLEIDIAI_DOWNLOAD_URL}
DOWNLOAD_EXTRACT_TIMESTAMP NEW
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
FetchContent_MakeAvailable(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download
SOURCE_DIR KLEIDIAI_SRC
POPULATED KLEIDIAI_POPULATED)
if (NOT KLEIDIAI_POPULATED)
message(FATAL_ERROR "KleidiAI source downloaded failed.")
endif()
add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
# Remove kleidiai target after fetching it
if (TARGET kleidiai)
set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
endif()
list(APPEND GGML_CPU_SOURCES
ggml-cpu/kleidiai/kleidiai.cpp
ggml-cpu/kleidiai/kernels.cpp
ggml-cpu/kleidiai/kleidiai.h
ggml-cpu/kleidiai/kernels.h
)
# KleidiAI
include_directories(
${KLEIDIAI_SRC}/
${KLEIDIAI_SRC}/kai/
${KLEIDIAI_SRC}/kai/ukernels/
${KLEIDIAI_SRC}/kai/ukernels/matmul/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
if (NOT ARCH_FLAGS_TEMP)
string(REGEX MATCH "-march=[^ ]+" ARCH_FLAGS_TEMP "${CMAKE_C_FLAGS}")
endif()
string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED)
string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
if (NOT DOTPROD_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
endif()
if (NOT I8MM_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
endif()
if (NOT SME_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
endif()
set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES})
endif()
message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}")
target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES})
target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS})

View File

@@ -59,6 +59,15 @@ struct ggml_compute_params {
#endif
#endif
#if defined(__s390x__) && defined(__VEC__)
#ifndef __VXE__
#define __VXE__
#endif
#ifndef __VXE2__
#define __VXE2__
#endif
#endif
#if defined(__ARM_FEATURE_SVE)
#include <arm_sve.h>
#include <sys/prctl.h>
@@ -359,6 +368,148 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
#endif
#endif
#if defined(__VXE__) || defined(__VXE2__)
#include <vecintrin.h>
#define vec_neg(a) (-(a)) // Vector Negate
#define vec_add(a, b) ((a) + (b)) // Vector Add
#define vec_sub(a, b) ((a) - (b)) // Vector Subtract
#define vec_mul(a, b) ((a) * (b)) // Vector Multiply
#define vec_div(a, b) ((a) / (b)) // Vector Divide
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
#define vec_sra(a, b) ((a) >> (b)) // Vector Shift Right
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic
#define vec_slo(a, b) vec_slb(a, (b) << 64) // Vector Shift Left by Octet
#define vec_sro(a, b) vec_srb(a, (b) << 64) // Vector Shift Right by Octet
#ifndef vec_and
#define vec_and(a, b) ((a) & (b)) // Vector AND
#endif
#ifndef vec_or
#define vec_or(a, b) ((a) | (b)) // Vector OR
#endif
#ifndef vec_xor
#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR
#endif
typedef signed char char8x16_t __attribute__((vector_size(16)));
typedef unsigned char uchar8x16_t __attribute__((vector_size(16)));
typedef int8_t int8x16_t __attribute__((vector_size(16)));
typedef int16_t int16x8_t __attribute__((vector_size(16)));
typedef int32_t int32x4_t __attribute__((vector_size(16)));
typedef uint8_t uint8x16_t __attribute__((vector_size(16)));
typedef uint16_t uint16x8_t __attribute__((vector_size(16)));
typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
typedef float float32x4_t __attribute__((vector_size(16)));
typedef double double64x2_t __attribute((vector_size(16)));
typedef signed long long long64x2_t __attribute((vector_size(16)));
typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
typedef struct ggml_uint8x16x2_t {
uint8x16_t val[2];
} ggml_uint8x16x2_t;
inline static ggml_uint8x16x2_t ggml_vec_xl_u8x2(const uint8_t * ptr) {
ggml_uint8x16x2_t res;
res.val[0] = vec_xl( 0, ptr);
res.val[1] = vec_xl(16, ptr);
return res;
}
typedef struct ggml_uint8x16x4_t {
uint8x16_t val[4];
} ggml_uint8x16x4_t;
inline static ggml_uint8x16x4_t ggml_vec_xl_u8x4(const uint8_t * ptr) {
ggml_uint8x16x4_t res;
res.val[0] = vec_xl( 0, ptr);
res.val[1] = vec_xl(16, ptr);
res.val[2] = vec_xl(32, ptr);
res.val[3] = vec_xl(48, ptr);
return res;
}
typedef struct ggml_int8x16x4_t {
int8x16_t val[4];
} ggml_int8x16x4_t;
inline static ggml_int8x16x4_t ggml_vec_xl_s8x4(const int8_t * ptr) {
ggml_int8x16x4_t res;
res.val[0] = vec_xl( 0, ptr);
res.val[1] = vec_xl(16, ptr);
res.val[2] = vec_xl(32, ptr);
res.val[3] = vec_xl(48, ptr);
return res;
}
typedef struct ggml_int16x8x2_t {
int16x8_t val[2];
} ggml_int16x8x2_t;
inline static ggml_int16x8x2_t ggml_vec_xl_s16x2(const int16_t * ptr) {
ggml_int16x8x2_t res;
res.val[0] = vec_xl( 0, ptr);
res.val[1] = vec_xl(16, ptr);
return res;
}
/*
! WARNING: Very slow. Use vec_perm if possible. Refer to iq4_xs
! or iq4_nl for example implementation.
*/
inline static int8x16_t ggml_vec_tbl(int8x16_t a, uint8x16_t b) {
int8x16_t res;
res[ 0] = a[b[ 0]];
res[ 1] = a[b[ 1]];
res[ 2] = a[b[ 2]];
res[ 3] = a[b[ 3]];
res[ 4] = a[b[ 4]];
res[ 5] = a[b[ 5]];
res[ 6] = a[b[ 6]];
res[ 7] = a[b[ 7]];
res[ 8] = a[b[ 8]];
res[ 9] = a[b[ 9]];
res[10] = a[b[10]];
res[11] = a[b[11]];
res[12] = a[b[12]];
res[13] = a[b[13]];
res[14] = a[b[14]];
res[15] = a[b[15]];
return res;
}
inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
const uchar8x16_t v_maske = { 0, 1, 4, 5, 8, 9, 12, 13,
16, 17, 20, 21, 24, 25, 28, 29 };
const int16x8_t v_abo = vec_pack((int32x4_t)a, (int32x4_t)b);
const int16x8_t v_abe = vec_perm(a, b, v_maske);
return v_abo + v_abe;
}
inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
return acc + (vec_unpackh(p) + vec_unpackl(p));
}
#endif
#if defined(__loongarch_asx)
/* float type data load instructions */
static __m128 __lsx_vreplfr2vr_s(const float val) {

File diff suppressed because it is too large Load Diff

View File

@@ -112,7 +112,8 @@ struct ggml_arm_arch_features_type {
int has_i8mm;
int has_sve;
int sve_cnt;
} ggml_arm_arch_features = {-1, -1, -1, -1, 0};
int has_sme;
} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
#endif
@@ -236,6 +237,8 @@ typedef pthread_t ggml_thread_t;
#else
#if defined(__POWER9_VECTOR__)
#define CACHE_LINE_SIZE 128
#elif defined(__VXE__) || defined(__VXE2__)
#define CACHE_LINE_SIZE 256
#else
#define CACHE_LINE_SIZE 64
#endif
@@ -1210,6 +1213,87 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
#elif defined(__VXE__) || defined(__VXE2__)
#define GGML_SIMD
// F32 s390x
#define GGML_F32_STEP 32
#define GGML_F32_EPR 4
#define GGML_F32x4 __vector float
#define GGML_F32x4_ZERO vec_splats(0.0f)
#define GGML_F32x4_SET1 vec_splats
#define GGML_F32x4_LOAD(p) vec_xl(0, p)
#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
#define GGML_F32x4_ADD vec_add
#define GGML_F32x4_MUL vec_mul
#define GGML_F32x4_REDUCE(res, x) \
{ \
int offset = GGML_F32_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vec_add(x[i], x[offset + i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vec_add(x[i], x[offset + i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vec_add(x[i], x[offset + i]); \
} \
res = vec_extract(x[0], 0) + \
vec_extract(x[0], 1) + \
vec_extract(x[0], 2) + \
vec_extract(x[0], 3); \
}
#define GGML_F32_VEC GGML_F32x4
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
// F16 s390x
#define GGML_F16_STEP GGML_F32_STEP
#define GGML_F16_EPR GGML_F32_EPR
static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) {
float tmp[4];
for (int i = 0; i < 4; i++) {
tmp[i] = GGML_FP16_TO_FP32(x[i]);
}
return vec_xl(0, tmp);
}
static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) {
float arr[4];
vec_xst(y, 0, arr);
for (int i = 0; i < 4; i++) {
x[i] = GGML_FP32_TO_FP16(arr[i]);
}
}
#define GGML_F16_VEC GGML_F32x4
#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
#define GGML_F16_VEC_SET1 GGML_F32x4_SET1
#define GGML_F16_VEC_LOAD(p, i) __lzs_f16cx4_load(p)
#define GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i])
#define GGML_F16_VEC_FMA GGML_F32x4_FMA
#define GGML_F16_VEC_ADD GGML_F32x4_ADD
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
#endif
// GGML_F32_ARR / GGML_F16_ARR
@@ -2381,15 +2465,20 @@ bool ggml_is_numa(void) {
#define HWCAP2_I8MM (1 << 13)
#endif
#if !defined(HWCAP2_SME)
#define HWCAP2_SME (1 << 23)
#endif
static void ggml_init_arm_arch_features(void) {
#if defined(__linux__) && defined(__aarch64__)
uint32_t hwcap = getauxval(AT_HWCAP);
uint32_t hwcap2 = getauxval(AT_HWCAP2);
ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME);
#if defined(__ARM_FEATURE_SVE)
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
@@ -2412,6 +2501,11 @@ static void ggml_init_arm_arch_features(void) {
}
ggml_arm_arch_features.has_i8mm = oldp;
if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
oldp = 0;
}
ggml_arm_arch_features.has_sme = oldp;
ggml_arm_arch_features.has_sve = 0;
ggml_arm_arch_features.sve_cnt = 0;
#else
@@ -2435,6 +2529,12 @@ static void ggml_init_arm_arch_features(void) {
ggml_arm_arch_features.has_sve = 0;
ggml_arm_arch_features.sve_cnt = 0;
#endif
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
ggml_arm_arch_features.has_sme = 1;
#else
ggml_arm_arch_features.has_sme = 0;
#endif
#endif
}
#endif
@@ -14402,6 +14502,14 @@ int ggml_cpu_has_vsx(void) {
#endif
}
int ggml_cpu_has_vxe(void) {
#if defined(__VXE__) || defined(__VXE2__)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_neon(void) {
#if defined(__ARM_ARCH) && defined(__ARM_NEON)
return ggml_arm_arch_features.has_neon;
@@ -14442,6 +14550,14 @@ int ggml_cpu_get_sve_cnt(void) {
#endif
}
int ggml_cpu_has_sme(void) {
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
return ggml_arm_arch_features.has_sme;
#else
return 0;
#endif
}
void ggml_cpu_init(void) {
// needed to initialize f16 tables
{

View File

@@ -14,6 +14,10 @@
#include "ggml-cpu-hbm.h"
#endif
#ifdef GGML_USE_CPU_KLEIDIAI
#include "kleidiai/kleidiai.h"
#endif
#if defined(__APPLE__)
#include <sys/types.h>
#include <sys/sysctl.h>
@@ -39,6 +43,12 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
}
#endif
#ifdef GGML_USE_CPU_KLEIDIAI
if (ggml_backend_cpu_kleidiai_buffer_type()) {
bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
}
#endif
#ifdef GGML_USE_CPU_AARCH64
if (ggml_backend_cpu_aarch64_buffer_type()) {
bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
@@ -538,12 +548,18 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
features.push_back({ "SVE_CNT", sve_cnt.c_str() });
}
if (ggml_cpu_has_sme()) {
features.push_back({ "SME", "1" });
}
if (ggml_cpu_has_riscv_v()) {
features.push_back({ "RISCV_V", "1" });
}
if (ggml_cpu_has_vsx()) {
features.push_back({ "VSX", "1" });
}
if (ggml_cpu_has_vxe()) {
features.push_back({ "VXE", "1" });
}
if (ggml_cpu_has_wasm_simd()) {
features.push_back({ "WASM_SIMD", "1" });
}
@@ -559,6 +575,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
#ifdef GGML_USE_OPENMP
features.push_back({ "OPENMP", "1" });
#endif
#ifdef GGML_USE_CPU_KLEIDIAI
features.push_back({ "KLEIDIAI", "1" });
#endif
#ifdef GGML_USE_CPU_AARCH64
features.push_back({ "AARCH64_REPACK", "1" });
#endif

View File

@@ -0,0 +1,259 @@
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: MIT
//
// KleidiAI micro-kernels
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
#include "kai_common.h"
#include "kernels.h"
#define NELEMS(x) sizeof(x) / sizeof(*x)
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
#if defined(__ARM_FEATURE_SME)
{
/* SME GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
},
/* SME GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
},
/* .lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
/* .require_aligned_m_idx = */ true,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
},
/* .required_cpu = */ CPU_FEATURE_SME,
},
#endif
#if defined(__APPLE__)
#if defined(__ARM_FEATURE_DOTPROD)
{
/* DOTPROD GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
},
/* DOTPROD GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
},
/* .lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
/* .require_aligned_m_idx = */ false,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
},
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
{
/* i8mm GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
},
/* i8mm GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
},
/* .lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
/* .require_aligned_m_idx = */ false,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
},
#endif
#else
#if defined(__ARM_FEATURE_MATMUL_INT8)
{
/* i8mm GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
},
/* i8mm GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
},
/* .lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
/* .require_aligned_m_idx = */ false,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
},
#endif
#if defined(__ARM_FEATURE_DOTPROD)
{
/* DOTPROD GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
},
/* DOTPROD GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
},
/* .lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
/* .require_aligned_m_idx = */ false,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
},
#endif
#endif
};
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
ggml_kleidiai_kernels * kernels = nullptr;
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
kernels = &gemm_gemv_kernels[i];
break;
}
}
return kernels;
}

View File

@@ -0,0 +1,61 @@
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: MIT
//
#pragma once
enum cpu_feature {
CPU_FEATURE_NONE = 0,
CPU_FEATURE_DOTPROD = 1,
CPU_FEATURE_I8MM = 2,
CPU_FEATURE_SVE = 4,
CPU_FEATURE_SME = 8
};
inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
lhs = static_cast<cpu_feature>(lhs | rhs);
return lhs;
}
inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) {
return static_cast<cpu_feature>(static_cast<int>(lhs) | static_cast<int>(rhs));
}
struct kernel_info {
size_t (*get_m_step)(void);
size_t (*get_n_step)(void);
size_t (*get_mr)(void);
size_t (*get_nr)(void);
size_t (*get_kr)(void);
size_t (*get_sr)(void);
size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
size_t (*get_dst_size)(size_t m, size_t n);
void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
};
struct lhs_packing_info {
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
size_t lhs_stride, void* lhs_packed);
bool require_aligned_m_idx;
};
struct rhs_packing_info {
size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
};
struct ggml_kleidiai_kernels {
kernel_info gemm;
kernel_info gemv;
lhs_packing_info lhs_info;
rhs_packing_info rhs_info;
cpu_feature required_cpu;
};
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);

View File

@@ -0,0 +1,287 @@
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: MIT
//
#include <arm_neon.h>
#include <assert.h>
#include <cfloat>
#include <stdint.h>
#include <string.h>
#if defined(__linux__)
#include <asm/hwcap.h>
#include <sys/auxv.h>
#elif defined(__APPLE__)
#include <string_view>
#include <sys/sysctl.h>
#include <sys/types.h>
#elif defined(_WIN32)
#include <windows.h>
#include <excpt.h>
#endif
#include "kleidiai.h"
#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include "ggml-threading.h"
#include "ggml-cpu-traits.h"
#include "kernels.h"
#include "kai_common.h"
#define GGML_COMMON_DECL_CPP
#include "ggml-common.h"
struct ggml_kleidiai_context {
ggml_kleidiai_kernels * kernels;
} static ctx = { NULL };
static void init_kleidiai_context(void) {
ggml_critical_section_start();
static bool initialized = false;
if (!initialized) {
initialized = true;
const char *env_var = getenv("GGML_KLEIDIAI_SME");
int sme_enabled = 0;
cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
if (env_var) {
sme_enabled = atoi(env_var);
}
if (sme_enabled != 0) {
features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
}
ctx.kernels = ggml_kleidiai_select_kernels(features);
}
ggml_critical_section_end();
}
static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
return tensor->ne[dim];
}
namespace ggml::cpu::kleidiai {
class tensor_traits : public ggml::cpu::tensor_traits {
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
GGML_ASSERT(ctx.kernels);
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
size_t k = op->src[0]->ne[0];
size_t m = op->src[1]->ne[1];
size_t mr = kernel->get_mr();
size_t kr = kernel->get_kr();
size_t sr = kernel->get_sr();
size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
return true;
}
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
if (dst->op == GGML_OP_MUL_MAT) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(ctx.kernels);
kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
GGML_ASSERT(kernel);
const int ith = params->ith;
const int nth = params->nth;
const size_t k = ne00;
const size_t m = ne11;
const size_t n = ne01;
const size_t n_step = kernel->get_n_step();
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
const size_t n_start = ith * num_n_per_thread;
size_t n_to_process = num_n_per_thread;
if ((n_start + n_to_process) > n) {
n_to_process = n - n_start;
}
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
uint8_t * lhs_packed = (uint8_t*)params->wdata;
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
size_t mr = kernel->get_mr();
size_t kr = kernel->get_kr();
size_t sr = kernel->get_sr();
// Calculate number of columns to be processed per thread
const bool use_multithread = lhs_info->require_aligned_m_idx && m <= mr ? false : true;
const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m;
const size_t m_start = ith * num_m_per_thread;
size_t m_to_process = num_m_per_thread;
if ((m_start + m_to_process) > m) {
m_to_process = m - m_start;
}
if(m_start < m) {
// Transform LHS
const size_t src_stride = src1->nb[1];
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
}
ggml_barrier(params->threadpool);
// Perform the operation
const size_t dst_stride = dst->nb[1];
const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
return true;
}
return false;
}
public:
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
GGML_ASSERT(ctx.kernels);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
size_t nr = ctx.kernels->gemm.get_nr();
size_t kr = ctx.kernels->gemm.get_kr();
size_t sr = ctx.kernels->gemm.get_sr();
#ifndef NDEBUG
const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
#endif
struct kai_rhs_pack_qs4cxs1s0_param params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, &params);
return 0;
GGML_UNUSED(data_size);
}
};
static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
static tensor_traits traits;
return &traits;
}
} // namespace ggml::cpu::kleidiai
static void ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
const void * data, size_t offset, size_t size) {
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
auto OK = tensor_traits->repack(tensor, data, size);
GGML_ASSERT(OK == 0);
GGML_UNUSED(buffer);
}
static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU_KLEIDIAI";
GGML_UNUSED(buft);
}
static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
if (buffer == nullptr) {
return nullptr;
}
buffer->buft = buft;
buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor;
buffer->iface.get_tensor = nullptr;
buffer->iface.cpy_tensor = nullptr;
return buffer;
}
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return TENSOR_ALIGNMENT;
GGML_UNUSED(buft);
}
namespace ggml::cpu::kleidiai {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
if ( op->op == GGML_OP_MUL_MAT &&
op->src[0]->type == GGML_TYPE_Q4_0 &&
op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) &&
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
) {
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
if (op->src[1]->type == GGML_TYPE_F32 &&
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
return true;
}
}
return false;
}
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
if (op->op == GGML_OP_MUL_MAT) {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}
}
return nullptr;
}
};
} // namespace ggml::cpu::kleidiai
ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
static ggml::cpu::kleidiai::extra_buffer_type ctx;
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
/* .iface = */ {
/* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
/* .is_host = */ nullptr,
},
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ &ctx,
};
init_kleidiai_context();
return &ggml_backend_cpu_buffer_type_kleidiai;
}

View File

@@ -0,0 +1,17 @@
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: MIT
//
#pragma once
#include "ggml-alloc.h"
#ifdef __cplusplus
extern "C" {
#endif
ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void);
#ifdef __cplusplus
}
#endif

View File

@@ -7,7 +7,7 @@ if (CUDAToolkit_FOUND)
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
# native == GPUs available at build time
# 52 == Maxwell, lowest CUDA 12 standard
# 50 == Maxwell, lowest CUDA 12 standard
# 60 == P100, FP16 CUDA intrinsics
# 61 == Pascal, __dp4a instruction (per-byte integer dot product)
# 70 == V100, FP16 tensor cores
@@ -17,7 +17,7 @@ if (CUDAToolkit_FOUND)
elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80")
else()
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75;80")
set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75;80")
endif()
endif()
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
@@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
add_compile_definitions(GGML_CUDA_NO_VMM)
endif()
if (NOT GGML_CUDA_FA)
add_compile_definitions(GGML_CUDA_NO_FA)
endif()
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
add_compile_definitions(GGML_CUDA_F16)
endif()

View File

@@ -41,12 +41,13 @@
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
// GCN/CNDA, wave size is 64
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
@@ -199,9 +200,13 @@ typedef float2 dfloat2;
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
@@ -231,6 +236,10 @@ static bool new_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}
static bool cp_async_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
return __AMDGCN_WAVEFRONT_SIZE;
@@ -402,13 +411,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
return __dp4a(a, b, c);
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
const int8_t * a8 = (const int8_t *) &a;
const int8_t * b8 = (const int8_t *) &b;
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
}

View File

@@ -0,0 +1,46 @@
// Simplified API for asynchronous data loading.
#include "common.cuh"
// Copies data from global to shared memory, cg == cache global.
// Both the src and dst pointers must be aligned to 16 bit.
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
template <int preload>
static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
#ifdef CP_ASYNC_AVAILABLE
#if CUDART_VERSION >= 11040
if (preload == 256) {
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else if (preload == 128) {
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else if (preload == 64) {
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else
#endif // CUDART_VERSION >= 11040
{
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
: : "r"(dst), "l"(src));
}
#else
GGML_UNUSED(dst);
GGML_UNUSED(src);
NO_DEVICE_CODE;
#endif // CP_ASYNC_AVAILABLE
}
// Makes each thread wait until its asynchronous data copies are done.
// This does NOT provide any additional synchronization.
// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
static __device__ __forceinline__ void cp_async_wait_all() {
#ifdef CP_ASYNC_AVAILABLE
asm volatile("cp.async.wait_all;");
#else
NO_DEVICE_CODE;
#endif // CP_ASYNC_AVAILABLE
}

View File

@@ -1,4 +1,5 @@
#include "cpy.cuh"
#include "dequantize.cuh"
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
@@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
}
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
const block_q8_0 * xi = (const block_q8_0 *) cxi;
float * dsti = (float *) cdsti;
float * cdstf = (float *)(cdsti);
const float d = (float)xi->d;
for (int j = 0; j < QK8_0; j++) {
dsti[j] = xi->qs[j] * d;
#pragma unroll
for (int j = 0; j < QK8_0; j += 2) {
dfloat2 dq;
dequantize_q8_0(cxi, 0, j, dq);
*(cdstf + j) = dq.x;
*(cdstf + j + 1) = dq.y;
}
}
@@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
memcpy(dsti->qh, &qh, sizeof(qh));
}
template<dequantize_kernel_t dequant, int qk>
static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
float * cdstf = (float *)(cdsti);
#pragma unroll
for (int j = 0; j < qk/2; j++) {
dfloat2 dq;
dequant(cxi, 0, j, dq);
*(cdstf + j) = dq.x;
*(cdstf + j + qk/2) = dq.y;
}
}
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
@@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q4_0_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q4_1_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q5_0_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q5_1_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {

View File

@@ -123,13 +123,13 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
} else {
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
@@ -175,13 +175,13 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
} else {
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);

View File

@@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
nullptr;
}
// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template<int D, int ncols, int KQ_stride> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
const int iter_k = ne11 / KQ_stride;
const int iter_j = (ne01 + (ncols - 1)) / ncols;
constexpr int ncols = ncols1*ncols2;
const int bidx0 = blockIdx.x;
const int j = blockIdx.y;
const int c = blockIdx.z;
const int jc = j*ncols2 + c;
const int tid = threadIdx.x;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -548,22 +546,22 @@ static __global__ void flash_attn_stream_k_fixup(
const int channel = kbc0 / (iter_k*iter_j);
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
dst += jt*ncols*ne02*D + channel*D;
if (jt*ncols1 + j >= ne01) {
return;
}
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
// Load the partial result that needs a fixup:
float dst_val[ncols] = {0.0f};
float max_val[ncols] = {0.0f};
float rowsum[ncols] = {0.0f};
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
break;
}
dst_val[j] = dst[j*ne02*D + threadIdx.x];
float dst_val = 0.0f;
float max_val = 0.0f;
float rowsum = 0.0f;
{
dst_val = *dst;
const float2 tmp = dst_fixup[bidx0*ncols + j];
max_val[j] = tmp.x;
rowsum[j] = tmp.y;
const float2 tmp = dst_fixup[bidx0*ncols + jc];
max_val = tmp.x;
rowsum = tmp.y;
}
// Iterate over previous blocks and compute the combined results.
@@ -571,36 +569,30 @@ static __global__ void flash_attn_stream_k_fixup(
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
continue;
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
break;
}
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
// Scale the current and new value accumulators depending on the max. values.
const float max_val_new = fmaxf(max_val[j], tmp.x);
// Scale the current and new value accumulators depending on the max. values.
const float max_val_new = fmaxf(max_val, tmp.x);
const float diff_val = max_val[j] - max_val_new;
const float diff_add = tmp.x - max_val_new;
const float diff_val = max_val - max_val_new;
const float diff_add = tmp.x - max_val_new;
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
dst_val = scale_val*dst_val + scale_add*dst_add;
rowsum = scale_val*rowsum + scale_add*tmp.y;
max_val[j] = max_val_new;
}
max_val = max_val_new;
// If this block started in a previous tile we are done and don't need to combine additional partial results.
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
@@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
}
// Write back final result:
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
return;
}
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
}
*dst = dst_val / rowsum;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
@@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
}
// parallel_blocks == 0 is stream-k decomposition
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
) {
constexpr int ncols = ncols1 * ncols2;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
@@ -716,7 +700,9 @@ void launch_fattn(
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;
ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
@@ -761,24 +747,26 @@ void launch_fattn(
nb23 = nb23*bs*sizeof(half)/ts;
}
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
const dim3 block_dim(WARP_SIZE, nwarps, 1);
dim3 blocks_num;
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
const bool short_context = K->ne[1] < 4096;
const int max_blocks = 2*nsm;
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
const int nblocks_stream_k = 2*nsm;
const int nblocks_stream_k = max_blocks;
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
blocks_num.y = 1;
blocks_num.z = 1;
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
} else {
blocks_num.x = parallel_blocks*ntiles_x;
blocks_num.y = Q->ne[2];
@@ -790,7 +778,6 @@ void launch_fattn(
}
}
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
@@ -827,11 +814,11 @@ void launch_fattn(
CUDA_CHECK(cudaGetLastError());
if constexpr (parallel_blocks == 0) {
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine = blocks_num;
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
}

File diff suppressed because it is too large Load Diff

View File

@@ -44,12 +44,7 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
@@ -290,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16(
}
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
@@ -302,14 +297,14 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");

View File

@@ -44,10 +44,7 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
@@ -285,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32(
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
@@ -296,14 +296,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");

View File

@@ -41,12 +41,7 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -300,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16(
}
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
@@ -310,7 +305,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>

View File

@@ -41,10 +41,7 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -281,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32(
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
@@ -290,7 +290,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>

View File

@@ -51,7 +51,7 @@ static __global__ void flash_attn_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
@@ -425,7 +425,7 @@ static __global__ void flash_attn_ext_f16(
}
#else
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
}
constexpr int get_max_power_of_2(int x) {
@@ -478,7 +478,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
@@ -493,7 +493,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
return;
}
constexpr int parallel_blocks = 1;
@@ -507,7 +507,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
}
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@@ -8,28 +8,50 @@
#include "fattn-wmma-f16.cuh"
#include "fattn.cuh"
template <int cols_per_block>
template <int D, int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
if (Q->ne[1] <= 8/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
return;
}
if (Q->ne[1] <= 16/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
return;
}
if (Q->ne[1] <= 32/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
}
template <int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
@@ -38,24 +60,35 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
}
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * mask = dst->src[3];
if (Q->ne[1] <= 8) {
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
const float use_gqa_opt = mask && max_bias == 0.0f;
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (use_gqa_opt && gqa_ratio % 8 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
return;
}
if (Q->ne[1] <= 16) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
if (use_gqa_opt && gqa_ratio == 4) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
return;
}
if (Q->ne[1] <= 32) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
if (use_gqa_opt && gqa_ratio == 2) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
}
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
@@ -209,8 +242,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
}
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@@ -252,7 +288,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return;
}
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
const int gqa_ratio = Q->ne[2] / K->ne[2];
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
if (prec == GGML_PREC_DEFAULT) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
return;

View File

@@ -261,6 +261,12 @@ static ggml_cuda_device_info ggml_cuda_init() {
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
device_vmm ? "yes" : "no", prop.warpSize);
#elif defined(GGML_USE_MUSA)
// TODO: refine the .cc to reflect MUSA's actual CC capabilities
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
#else
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
@@ -1782,9 +1788,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
}
}
#else
#ifdef GGML_USE_MUSA
GGML_ASSERT(false);
#else // !GGML_USE_MUSA
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
@@ -1827,7 +1830,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
#endif // GGML_USE_MUSA
#endif
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
@@ -3073,15 +3075,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
return true;
}
if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
return true;
}
if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
return true;
}
if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
return true;
}
if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
return true;
}
@@ -3189,7 +3203,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
return false;
}

View File

@@ -4,11 +4,12 @@
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
//
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
// A is a row-major matrix with shape I x K.
// B is a column-major matrix with shape K x J.
// C is a column-major matrix with shape I x J.
// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
// A is a row-major matrix with shape M x K.
// B is a column-major matrix with shape K x N.
// C is a column-major matrix with shape M x N.
// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
// Note that J is measured in physical 32 bit elements instead of logical elements.
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
// All matrix tiles have ne physical 32 bit elements per warp.
//
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
@@ -23,7 +24,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
#ifdef NEW_MMA_AVAILABLE
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "+r"(ret) : "r"(x));
: "=r"(ret) : "r"(x));
#else
NO_DEVICE_CODE;
#endif // defined(NEW_MMA_AVAILABLE)
@@ -52,407 +53,342 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
#endif // CUDART_VERSION >= 11080
static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
half2 ret;
*((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
return ret;
}
template <typename T>
struct mma_A_I16K4 {
static_assert(sizeof(T) == 4, "bad type size");
namespace ggml_cuda_mma {
static constexpr int I = 16;
static constexpr int K = 4;
static constexpr int ne = 2;
template <int I_, int J_, typename T>
struct tile {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr int ne = I * J / WARP_SIZE;
T x[ne] = {0};
T x[ne];
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && (J == 4 || J == 8)) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 8 + threadIdx.x / 4;
} else if constexpr (I == 16 && J == 16) {
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l%2) * (I/2) + threadIdx.x / K;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 8 && J == 8) {
return 4 * l + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x % 4) + l % 2;
} else if constexpr (I == 16 && J == 16) {
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
};
static __device__ __forceinline__ int get_k(const int /* l */) {
const int ret = threadIdx.x % K;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
template <int I_, int J_>
struct tile<I_, J_, half2> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
return l * 8 + threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l % 2) * 8 + threadIdx.x / 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 8) {
return l * 4 + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 4 + threadIdx.x % 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
};
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
tile<I, J/2, half2> ret;
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
}
return ret;
}
static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
tile<8, 8, half2> ret;
ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
return ret;
}
template <int I, int J, typename T>
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(xi[0]), "+r"(xi[1])
int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "=r"(xi[0]), "=r"(xi[1])
: "l"(xs));
#else
load_generic(t, xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "=r"(xi[0]), "=r"(xi[1])
: "l"(xs));
#else
load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
};
template <typename T>
struct mma_A_I16K8 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int I = 16;
static constexpr int K = 8;
static constexpr int ne = 4;
T x[ne];
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_k(const int l) {
const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
}
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int * ) x;
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
: "l"(xs));
#else
load_generic(t, xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix_trans(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
: "l"(xs));
#else
GGML_UNUSED(t);
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int * ) x;
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
: "l"(xs));
#else
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void transpose() {
int * xi = (int *) x;
xi[0] = ggml_cuda_movmatrix(xi[0]);
const int tmp = ggml_cuda_movmatrix(xi[1]);
xi[1] = ggml_cuda_movmatrix(xi[2]);
xi[2] = tmp;
xi[3] = ggml_cuda_movmatrix(xi[3]);
}
};
template <typename T>
struct mma_B_J8K4 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int J = 8;
static constexpr int K = 4;
static constexpr int ne = 1;
T x[ne];
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / K;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
static __device__ __forceinline__ int get_k(const int /* l */) {
const int ret = threadIdx.x % K;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
: "+r"(xi[0]) : "l"(xs));
#else
load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
};
template <typename T>
struct mma_B_J8K8 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int J = 8;
static constexpr int K = 8;
static constexpr int ne = 2;
T x[ne];
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
static __device__ __forceinline__ int get_k(const int l) {
const int ret = l * (K/2) + threadIdx.x % (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(xi[0]), "+r"(xi[1])
: "l"(xs));
#else
load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
};
template <typename T>
struct mma_C_I16J8 {};
template <>
struct mma_C_I16J8<int> {
static constexpr int I = 16;
static constexpr int J = 8;
static constexpr int ne = 4;
int x[ne] = {0};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int l) {
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
__device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
static __device__ __forceinline__ void mma(
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
#ifdef NEW_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
: "+r"(D.x[0]), "+r"(D.x[1])
: "r"(A.x[0]), "r"(B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
: "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[1]), "r"(B.x[0]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
static __device__ __forceinline__ void mma(
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
#ifdef NEW_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
#else
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
: "+r"(D.x[0]), "+r"(D.x[1])
: "r"(A.x[0]), "r"(B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
: "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[1]), "r"(B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[2]), "r"(mma_B.x[1]));
: "+r"(D.x[0]), "+r"(D.x[1])
: "r"(A.x[2]), "r"(B.x[1]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[3]), "r"(mma_B.x[1]));
: "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[3]), "r"(B.x[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
};
template <>
struct mma_C_I16J8<half2> {
static constexpr int I = 16;
static constexpr int J = 4;
static constexpr int ne = 2;
half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = l * (I/2) + threadIdx.x / J;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x % J;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
static __device__ __forceinline__ void mma(
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
int * Axi = (int *) mma_A.x;
int * Bxi = (int *) mma_B.x;
int * xi = (int *) x;
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
: "+r"(xi[0]), "+r"(xi[1])
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(xi[0]), "+r"(xi[1])
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(xi[0]), "+r"(xi[1])
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
mma_B_J8K8<half2> mma_B;
int * xi = (int *) x;
int * Bxi = (int *) mma_B.x;
Bxi[0] = ggml_cuda_movmatrix(xi[0]);
Bxi[1] = ggml_cuda_movmatrix(xi[1]);
return mma_B;
}
};
template <>
struct mma_C_I16J8<float> {
static constexpr int I = 16;
static constexpr int J = 8;
static constexpr int ne = 4;
float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int l) {
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
static __device__ __forceinline__ void mma(
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
int * Axi = (int *) mma_A.x;
int * Bxi = (int *) mma_B.x;
int * xi = (int *) x;
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
: "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
#else
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
mma_B_J8K8<half2> mma_B;
mma_B.x[0] = make_half2(x[0], x[1]);
mma_B.x[1] = make_half2(x[2], x[3]);
int * Bxi = (int *) mma_B.x;
Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
return mma_B;
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
#else
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_i(l)];
}
}
};
}

View File

@@ -7,6 +7,8 @@
#include <climits>
#include <cstdint>
using namespace ggml_cuda_mma;
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
#define MMQ_ITER_K 256
#define MMQ_NWARPS 8
@@ -107,9 +109,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#ifdef GGML_CUDA_FORCE_MMQ
return MMQ_DP4A_MAX_BATCH_SIZE;
#else // GGML_CUDA_FORCE_MMQ
return 128;
#else // GGML_CUDA_FORCE_MMQ
return MMQ_DP4A_MAX_BATCH_SIZE;
#endif // GGML_CUDA_FORCE_MMQ
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
@@ -647,15 +649,15 @@ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
typedef mma_A_I16K8<int> mma_A;
typedef mma_B_J8K8<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 8, int> tile_A;
typedef tile< 8, 8, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
@@ -663,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const float * y_df = (const float *) y;
const half2 * y_ds = (const half2 *) y;
mma_A A[ntx][WARP_SIZE/QI8_0];
float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
tile_A A[ntx][WARP_SIZE/QI8_0];
float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
@@ -674,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
const int k0 = k00 + k01;
A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
@@ -691,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
mma_B B;
float dB[mma_C::ne/2];
tile_B B;
float dB[tile_C::ne/2];
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
@@ -712,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C;
C.mma(A[n][k01/QI8_0], B);
tile_C C;
mma(C, A[n][k01/QI8_0], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
}
}
}
@@ -758,23 +760,23 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
typedef mma_A_I16K8<int> mma_A;
typedef mma_B_J8K8<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 8, int> tile_A;
typedef tile< 8, 8, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
const int * y_qs = (const int *) y + 4;
const half2 * y_dm = (const half2 *) y;
mma_A A[ntx][WARP_SIZE/QI8_1];
float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
tile_A A[ntx][WARP_SIZE/QI8_1];
float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
@@ -784,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
const int k0 = k00 + k01;
A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
@@ -801,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
mma_B B;
float2 dsB[mma_C::ne/2];
tile_B B;
float2 dsB[tile_C::ne/2];
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C;
C.mma(A[n][k01/QI8_1], B);
tile_C C;
mma(C, A[n][k01/QI8_1], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
}
}
}
@@ -868,26 +870,26 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef NEW_MMA_AVAILABLE
typedef mma_A_I16K4<int> mma_A;
typedef mma_A_I16K8<int> mma_A_K8;
typedef mma_B_J8K4<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 4, int> tile_A;
typedef tile<16, 8, int> tile_A_8;
typedef tile< 8, 4, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
mma_A A[ntx][8];
float dA[ntx][mma_C::ne/2][8];
tile_A A[ntx][8];
float dA[ntx][tile_C::ne/2][8];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
@@ -895,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
const int k0 = k00 + k01;
((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
@@ -912,32 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
mma_B B[2];
float dB[mma_C::ne/2];
tile_B B[2];
float dB[tile_C::ne/2];
// Here load_generic is faster than load_ldmatrix.
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C[2];
C[0].mma(A[n][k01/4 + 0], B[0]);
C[1].mma(A[n][k01/4 + 1], B[1]);
tile_C C[2];
mma(C[0], A[n][k01/4 + 0], B[0]);
mma(C[1], A[n][k01/4 + 1], B[1]);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
}
}
}
@@ -1056,27 +1058,27 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef NEW_MMA_AVAILABLE
typedef mma_A_I16K4<int> mma_A;
typedef mma_A_I16K8<int> mma_A_K8;
typedef mma_B_J8K4<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 4, int> tile_A;
typedef tile<16, 8, int> tile_A_8;
typedef tile< 8, 4, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
mma_A A[ntx][8];
float dA[ntx][mma_C::ne/2][8];
float mA[ntx][mma_C::ne/2][8];
tile_A A[ntx][8];
float dA[ntx][tile_C::ne/2][8];
float mA[ntx][tile_C::ne/2][8];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
@@ -1084,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
const int k0 = k00 + k01;
((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
}
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
@@ -1107,58 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
float2 dB[mma_C::ne/2];
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
float2 dB[tile_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
}
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
mma_B B[2];
tile_B B[2];
// Here load_generic is faster than load_ldmatrix.
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
mma_C Cm[2];
tile_C Cm[2];
if (k01 >= WARP_SIZE * 3/4) {
mma_A A1;
tile_A A1;
A1.x[0] = 0x01010101;
A1.x[1] = 0x01010101;
Cm[0].mma(A1, B[0]);
Cm[1].mma(A1, B[1]);
mma(Cm[0], A1, B[0]);
mma(Cm[1], A1, B[1]);
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C Cd[2];
tile_C Cd[2];
Cd[0].mma(A[n][k01/4 + 0], B[0]);
Cd[1].mma(A[n][k01/4 + 1], B[1]);
mma(Cd[0], A[n][k01/4 + 0], B[0]);
mma(Cd[1], A[n][k01/4 + 1], B[1]);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
for (int l = 0; l < tile_C::ne; ++l) {
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
if (k01 >= WARP_SIZE * 3/4) {
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
}
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
}
}
}
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
float2 sB[mma_C::ne/2];
float2 sB[tile_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
@@ -1166,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
}
}
}
@@ -1708,15 +1710,15 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef NEW_MMA_AVAILABLE
typedef mma_A_I16K4<int> mma_A;
typedef mma_B_J8K4<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 4, int> tile_A;
typedef tile< 8, 4, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
@@ -1724,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
mma_A A[ntx][8];
int scA[ntx][mma_C::ne/2][8];
float dA[ntx][mma_C::ne/2];
tile_A A[ntx][8];
int scA[ntx][tile_C::ne/2][8];
float dA[ntx][tile_C::ne/2];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
@@ -1736,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
const int k0 = k00 + k01;
A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
}
#pragma unroll
@@ -1745,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int k0 = k00 + k01;
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
const int8_t * sc = (const int8_t *) &sc_packed;
@@ -1759,41 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
float tmp[ntx][mma_C::ne] = {{0.0f}};
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
float tmp[ntx][tile_C::ne] = {{0.0f}};
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
mma_B B[2];
float dB[mma_C::ne/2];
tile_B B[2];
float dB[tile_C::ne/2];
// Here load_generic is faster than load_ldmatrix.
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C[2];
C[0].mma(A[n][k01/4 + 0], B[0]);
C[1].mma(A[n][k01/4 + 1], B[1]);
tile_C C[2];
mma(C[0], A[n][k01/4 + 0], B[0]);
mma(C[1], A[n][k01/4 + 1], B[1]);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
for (int l = 0; l < tile_C::ne; ++l) {
tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
}
}
@@ -1802,8 +1804,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
}
}
}
@@ -2312,36 +2314,36 @@ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_mma(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
#ifdef NEW_MMA_AVAILABLE
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
#endif // NEW_MMA_AVAILABLE
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne; ++l) {
const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
if (j > j_max) {
continue;
}
const int i = i0 + n*mma_C::I + mma_C::get_i(l);
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
if (need_check && i > i_max) {
continue;
}
dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
}
}
}

View File

@@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16);
DECL_FATTN_MMA_F16_CASE(80, 16);
DECL_FATTN_MMA_F16_CASE(96, 16);
DECL_FATTN_MMA_F16_CASE(112, 16);
DECL_FATTN_MMA_F16_CASE(128, 16);
DECL_FATTN_MMA_F16_CASE(256, 16);

View File

@@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 32);
DECL_FATTN_MMA_F16_CASE(80, 32);
DECL_FATTN_MMA_F16_CASE(96, 32);
DECL_FATTN_MMA_F16_CASE(112, 32);
DECL_FATTN_MMA_F16_CASE(128, 32);
DECL_FATTN_MMA_F16_CASE(256, 32);

View File

@@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 64);
DECL_FATTN_MMA_F16_CASE(80, 64);
DECL_FATTN_MMA_F16_CASE(96, 64);
DECL_FATTN_MMA_F16_CASE(112, 64);
DECL_FATTN_MMA_F16_CASE(128, 64);
DECL_FATTN_MMA_F16_CASE(256, 64);

View File

@@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8);
DECL_FATTN_MMA_F16_CASE(80, 8);
DECL_FATTN_MMA_F16_CASE(96, 8);
DECL_FATTN_MMA_F16_CASE(112, 8);
DECL_FATTN_MMA_F16_CASE(128, 8);
DECL_FATTN_MMA_F16_CASE(256, 8);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 1, 8);
DECL_FATTN_MMA_F16_CASE(80, 1, 8);
DECL_FATTN_MMA_F16_CASE(96, 1, 8);
DECL_FATTN_MMA_F16_CASE(112, 1, 8);
DECL_FATTN_MMA_F16_CASE(128, 1, 8);
DECL_FATTN_MMA_F16_CASE(256, 1, 8);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16, 1);
DECL_FATTN_MMA_F16_CASE(80, 16, 1);
DECL_FATTN_MMA_F16_CASE(96, 16, 1);
DECL_FATTN_MMA_F16_CASE(112, 16, 1);
DECL_FATTN_MMA_F16_CASE(128, 16, 1);
DECL_FATTN_MMA_F16_CASE(256, 16, 1);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16, 2);
DECL_FATTN_MMA_F16_CASE(80, 16, 2);
DECL_FATTN_MMA_F16_CASE(96, 16, 2);
DECL_FATTN_MMA_F16_CASE(112, 16, 2);
DECL_FATTN_MMA_F16_CASE(128, 16, 2);
DECL_FATTN_MMA_F16_CASE(256, 16, 2);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16, 4);
DECL_FATTN_MMA_F16_CASE(80, 16, 4);
DECL_FATTN_MMA_F16_CASE(96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 16, 4);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 2, 4);
DECL_FATTN_MMA_F16_CASE(80, 2, 4);
DECL_FATTN_MMA_F16_CASE(96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 2, 4);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 2, 8);
DECL_FATTN_MMA_F16_CASE(80, 2, 8);
DECL_FATTN_MMA_F16_CASE(96, 2, 8);
DECL_FATTN_MMA_F16_CASE(112, 2, 8);
DECL_FATTN_MMA_F16_CASE(128, 2, 8);
DECL_FATTN_MMA_F16_CASE(256, 2, 8);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 32, 1);
DECL_FATTN_MMA_F16_CASE(80, 32, 1);
DECL_FATTN_MMA_F16_CASE(96, 32, 1);
DECL_FATTN_MMA_F16_CASE(112, 32, 1);
DECL_FATTN_MMA_F16_CASE(128, 32, 1);
DECL_FATTN_MMA_F16_CASE(256, 32, 1);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 32, 2);
DECL_FATTN_MMA_F16_CASE(80, 32, 2);
DECL_FATTN_MMA_F16_CASE(96, 32, 2);
DECL_FATTN_MMA_F16_CASE(112, 32, 2);
DECL_FATTN_MMA_F16_CASE(128, 32, 2);
DECL_FATTN_MMA_F16_CASE(256, 32, 2);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 4, 2);
DECL_FATTN_MMA_F16_CASE(80, 4, 2);
DECL_FATTN_MMA_F16_CASE(96, 4, 2);
DECL_FATTN_MMA_F16_CASE(112, 4, 2);
DECL_FATTN_MMA_F16_CASE(128, 4, 2);
DECL_FATTN_MMA_F16_CASE(256, 4, 2);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 4, 4);
DECL_FATTN_MMA_F16_CASE(80, 4, 4);
DECL_FATTN_MMA_F16_CASE(96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 4, 4);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 4, 8);
DECL_FATTN_MMA_F16_CASE(80, 4, 8);
DECL_FATTN_MMA_F16_CASE(96, 4, 8);
DECL_FATTN_MMA_F16_CASE(112, 4, 8);
DECL_FATTN_MMA_F16_CASE(128, 4, 8);
DECL_FATTN_MMA_F16_CASE(256, 4, 8);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 64, 1);
DECL_FATTN_MMA_F16_CASE(80, 64, 1);
DECL_FATTN_MMA_F16_CASE(96, 64, 1);
DECL_FATTN_MMA_F16_CASE(112, 64, 1);
DECL_FATTN_MMA_F16_CASE(128, 64, 1);
DECL_FATTN_MMA_F16_CASE(256, 64, 1);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8, 1);
DECL_FATTN_MMA_F16_CASE(80, 8, 1);
DECL_FATTN_MMA_F16_CASE(96, 8, 1);
DECL_FATTN_MMA_F16_CASE(112, 8, 1);
DECL_FATTN_MMA_F16_CASE(128, 8, 1);
DECL_FATTN_MMA_F16_CASE(256, 8, 1);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8, 2);
DECL_FATTN_MMA_F16_CASE(80, 8, 2);
DECL_FATTN_MMA_F16_CASE(96, 8, 2);
DECL_FATTN_MMA_F16_CASE(112, 8, 2);
DECL_FATTN_MMA_F16_CASE(128, 8, 2);
DECL_FATTN_MMA_F16_CASE(256, 8, 2);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8, 4);
DECL_FATTN_MMA_F16_CASE(80, 8, 4);
DECL_FATTN_MMA_F16_CASE(96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 8, 4);

View File

@@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8, 8);
DECL_FATTN_MMA_F16_CASE(80, 8, 8);
DECL_FATTN_MMA_F16_CASE(96, 8, 8);
DECL_FATTN_MMA_F16_CASE(112, 8, 8);
DECL_FATTN_MMA_F16_CASE(128, 8, 8);
DECL_FATTN_MMA_F16_CASE(256, 8, 8);

View File

@@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
"""
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@@ -57,12 +57,18 @@ for vkq_size in [16, 32]:
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
for cols_per_block in [8, 16, 32, 64]:
with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
f.write(SOURCE_FATTN_MMA_START)
for ncols in [8, 16, 32, 64, 128]:
for ncols2 in [1, 2, 4, 8]:
ncols1 = ncols // ncols2
if ncols == 128:
continue # Too much register pressure.
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
f.write(SOURCE_FATTN_MMA_START)
for head_size in [64, 80, 96, 112, 128, 256]:
f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
for head_size in [64, 80, 96, 112, 128, 256]:
if ncols == 128 and head_size == 256:
continue # Needs too much shared memory.
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
for type in TYPES_MMQ:
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:

View File

@@ -107,6 +107,10 @@ if (GGML_HIP_NO_VMM)
add_compile_definitions(GGML_HIP_NO_VMM)
endif()
if (NOT GGML_CUDA_FA)
add_compile_definitions(GGML_CUDA_NO_FA)
endif()
if (CXX_IS_HIPCC)
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
target_link_libraries(ggml-hip PRIVATE hip::device)

View File

@@ -16,7 +16,7 @@
#include <arm_sve.h>
#endif // __ARM_FEATURE_SVE
#if defined(__ARM_NEON) && !defined(__CUDACC__)
#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
//
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/

View File

@@ -407,6 +407,16 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
GGML_METAL_KERNEL_TYPE_CONCAT,
GGML_METAL_KERNEL_TYPE_SQR,
GGML_METAL_KERNEL_TYPE_SQRT,
@@ -1012,6 +1022,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
@@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
default:
return false;
}
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
switch (op->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
return true;
default:
return false;
}
default:
return false;
};
@@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node(
case GGML_OP_CPY:
case GGML_OP_CONT:
{
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
id<MTLComputePipelineState> pipeline = nil;
switch (src0t) {
@@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node(
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
default: GGML_ABORT("not implemented");
};
} break;
case GGML_TYPE_Q4_0:
{
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
default: GGML_ABORT("not implemented");
};
} break;
case GGML_TYPE_Q4_1:
{
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
default: GGML_ABORT("not implemented");
};
} break;
case GGML_TYPE_Q5_0:
{
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
default: GGML_ABORT("not implemented");
};
} break;
case GGML_TYPE_Q5_1:
{
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
default: GGML_ABORT("not implemented");
};
} break;
case GGML_TYPE_Q8_0:
{
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
default: GGML_ABORT("not implemented");
};
} break;
default: GGML_ABORT("not implemented");
@@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_SET:
{

View File

@@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl(
}
}
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_cpy_q_f32(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
T4x4 temp;
dequantize_func(src_data + i00/nl, i00%nl, temp);
dst_data[i00] = temp;
}
}
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
kernel void kernel_concat(
constant ggml_metal_kargs_concat & args,
device const char * src0,

Some files were not shown because too many files have changed in this diff Show More