Compare commits

...

53 Commits
b8424 ... b8477

Author SHA1 Message Date
bssrdf
ec2b787ebe mtmd: Add dynamic high-resolution image preprocessing for InternVL model (#20847)
* added support for internvl's dynamic high-resolution (Qianfan-OCR needed)

* add min/max dynamic patch to gguf meta

* clean up

* simplified handling min/max dynamic patch

* reuse llava_uhd logic for slice images

* provide default values for older models

* flake8

* prevent writing 0 value to gguf

* remove duplicated resolution candidates with a better algorithm

* fix indentation

* format

* add protection from divide by zero

* change to 0 to be safe

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-03-23 01:06:30 +01:00
DorianRudolph
d3ac030a5d mtmd : fix LightOnOCR image preprocessing (#20877) 2026-03-23 01:04:14 +01:00
Xuan-Son Nguyen
49bfddeca1 server: allow router to report child instances sleep status (#20849)
* server: allow router to report child instances sleep status

* refactor

* move sleeping to state

* nits
2026-03-22 18:33:52 +01:00
Johannes Gäßler
bd3f1d9d65 CUDA: fix BF16 FA compilation (#20865) 2026-03-22 17:53:33 +01:00
Sigbjørn Skjæret
23c9182ce8 jinja : refactor token advancement (#20864)
* refactor token advancement

* exercise sub-expressions
2026-03-22 17:45:10 +01:00
Evgeny Kurnevsky
81bc4d3ddc server: fix Host header (#20843)
It should include port when it's not default.
2026-03-22 22:29:22 +08:00
Neo Zhang
f40a80b4f3 support bf16 and quantized type (#20803) 2026-03-22 22:06:27 +08:00
Patrick Buckley
db9d8aa428 ggml-cuda: native bf16 flash attention for vec kernel (#20525)
* ggml-cuda: native bf16 flash attention for vec and tile kernels

mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo

* ggml-cuda: address code owner review feedback

reverted tile kernel changes to avoid larger refactor

* fix ci failures on turing and hip

* fix bf16 vec kernel compile on hip v_dot2 platforms

* add comments

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-03-22 11:05:51 +01:00
Gaurav Garg
ccb87fa3ee [CUDA] Increase number of output elements per-thread block if the K-dimension is small (#20635)
* Increase per-thread work if the K-dimension is small

With tensor parallelism, the K-dimension of the FFN-down matrices is split, which makes it quite small, especially for MOEs. For example, Qwen3-30b-A3B has a K-dimension of 768, and Qwen3235B-A22B has k-dimension of 1536.
The current heuristic uses a group of 4 warps irrespective of K-dimension size, resulting in some of the threads being idle. This results in poor performance for these matrices.

This change increases the number of output elements per block for such cases.

* Limit this change to ncols_dst = 1

* tab to space
2026-03-22 16:49:35 +08:00
ddh0
3306dbaef7 misc : prefer ggml-org models in docs and examples (#20827)
* misc : prefer ggml-org models in docs and examples

Prefer referring to known-good quantizations under ggml-org rather than
3rd-party uploaders.

* remove accidentally committed file
2026-03-21 22:00:26 +01:00
Andrea Arcangeli
990e4d9698 common/grammar: fix grammar parsing issues to prevent stack overflow and hangs (#18604)
* grammar: add test case for nullable symbol loop

Reproduce stack overflow (or OOM) with ( [x]* )* found while adding
GBNF support to ripgrep-edit.

llama-server reproducer:

curl \
  -X POST \
  -d '{
    "messages": [{ "role": "user", "content": "write yes" }],
    "grammar": "root ::= ( [x]* )*"
  }' \
  -H "Content-Type: application/json" \
  http://localhost:8811/v1/chat/completions

* grammar: prevent stack overflow with nullable symbol loop

Fix a potential stack overflow in llama_grammar_advance_stack that
could occur when processing grammars with nullable symbols that lead
to infinite derivations of empty strings. The fix introduces cycle
detection by tracking visited stacks to prevent infinite recursion.

rg-edit regexp: llama_grammar_advance_stack
rg-edit extra-args: -A20
rg-edit directive: """Rewrite: fix the following segfault:

[..]
 Testing segfault. Grammar:
            root ::= ( [x]* )*

            root ::= ( [x]* )*

Segmentation fault         build/bin/test-grammar-integration"""

gptel-context:
(("~/llama.cpp/src/llama-grammar.cpp")
 ("~/llama.cpp/tests/test-grammar-integration.cpp")
 ("~/llama.cpp/grammars/./list.gbnf")
 ("~/llama.cpp/grammars/./json_arr.gbnf")
 ("~/llama.cpp/grammars/./json.gbnf")
 ("~/llama.cpp/grammars/./japanese.gbnf")
 ("~/llama.cpp/grammars/./english.gbnf")
 ("~/llama.cpp/grammars/./chess.gbnf")
 ("~/llama.cpp/grammars/./c.gbnf")
 ("~/llama.cpp/grammars/./arithmetic.gbnf")
 ("~/llama.cpp/grammars/./README.md"))

* grammar: convert recursive llama_grammar_advance_stack to iterative

This change converts the function to an iterative approach using
explicit stacks, which prevents deep recursion and eliminates the risk
of stack overflow.

rg-edit regexp: llama_grammar_advance_stack
rg-edit extra-args: -A30
rg-edit directive: """Rewrite: fix the following segfault:

[..]
 Testing segfault. Grammar:
            root ::= ( [x]* )*

            root ::= ( [x]* )*

Segmentation fault         build/bin/test-grammar-integration

convert from recursive to interactive"""

gptel-context:
(("~/llama.cpp/src/llama-grammar.cpp")
 ("~/llama.cpp/tests/test-grammar-integration.cpp")
 ("~/llama.cpp/grammars/./list.gbnf")
 ("~/llama.cpp/grammars/./json_arr.gbnf")
 ("~/llama.cpp/grammars/./json.gbnf")
 ("~/llama.cpp/grammars/./japanese.gbnf")
 ("~/llama.cpp/grammars/./english.gbnf")
 ("~/llama.cpp/grammars/./chess.gbnf")
 ("~/llama.cpp/grammars/./c.gbnf")
 ("~/llama.cpp/grammars/./arithmetic.gbnf")
 ("~/llama.cpp/grammars/./README.md"))

v2: Added a `std::set` to perform tree-based lookups with O(N log N)
complexity. Testing with a parallel run of `test-grammar-integration`
shows a double-digit percentage increase in runtime. An
`unordered_set` with O(1) hashing was also evaluated, but the overhead
of constructing hash keys from pointers made it significantly slower
than the rbtree implementation that only requires an ordering
operator. The performance regression in the test suite appears
justified by the overall reduction in algorithmic complexity.

Co-developed-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>

* grammar: add test case for hang in repetition grammar processing

This commit adds a new test case to the grammar integration tests that
specifically targets a hang scenario in the repetition grammar parser
found while adding GBNF support to ripgrep-edit.

llama-server reproducer:

curl \
  -X POST \
  -d '{
    "messages": [{ "role": "user", "content": "write yes" }],
    "grammar": "root ::= (([^x]*){0,99}){0,99}"
  }' \
  -H "Content-Type: application/json" \
  http://localhost:8811/v1/chat/completions

* grammar: add repetition threshold check

The change introduces a maximum repetition threshold to avoid
excessive rule expansion during grammar parsing. When parsing
repetition patterns like {m,n}, the parser now calculates the
potential number of rules that would be generated and throws an error
if the product of previous rules and new rules exceeds the threshold.

A test case was added to verify the threshold is properly enforced for
deeply nested repetition patterns that would otherwise cause hangs.
2026-03-21 18:43:35 +01:00
Tom Hillbrunner
212f4521b0 context : use n_embd_out for pooled embedding extraction (#20840)
The MEAN/CLS/LAST pooling paths in encode() and decode() used
n_embd_inp() (16384 for qwen3vl with deepstack) to read from the
pooled embedding tensor, which only has n_embd_out() (4096) floats
per sequence. This caused a tensor read out of bounds assertion.

Fixes embedding mode for Qwen3-VL-Embedding models.
2026-03-21 19:35:00 +02:00
Xuan-Son Nguyen
568aec82d2 docs : explicit about banning accounts that violates policy (#19593) 2026-03-21 15:50:16 +01:00
y198
2bcdddd5e3 fix(rpc): prevent division by zero in deserialize_tensor (#20712)
rpc : prevent division by zero in deserialize_tensor

When receiving an RPC message with a deprecated tensor type (e.g., type 4 or 5 where `blck_size == 0`), `ggml_row_size()` will trigger a division by zero (SIGFPE) and crash the rpc-server. 

This patch adds a simple validation check in `deserialize_tensor` to return `nullptr` if the requested tensor type has a block size of 0.

(Note: This was originally reported via Security Advisory and maintainer suggested dropping a patch here).

* style: remove trailing whitespace
2026-03-21 15:59:43 +02:00
Michael Wand
eac9c6ea83 Convert: Make NVFP4 and MXFP4 HF conversions say NVFP4/MXFP4 instead of BF16 (#20730)
* Corrected convert script for NVFP4 naming and updated gguf constants

* Add mostly_MXFP4 to FileType

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* simplify

* set initial value [no ci]

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-21 13:35:21 +02:00
Sigbjørn Skjæret
29b28a9824 ci : switch from pyright to ty (#20826)
* type fixes

* switch to ty

* tweak rules

* tweak more rules

* more tweaks

* final tweak

* use common import-not-found rule
2026-03-21 08:54:34 +01:00
Matt Corallo
cea560f483 Add shader count for Intel Arc Pro B60 (#20818) 2026-03-21 05:22:51 +01:00
Piotr Wilkin (ilintar)
b1c70e2e54 common/parser: fix nasty bug causing subtle corruption of generation prompt (#20825) 2026-03-21 00:19:04 +01:00
shalinib-ibm
e6ec21e62f ggml-cpu: add always_inline to tinyBLAS_PPC accumulator saves (#20791)
Explicitly mark save_acc and add_save_Acc with always_inline
in tinyBLAS_PPC. This ensures the compiler keeps MMA accumulator
disassembly within kernel's register context, preventing un-necessary
stask spills.

Signed-off-by: Shalini Salomi Bodapati <Shalini.Salomi.Bodapati@ibm.com>
2026-03-21 07:11:45 +08:00
Georgi Gerganov
4cb7e0bd61 ai : limit runtime of the agent (#20816) 2026-03-20 20:31:25 +02:00
James O'Leary
149b2493c0 common : fix typo in debug log ('extracft' -> 'extract') (#20807) 2026-03-20 18:23:18 +01:00
Georgi Gerganov
b31b30f31d ai : do not run bash commands in the prompt (#20810) 2026-03-20 19:06:33 +02:00
Victor Villar
58c81f7e81 model : fix Granite Hybrid type check for 7B.A1B (#20795)
* Check granite hybriid expert count to set type as LLM_TYPE_7B_A1B or LLM_TYPE_1B

* Use feed fwd dim instead of num of experts

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-20 15:16:09 +01:00
Xuan-Son Nguyen
fb78ad29bb server: (doc) clarify in-scope and out-scope features (#20794)
* server: (doc) clarify in-scope and out-scope features

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-03-20 14:03:50 +01:00
Jeff Bolz
e06c3ab2bc vulkan: change gated_delta_net to shard a column across a subgroup (#20662)
* vulkan: change gated_delta_net to shard a column across a subgroup

This is based on https://github.com/ggml-org/llama.cpp/pull/20391, I used an
LLM to port the CUDA code to Vulkan, and guided to it to make various fixes to
work with Vulkan (e.g. handling different subgroup sizes, unknown mapping of
subgroup to invocation id, using subgroupAdd optionally, etc.).

This fixes a perf regression from the transposing of the values in memory
(!20443).

* vulkan: Spread columns across fewer lanes to reduce the number of workgroups
2026-03-20 12:17:15 +01:00
Ruikai Peng
dc6592431b context: zero output buffer on allocation (#20781)
* context: zero output buffer on allocation

Address GHSA-wqq9-25mr-rw76.

The logits output buffer allocated in output_reserve() uses
posix_memalign(), which does not zero memory. The buffer is only
written during decode when needs_raw_logits() returns true. When
backend samplers cover all output sequences, needs_raw_logits()
returns false and the buffer is never written, but
llama_get_logits() still returns a pointer to it, exposing stale
heap content.

Zero the buffer after allocation to prevent information disclosure
through the public logits API.

Found-by: Pwno

* Update src/llama-context.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-03-20 11:31:34 +02:00
Ruikai Peng
3adbef7776 model: assert nextn_predict_layers to prevent underflow (#20783)
Address GHSA-645x-v54x-34w8.

When nextn_predict_layers >= n_layer, n_layer - nextn_predict_layers
can underflow (unsigned wrap), which corrupts n_layer_kv_from_start.

Assert nextn_predict_layers immediately after parsing the GGUF key.

Found-by: Pwno
2026-03-20 10:17:58 +01:00
Georgi Gerganov
ab9d4c3678 server : improve mtmd ctx checkpoints (#20726)
* server : improve mtmd ctx checkpoints

* server : fix off-by-one in pos_min_thold
2026-03-20 11:13:12 +02:00
hipudding
1af9dab32b CANN: add BF16 support for core operators (#20152)
* CANN: add BF16 support for core operators

Add BF16 (bfloat16) type support to the CANN backend for the following
operators: MUL_MAT, MUL_MAT_ID, GET_ROWS, SET_ROWS, CPY, CONT, and
OUT_PROD. This enables BF16 models to run on Ascend NPUs.

* CANN: skip NZ weight format for BF16 and add 310P compile guards

NZ weight format conversion does not support BF16 tensors, skip it
in set_tensor, get_alloc_size and mul_mat. Remove BF16 from MUL_MAT_ID
and OUT_PROD as there are no BF16 use cases. Add #ifndef ASCEND_310P
guards for all BF16 operator support since 310P does not support BF16.
2026-03-20 17:08:39 +08:00
Seyoung Jeong
6d99b44c7e docs : fix Metal backend op support status in ops.md (#20779)
Regenerate docs/ops/Metal.csv using test-backend-ops on Apple M5
and rebuild docs/ops.md via scripts/create_ops_docs.py.

Five ops were incorrectly marked as not supported () for Metal:
- DIAG:           
- POOL_1D:        
- SET:            
- SOLVE_TRI:      
- GATED_DELTA_NET:🟡 (partial, depends on head_size % 32)
2026-03-20 11:06:38 +02:00
Georgi Gerganov
464fd0e71f ai : update find-related action (#20790)
* ai : update "related issues" prompt

* cont

* cont

* cont
2026-03-20 10:28:14 +02:00
Ruikai Peng
21c8045214 jinja : fix heap OOB read in value equality comparison (#20782)
Address GHSA-q9j6-4hhc-rq9p and GHSA-2q4c-9gq5-5vfp.

The three-iterator overload of std::equal in value_array_t::equivalent()
and value_object_t::equivalent() reads past the end of the shorter
container when comparing arrays or objects of different lengths.

Use the four-iterator overload (C++14) which checks both range lengths.

Found-by: Pwno
2026-03-20 07:15:17 +01:00
James O'Leary
c46583b86b common/parser : fix out_of_range crash in throw path (#20424 regression) (#20777)
* chat : fix out_of_range crash in throw path (#20424 regression)

#20424 introduced effective_input = generation_prompt + input, but the
throw path uses input.substr(result.end) where result.end is a position
within effective_input. Every thinking model with a non-empty
generation_prompt crashes with std::out_of_range instead of the intended
error message.

Test crashes on unpatched master, passes with fix:

  cmake -B build -DLLAMA_BUILD_TESTS=ON -DLLAMA_BUILD_TOOLS=OFF
  cmake --build build --target test-chat
  ./build/bin/test-chat

* Update test-chat.cpp

* Update test-chat.cpp

* Update test-chat.cpp

---------

Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
2026-03-20 02:37:22 +01:00
Ben Racicot
c1b911654a server: fix router mode deadlock on child crash and TOCTOU race in models_max (#20763)
Two bugs in `server_models::load()` that affect router mode reliability:

**Bug 1: Deadlock when child process crashes**

When a child process is killed (e.g., SIGKILL from OS code signature
validation), the monitoring thread deadlocks on `stopping_thread.join()`
because the stopping_thread's wait predicate (`is_stopping`) is never
satisfied — the model name was never inserted into `stopping_models`.
`update_status()` is never reached and the model stays stuck in LOADING
state permanently.

Fix: extend the stopping_thread's wait predicate to also wake when the
child process is no longer alive (`!subprocess_alive()`). When woken by
a dead child, the thread skips the shutdown sequence and returns
immediately. The original `stopping_models.erase()` logic is preserved
for normal unloads.

**Bug 2: TOCTOU race bypasses `--models-max` (ref #20137)**

`unload_lru()` is called outside the mutex, then `load()` acquires the
lock afterward. Under concurrent requests, multiple threads observe
capacity and all proceed to load, exceeding the limit.

Fix: re-check capacity under the lock after `unload_lru()` returns.
If another thread filled the slot in the window between `unload_lru()`
and the lock acquisition, reject with an error instead of silently
exceeding the limit.
2026-03-19 22:16:05 +01:00
Tomeamis
b739738dad docs: Update server README to reflect PR #20297 (#20560) 2026-03-19 21:28:44 +01:00
Sundaram krishnan
a0bbcdd9b6 ggml: guard KleidiAI DOWNLOAD_EXTRACT_TIMESTAMP for cmake < 3.24 (#20767) 2026-03-19 21:36:23 +02:00
Georgi Gerganov
6c72646a61 ci : improve action for duplicate issue (#20772)
* ci : show thinking traces of the agent

* cont : increase thinking

* cont : remove agent files

* cont : move the model selection to the provider
2026-03-19 21:11:53 +02:00
Rail Chabdarov
340807273b hip: Avoid compiler bug in RDNA code generation during debug builds on Windows (#20655) 2026-03-19 19:14:08 +01:00
Ryan Goulden
26c9ce1288 server: Add cached_tokens info to oaicompat responses (#19361)
* tests : fix fetch_server_test_models.py

* server: to_json_oaicompat cached_tokens

Adds OpenAI and Anthropic compatible information about the
number of cached prompt tokens used in a response.
2026-03-19 19:09:33 +01:00
James O'Leary
76f2dc70c3 chat : handle tool calls with no required args in TAG_WITH_TAGGED format (#20764)
* chat : handle tool calls with no required args in TAG_WITH_TAGGED format

* Update tests/test-chat.cpp [no ci]

Co-authored-by: Aldehir Rojas <hello@alde.dev>

---------

Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
Co-authored-by: Aldehir Rojas <hello@alde.dev>
2026-03-19 17:53:11 +01:00
Georgi Gerganov
900efd531d ci : clarify gh command for viewing issues (#20766) 2026-03-19 18:43:54 +02:00
Yiwei Shao
74c42ee1f4 hexagon: add Matrix Extensions (HMX) for Hexagon NPU backend (#20693)
* migrate(vtcm): unify VTCM management for HMX merge

- Add HMX fields to htp_context (#ifdef HTP_HAS_HMX): hmx_enabled,
  hmx_dma, vtcm_scratch_size, exp2_table
- Add HTP_VTCM_SESSION_HOLD CMake option (default ON): hold VTCM for
  entire session instead of per-op acquire/release
- Add vtcm_op_acquire/vtcm_op_release inline wrappers: no-op in
  session-hold mode, delegate in per-op mode
- Add VTCM tail reservation for precompute tables (256KB, 64KB aligned)
  in htp_iface_start under HTP_HAS_HMX
- Add HMX init/cleanup hooks in htp_iface_start/stop
- Add precompute table recovery in vtcm_acquire after VTCM preemption
- Do NOT migrate vtcm_mgr from htp-ops-lib (replaced by tail reservation)

* migrate(repack): replace x4x2 with HMX tile-permuted super-block format

- Add hmx_block_q4_0/q8_0 struct definitions (scales-first + sequential quants)
- Implement forward repack: repack_q4_0_to_hmx_superblock, repack_q8_0_to_hmx_superblock, repack_f16_to_tile_permuted
- Implement inverse repack for get_tensor debug verification
- Route set_tensor/get_tensor via opt_arch >= 73 to HMX path, else existing HVX x4x2
- MXFP4 on v73+ falls back to HVX x4x2 repack (not memcpy)
- Extend supports_op: add IQ4_NL for v73+, F16 tile alignment checks
- Tail blocks (K not multiple of 256): repack to x4x2 via pad-repack-truncate
- Add CMake GGML_HEXAGON_HMX_TAIL_HVX option (default ON); OFF rejects non-256-aligned K in supports_op

* migrate(dma): add dma_queue_push_1d() convenience wrapper for HMX ops

Add 1D linear DMA transfer helper to hex-dma.h for upcoming HMX op
migration. Reuses existing dma_queue_flush() for sync points instead
of adding redundant dma_queue_drain().

* migrate(hmx): reorganize HMX files into htp/hmx/ and simplify HMX locking

Move all 14 HMX-related files from htp/ to htp/hmx/ subdirectory for
cleaner separation between HVX and HMX code. Simplify HMX hardware
locking by replacing the two-level lock design (SHARED HAP lock +
custom asm spin-lock) with direct HAP_compute_res_hmx_lock/unlock
on the existing vtcm_rctx, which already has HMX capability.

Key changes:
- Create htp/hmx/ subdirectory with all HMX infrastructure and ops
- Replace hmx_mgr_ctx_id + spin-lock with HAP_compute_res_hmx_lock(vtcm_rctx)
- Remove hmx_manager_enable/disable_execution() (SHARED lock no longer needed)
- Add hmx_set_vtcm_state() call in main.c (was missing, caused null globals)
- Update main.c includes to use hmx/ prefix
- Clean up duplicate declarations from hmx-worker-pool.h

* migrate(hmx-infra): consolidate HMX infrastructure into htp_context

- Remove hmx-mgr.c/h: eliminate global HMX state singleton, thread htp_context through all HMX ops
- Remove hmx-worker-pool.c/h: replace separate HMX worker pool with main worker_pool API (worker_pool_run_func)
- Replace hmx_unit_acquire/release with direct HAP_compute_res_hmx_lock/unlock on ctx->vtcm_rctx
- Remove HTP_VTCM_SESSION_HOLD compile option: always use per-op vtcm_acquire/release
- Remove hmx_dma from htp_context: HMX ops use ctx->dma[0] instead of separate DMA queue
- Simplify main.c init/cleanup: remove hmx_manager_setup/reset and vtcm_op_acquire/release wrappers
- Delete upstream llama.cpp AGENTS.md (not applicable to fork)

* migrate(flash-attn): remove HTP_EXP2_TABLE_COPIES, use single exp2 table

- Remove HTP_EXP2_TABLE_COPIES compile definition and CMake cache variable
- Remove table duplication loop in precompute-table.c
- Remove worker_index % N sub-table indexing in hmx-flash-attn-ops.c
- Fix table_size to 65536 (single 64 KB copy) in main.c

The exp2 lookup table is read-only; concurrent VTCM reads do not cause
bank conflicts, so duplicating the table wastes 192 KB of VTCM for no
benefit.

* migrate(dsp-main): add HMX priority dispatch in packet_callback

- Add proc_hmx_matmul_req() wrapper for HMX mat_mul (F16 and quantized types)
- Add proc_hmx_flash_attn_req() wrapper for HMX simple_flash_attn (FP16 only, falls back to HVX for non-FP16)
- Add proc_hmx_rms_norm_req() wrapper using hvx_rms_norm_f32
- Route MUL_MAT, FLASH_ATTN_EXT, RMS_NORM through HMX path when ctx->hmx_enabled
- Split RMS_NORM and SCALE into separate case blocks for independent dispatch
- All HMX wrappers guarded by #ifdef HTP_HAS_HMX

* migrate(cmake-dsp): add HMX source files and -mhmx for v73+ skels

Add HTP_VTCM_SESSION_HOLD option (default ON) and v73+ HMX build
integration: compile hmx-matmul-ops, hmx-flash-attn-ops,
hmx-rms-norm-ops and precompute-table into v73/v75/v79/v81 skels
with -mhmx flag and HTP_HAS_HMX=1 definition. v68/v69 skels remain
unchanged.

* migrate(hmx-ops): fix compile errors in HMX ops for ggml struct compatibility

- hmx-matmul-ops.c: include ggml-common.h for block_q4_0/block_q8_0 definitions
- hmx-matmul-ops.c: rename quants->qs, scale->d to match upstream ggml field names
- hmx-flash-attn-ops.c: suppress -Wunused-function/-Wunused-variable warnings
- hmx-flash-attn-ops.c: inline ctx->n_threads, remove unused n_workers variable

* hmx: set Q/O element type to fp16 for flash attention

The llama.cpp integration passes fp16 Q/O tensors, so qo_fp32_element
should be false to match the actual data layout.

* hexagon: unify HMX weight format to x4x2, add IQ4_NL and DSP-side fallback

Remove the v73+ HMX-specific super-block/tile-permuted weight format
and unify all architectures on the HVX x4x2 packed format. The DSP now
decides at runtime whether to use the HMX or HVX matmul path based on
dimension constraints (M%32, N%32, K%256 alignment), rather than the
host rejecting ops in supports_op. This simplifies the host repack
logic, eliminates ~400 lines of HMX super-block code, and adds IQ4_NL
quantization support across host and DSP.

Key changes:
- Remove hmx_block_q4_0/q8_0 types, repack functions, and F16 tile
  permutation (ggml-hexagon.cpp, hmx-quants.h)
- Simplify set_tensor/get_tensor to always use x4x2 repack, add IQ4_NL
- Force is_host=false so tensor copies go through format conversion
- Add HTP_TYPE_IQ4_NL to DSP message protocol (htp-msg.h)
- Rewrite DSP dequantizers to work directly on x4x2 layout
  (hmx-matmul-ops.c)
- Fix mxclracc.hf placement: clear per output tile, not once globally
- Move HMX eligibility checks to DSP proc_hmx_matmul_req (main.c)
- Remove dma_queue_push_1d wrapper, use 2D DMA for weight sub-blocks
- Add VTCM allocation overflow asserts
- Remove GGML_HEXAGON_HMX_TAIL_HVX build option (CMakeLists.txt)

* Enhance HMX debugging capabilities with new tile dumping functions

- Introduced hmx_dump_tile_mem and hmx_dump_fp32_tile_region for improved memory layout visualization of tile data.
- Updated hmx_dump_tile_rows to provide raw memory output for debugging.
- Added debug logging for activation and weight tile pairs during processing to facilitate troubleshooting.
- Refined existing macros for dumping HVX vector values to streamline debugging output.

These changes aim to enhance the debugging experience for HMX matmul operations, ensuring better visibility into data handling and transformations.

* OK for small mat mul

* hexagon: fix UDMA roiwidth 16-bit overflow in HMX matmul DMA transfers

The UDMA descriptor roiwidth field is 16-bit (max 65535), but large matrix
DMA transfers (e.g. 32×2304 = 73728 bytes) exceeded this limit, causing
truncated transfers and NaN results. Fix by using 2D DMA (per-row stride ×
n_rows) instead of 1D (total_size × 1) for all 4 DMA push calls in both
x4x2 and fp16 weight paths.

Also includes:
- Use standard vlut16 instead of _nomatch variant for dequantization
- Add per-tile vscatter drain barrier for correctness
- Add compile-time HMX_DEBUG_TRACE_VALUES instrumentation (disabled by default)

* hexagon: remove HMX RMS norm fallback and re-enable matmul pipeline

Remove hmx-rms-norm-ops.c as the HVX RMS norm offers no benefit over
the generic unary path. Re-enable DMA pipeline mode for QK matmul.

* hexagon: guard all HMX matmul DMA transfers against UDMA 16-bit field overflow

All UDMA type1 descriptor fields (roiwidth, roiheight, srcstride, dststride)
are 16-bit (max 65535). Commit 40d2a9cc fixed roiwidth overflow in the
non-pipeline path by switching from 1D to 2D DMA, but the pipeline path
(3 call sites) was left unchanged and still used 1D DMA with
chunk_size = n_cols * row_stride as roiwidth, which overflows for any
practical matrix size when the pipeline is active.

Add a local hmx_dma_push_safe() helper that transparently handles overflow:
- Fast path (zero overhead): all params fit in 16 bits -> direct call.
- Contiguous block: reshapes into a single 2D descriptor with sub_width
  that fits in 16 bits, preserving async DMA behavior.
- Stride overflow: row-by-row fallback for future large-k models where
  per-row stride itself exceeds 65535.

Convert all 8 external dma_queue_push calls in hmx-matmul-ops.c to use
the safe helper, including the 3 pipeline sites (1D -> 2D fix), the
FP16 and x4x2 weight paths, qweight_fetch sub-block DMA, and the
output-stationary activation fetch.

* hexagon: multithread activation/output transfer and add HMX matmul fallback

- Replace single-threaded transfer_activation_chunk_fp32_to_fp16 with
  transfer_activation_chunk_multithread across all HMX matmul paths
- Add multi-threaded transfer_output_chunk_multithread for FP16-to-FP32
  output store, following the same worker pool pattern
- Rename transfer_activation_chunk_no_prefetch back to
  transfer_activation_chunk_fp32_to_fp16 and clean up stale comments
- Add HVX fallback in proc_hmx_matmul_req when HMX matmul returns error

* [todo]: dynamic alloc vtcm, cause prefill regression.

* hexagon: constrain HMX mxmem tile load region to avoid VTCM bank boundary faults

Set activation/weight mxmem Rt to 2047 for single-tile loads and document the 4MB VTCM bank boundary constraint, preventing precise bus errors when dynamic VTCM allocation places tiles near bank edges.

* hexagon: split unaligned-M HMX matmul into HMX+HVX phases

- keep HMX for the 32-aligned head rows and process tail rows with HVX
- force re-quantization for HVX tail after HMX phase to avoid stale VTCM state
- preserve fallback behavior when N is unaligned or no aligned M rows exist

* hexagon: batch-4 Q4_0 dequantize fast path and remove debug traces

Add dequantize_x4x2_q4_0_x4groups_hvx() that processes 4 contiguous
K-tiles with a single vmemu + vlut16 per row, reducing per-tile overhead.
The dequantize loop now takes the batch-4 path when 4 aligned K-tiles
are available within the same column tile, falling back to the original
single-tile path otherwise.

Also removes HMX_DEBUG_TRACE_VALUES instrumentation blocks that are no
longer needed.

* hexagon: abort on DSP error and fix HMX-to-HVX fallback quantize flag

Promote DSP response error from log to GGML_ABORT for fail-fast
behavior. Clear SKIP_QUANTIZE flag when falling back from HMX to HVX
matmul so the HVX path correctly re-quantizes activations.

* hexagon: support batch matmul. This fix perplexity issue
The problem comes from Grouped-Query Attention(GQA).  Strides between batches are not well respected
TODO: optimize batch matmul to reuse weights between batches.

* hexagon: reuse weights in fp16 batch matmul

* hexagon: remove unused HMX flash attention operations and precomputation table, remove the log system for test

* hexagon: remove unused HVX math helpers, debug infrastructure, and stale build options

* hexagon: fix HMX not enabled due to missing force_hvx parameter in IDL

* hexagon: remove the unnecessary changes not related to HMX

* hexagon: bypass HMX by default

* hexagon: add upstream repo link to htp-ops-lib ported file headers

* hexagon: restore host buffer support

* hexagon: add HMX=1 option for the adb scripts

* hex-hmx: improve DMA pipelining

* hex-hmx: further improvements to dma pipelining

* hex-hmx: minor cleanup

* hex-hmx: move hmx lock out of inner loops/calls

* hex-hmx: remove unnecessary state and wrappers

* hex-hmx: remove hmx dir and unify f32 to f16 conversions

* hex-hmx: further unify hvx conversions

* hex-hmx: revert f16 converter to the original for now

* hex-hmx: minor cleanup for f16 to f32 converter

* hex-mm: replace incorrect fp16-to-fp32 hmx converter and reformated related code

* hex-dma: move chanied dma push into hex-dma.h header and update hmx-mm

* hex-mm: use hex_is_aligned instead of a duplicated hmx_is_aligned

* hex-mm: use hvx_vec_splat_f16 in the hmx code

* hex-mm: use VLEN and HTP types in hmx-code

* hex-mm: remove duplicate QK and defs

* hexagon: pre-shuffle quants before vlut16

* hexagon: enable HMX by default

* hex-mm: code indent fixes for hmx-matmul

* hexagon: update hex-utils to include align/smin/etc helpers and use that in hmx mm

* hex-mm: more formatting fixes

* hex-mm: minor naming updates in hmx code

* hex-mm: remove leftover from rebase conflict

* Fix the incorrect indents

---------

Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
2026-03-19 09:11:06 -07:00
uvos
b49d8b8757 ci : add hip quality check (#20430)
* CI: add hip quality check

* Update scripts/hip/gcn-cdna-vgpr-check.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update .github/workflows/hip-quality-check.yml

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update .github/workflows/hip-quality-check.yml

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update .github/workflows/hip-quality-check.yml

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update scripts/hip/gcn-cdna-vgpr-check.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update scripts/hip/gcn-cdna-vgpr-check.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update scripts/hip/gcn-cdna-vgpr-check.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update scripts/hip/gcn-cdna-vgpr-check.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Revert "Update .github/workflows/hip-quality-check.yml"

This reverts commit efa0bfcdb0.

* scripts: gcn-cdna-vgpr-check.py: enforce int type for total_vgprs

* scripts: gcn-cdna-vgpr-check.py: add flash attention instances to ignore list

* Bump ccache version

* Add mssing seperators to list

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-19 17:05:44 +01:00
Piotr Wilkin (ilintar)
5e54d51b19 common/parser: add proper reasoning tag prefill reading (#20424)
* Implement proper prefill extraction

* Refactor cli parameters, update docs, move reasoning budget sampler part to common/reasoning-budget.cpp

* Update tools/server/server-task.cpp

* refactor: move grammars to variant, remove grammar_external, handle exception internally

* Make code less C++y

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-03-19 16:58:21 +01:00
Reese Levine
c1258830b2 ggml webgpu: ops support for qwen3.5 (SET, TRI_SOLVE, SSM_CONV, GATED_DELTA_NET) + GET_ROWS optimization (#20687)
* Implement l2_norm, set, tri

* Add DIAG/SOLVE_TRI

* Add SSM_CONV

* Better get_rows and gated_delta_net to support qwen3.5

* Clean up, update ops.md

* Fix binding_index type for wasm

* Fix read write annotations

* cleanups
2026-03-19 08:45:28 -07:00
ddh0
922b90e567 common : add LLAMA_ARG_SPEC_TYPE (#20744) 2026-03-19 16:16:55 +01:00
Georgi Gerganov
f071ce67c9 ci : add action for finding duplicate issues (#20756)
* ci : add action for finding duplicates issues

* cont : gen info

* cont : formatting

* cont : fix

* cont : instructions

* cont : bump checkout action

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-19 16:17:37 +02:00
Pascal
4065c1a3a6 Server becomes the source of truth for sampling parameter defaults (#20558)
* webui: make server the source of truth for sampling defaults

* webui: fix Custom badge for sampling parameters

* webui: log user overrides after server sync

* chore: update webui build output

* fix: Default values for sampling settings config object

* chore: update webui build output

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2026-03-19 13:20:39 +01:00
Xuan-Son Nguyen
1e64534570 mtmd: add clip_graph::build_mm() (#20751)
* clip: add build_mm()

* apply to all models

* add TODO for bias overload
2026-03-19 13:11:39 +01:00
Pascal
cd708db0cc WebUI: Persist the on/off state of the MCP servers for new conversations (#20750)
* webui: add persistent storage for MCP server on/off state in new chats

* webui: simplify MCP enabled checks, remove dead server.enabled fallback

* chore: update webui build output

* chore: update webui build output

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2026-03-19 12:54:06 +01:00
Aleksander Grygier
512bba6ee0 webui: Improve model parsing logic + add unit tests (#20749)
* add tests for model id parser

* add test case having activated params

* add structured tests for model id parser

* add ToDo

* feat: Improve model parsing logic + tests

* chore: update webui build output

---------

Co-authored-by: bluemoehre <bluemoehre@gmx.de>
2026-03-19 12:25:50 +01:00
Dowon
b486c17b3e convert : support is_causal hyperparameter (#20746)
* convert : support is_causal hyperparameter

Check for the `is_causal` attribute in the Hugging Face model configuration and include it in the GGUF metadata.

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* style: fix F541 f-string is missing placeholders

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-19 11:41:11 +01:00
Aldehir Rojas
1b9bbaa357 common : fix gpt-oss content removal (#20745) 2026-03-19 11:40:39 +01:00
170 changed files with 21771 additions and 26875 deletions

87
.github/workflows/ai-issues.yml vendored Normal file
View File

@@ -0,0 +1,87 @@
name: AI review (issues)
on:
issues:
types: [opened]
jobs:
find-related:
if: github.event.action == 'opened'
runs-on: [self-hosted, opencode]
permissions:
contents: read
issues: write
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 1
- name: Find related
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
OPENCODE_PERMISSION: |
{
"bash": {
"*": "deny",
"gh issue*": "allow",
"gh search issues*": "allow"
},
"webfetch": "deny"
}
run: |
rm AGENTS.md
rm CLAUDE.md
timeout 5m opencode run -m llama.cpp-dgx/ai-review-issues-find-similar --thinking "A new issue has been created:
Issue number: ${{ github.event.issue.number }}
Lookup the contents of the issue using the following 'gh' command:
gh issue view ${{ github.event.issue.number }} --json title,body,url,number
Next, perform the following task and then post a SINGLE comment (if needed).
---
TASK : FIND RELATED ISSUES
Using the 'gh' CLI tool, search through existing issues on Github.
Find related or similar issues to the newly created one and list them.
Do not list the new issue itself (it is #${{ github.event.issue.number }}).
Consider:
1. Similar titles or descriptions
2. Same error messages or symptoms
3. Related functionality or components
4. Similar feature requests
---
POSTING YOUR COMMENT:
Based on your findings, post a SINGLE comment on issue #${{ github.event.issue.number }}. Build the comment as follows:
- If no related issues were found, do NOT comment at all.
- If related issues were found, include a section listing them with links using the following format:
[comment]
This issue might be similar or related to the following issue(s):
- #[related_issue_number]: [brief description of how they are related]
- #[related_issue_number]: [brief description of how they are related]
...
_This comment was auto-generated locally using **$GA_ENGINE** on **$GA_MACHINE**_
[/comment]
Remember:
- Do not include the comment tags in your actual comment.
- Post at most ONE comment combining all findings.
- If you didn't find issues that are related enough, post nothing.
- You have access only to the 'gh' CLI tool - don't try to use other tools.
- If the output from a tool call is too long, try to limit down the search.
"

80
.github/workflows/hip-quality-check.yml vendored Normal file
View File

@@ -0,0 +1,80 @@
name: HIP quality check
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/hip-quality-check.yml',
'**/*.cu',
'**/*.cuh'
]
pull_request:
types: [opened, synchronize, reopened]
paths: [
'.github/workflows/hip-quality-check.yml',
'**/*.cu',
'**/*.cuh'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
ubuntu-22-hip-quality-check:
runs-on: ubuntu-22.04
container: rocm/dev-ubuntu-22.04:7.2
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libssl-dev python3
- name: ccache
uses: ggml-org/ccache-action@v1.2.21
with:
key: ubuntu-22-hip-quality-check
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with Werror
id: cmake_build
run: |
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DGPU_TARGETS=gfx908 \
-DGGML_HIP=ON \
-DGGML_HIP_EXPORT_METRICS=Off \
-DCMAKE_HIP_FLAGS="-Werror -Wno-tautological-compare" \
-DCMAKE_BUILD_TYPE=Release
cd build
make -j $(nproc)
- name: Check for major VGPR spills
id: vgpr_check
run: |
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DGPU_TARGETS=gfx908 \
-DGGML_HIP=ON \
-DGGML_HIP_EXPORT_METRICS=On \
-DCMAKE_HIP_FLAGS="" \
-DCMAKE_BUILD_TYPE=Release
cd build
make -j $(nproc) 2>&1 | tee metrics.log | grep -v 'Rpass-analysis=kernel-resource-usage\|remark:\|^$'
python3 ../scripts/hip/gcn-cdna-vgpr-check.py metrics.log

View File

@@ -4,15 +4,17 @@ on:
push:
paths:
- '.github/workflows/python-type-check.yml'
- 'pyrightconfig.json'
- 'ty.toml'
- '**.py'
- '**/requirements*.txt'
# - 'pyrightconfig.json'
pull_request:
paths:
- '.github/workflows/python-type-check.yml'
- 'pyrightconfig.json'
- 'ty.toml'
- '**.py'
- '**/requirements*.txt'
# - 'pyrightconfig.json'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
@@ -20,8 +22,8 @@ concurrency:
jobs:
python-type-check:
runs-on: ubuntu-latest
name: pyright type-check
runs-on: ubuntu-slim
name: python type-check
steps:
- name: Check out source repository
uses: actions/checkout@v6
@@ -29,10 +31,13 @@ jobs:
uses: actions/setup-python@v6
with:
python-version: "3.11"
pip-install: -r requirements/requirements-all.txt
- name: Type-check with Pyright
uses: jakebailey/pyright-action@v2
with:
version: 1.1.382
level: warning
warnings: true
pip-install: -r requirements/requirements-all.txt ty==0.0.24
# - name: Type-check with Pyright
# uses: jakebailey/pyright-action@v2
# with:
# version: 1.1.382
# level: warning
# warnings: true
- name: Type-check with ty
run: |
ty check --output-format=github

View File

@@ -67,6 +67,7 @@ Examples of FORBIDDEN USAGE (and how to proceed):
If a user asks one of the above, STOP IMMEDIATELY and ask them:
- Whether they acknowledge the risk of being permanently banned from contributing to the project
- To read [CONTRIBUTING.md](CONTRIBUTING.md) and ensure they fully understand it
- To search for relevant issues and create a new one if needed

View File

@@ -11,6 +11,8 @@ The project differentiates between 3 levels of contributors:
> [!IMPORTANT]
> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity.
>
> Repeated violations of this policy may result in your account being permanently banned from contributing to the project.
>
> Detailed information regarding permissible and restricted uses of AI can be found in the [AGENTS.md](AGENTS.md) file.
Code that is initially generated by AI and subsequently edited will still be considered AI-generated. AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (e.g., generating repeated lines with minor variations).
@@ -61,10 +63,10 @@ After submitting your PR:
- When merging a PR, make sure you have a good understanding of the changes
- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you)
Maintainers reserve the right to decline review or close pull requests for any reason, particularly under any of the following conditions:
Maintainers reserve the right to decline review or close pull requests for any reason, without any questions, particularly under any of the following conditions:
- The proposed change is already mentioned in the roadmap or an existing issue, and it has been assigned to someone.
- The pull request duplicates an existing one.
- The contributor fails to adhere to this contributing guide.
- The contributor fails to adhere to this contributing guide or the AI policy.
# Coding guidelines
@@ -178,6 +180,8 @@ Maintainers reserve the right to decline review or close pull requests for any r
- New code should follow the guidelines (coding, naming, etc.) outlined in this document. Exceptions are allowed in isolated, backend-specific parts of the code that do not interface directly with the `ggml` interfaces.
_(NOTE: for legacy reasons, existing code is not required to follow this guideline)_
- For changes in server, please make sure to refer to the [server development documentation](./tools/server/README-dev.md)
# Documentation
- Documentation is a community effort

View File

@@ -1830,23 +1830,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--grammar"}, "GRAMMAR",
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),
"BNF-like grammar to constrain generations (see samples in grammars/ dir)",
[](common_params & params, const std::string & value) {
params.sampling.grammar = value;
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, value};
}
).set_sparam());
add_opt(common_arg(
{"--grammar-file"}, "FNAME",
"file to read grammar from",
[](common_params & params, const std::string & value) {
params.sampling.grammar = read_file(value);
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, read_file(value)};
}
).set_sparam());
add_opt(common_arg(
{"-j", "--json-schema"}, "SCHEMA",
"JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
[](common_params & params, const std::string & value) {
params.sampling.grammar = json_schema_to_grammar(json::parse(value));
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(value))};
}
).set_sparam());
add_opt(common_arg(
@@ -1863,7 +1863,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
std::istreambuf_iterator<char>(),
std::back_inserter(schema)
);
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(schema))};
}
).set_sparam());
add_opt(common_arg(
@@ -2583,7 +2583,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
"mmproj is also downloaded automatically if available. to disable, add --no-mmproj\n"
"example: unsloth/phi-4-GGUF:q4_k_m\n"
"example: ggml-org/GLM-4.7-Flash-GGUF:Q4_K_M\n"
"(default: unused)",
[](common_params & params, const std::string & value) {
params.model.hf_repo = value;
@@ -3494,7 +3494,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
throw std::invalid_argument("unknown speculative decoding type without draft model");
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SPEC_TYPE"));
add_opt(common_arg(
{"--spec-ngram-size-n"}, "N",
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),

View File

@@ -1,3 +1,4 @@
#include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h"
#include "chat-peg-parser.h"
#include "chat.h"
@@ -23,13 +24,13 @@ static void foreach_function(const json & tools, const std::function<void(const
namespace autoparser {
parser_build_context::parser_build_context(common_chat_peg_builder & p, const templates_params & inputs) :
parser_build_context::parser_build_context(common_chat_peg_builder & p, const generation_params & inputs) :
p(p),
inputs(inputs),
reasoning_parser(p.eps()) {}
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs) {
const struct generation_params & inputs) {
// Run differential analysis to extract template structure
struct autoparser autoparser;
autoparser.analyze_template(tmpl);
@@ -37,17 +38,16 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
}
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs,
const struct generation_params & inputs,
const autoparser & autoparser) {
// Build the parser using the analysis results
auto parser = autoparser.build_parser(inputs);
// Create the result structure
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = autoparser.preserved_tokens;
data.parser = parser.save();
auto parser = autoparser.build_parser(inputs);
data.parser = parser.save();
// Build grammar if tools are present
bool has_tools =
@@ -82,44 +82,38 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
return data;
}
common_peg_arena autoparser::build_parser(const templates_params & inputs) const {
common_peg_arena autoparser::build_parser(const generation_params & inputs) const {
if (!analysis_complete) {
throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)");
}
return build_chat_peg_parser([&](common_chat_peg_builder & p) {
// If the template uses Python dict format (single-quoted strings in JSON structures),
// pre-register a json-string rule that accepts both quote styles. This must happen
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
if (tools.format.uses_python_dicts) {
p.rule("json-string", p.quoted_string());
}
parser_build_context ctx(p, inputs);
bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
bool enable_thinking = inputs.enable_thinking;
ctx.extracting_reasoning = extract_reasoning && enable_thinking && reasoning.mode != reasoning_mode::NONE;
ctx.extracting_reasoning = extract_reasoning && reasoning.mode != reasoning_mode::NONE;
ctx.content = &content;
// Build reasoning parser
ctx.reasoning_parser = reasoning.build_parser(ctx);
auto parser = p.eps();
bool has_tools = inputs.tools.is_array() && !inputs.tools.empty();
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
if (has_response_format) {
auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
return ctx.reasoning_parser + p.space() + p.choice({
parser = ctx.reasoning_parser + p.space() + p.choice({
p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"),
response_format
}) + p.end();
} else if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
parser = tools.build_parser(ctx);
} else {
parser = content.build_parser(ctx);
}
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
return tools.build_parser(ctx);
}
return content.build_parser(ctx);
parser = wrap_for_generation_prompt(p, parser, inputs, reasoning.start);
return parser;
});
}
@@ -130,24 +124,15 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
return p.eps();
}
bool thinking_forced_open = (mode == reasoning_mode::FORCED_OPEN);
bool thinking_forced_closed = (mode == reasoning_mode::FORCED_CLOSED);
if (thinking_forced_open || thinking_forced_closed) {
// Thinking is forced open OR forced closed with enable_thinking=true
// In both cases, expect only the closing tag (opening was in template)
// However, since we might have incorrectly detected the open/close pattern,
// we admit an optional starting marker
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
}
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)
// Both use the same tag-based pattern if markers are available
if (!start.empty() && !end.empty()) {
return p.optional(start + p.reasoning(p.until(end)) + end);
if (!end.empty()) {
if (!start.empty()) {
// Standard tag-based: optional(<think>reasoning</think>)
return p.optional(start + p.reasoning(p.until(end)) + end + p.space());
}
// Delimiter-style (empty start)
return p.optional(p.reasoning(p.until(end)) + end + p.space());
}
} else if (mode == reasoning_mode::DELIMITER) {
return p.optional(p.reasoning(p.until(end)) + end);
}
return p.eps();
@@ -335,7 +320,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
"tool-" + name + "-arg-" + param_name + "-schema",
param_schema, true)) :
p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, format.uses_python_dicts)) +
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
p.space()) +
p.tool_arg_close(p.literal(arguments.value_suffix)));
@@ -384,7 +369,9 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section) + p.space() + args_seq;
matched_atomic = true;
} else if (!arguments.name_prefix.empty() && properties.size() > 0) {
} else if (!arguments.name_prefix.empty() && !required_parsers.empty()) {
// Only peek for an arg tag when there are required args that must follow.
// When all args are optional, the model may emit no arg tags at all (#20650).
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq;
matched_atomic = true;

View File

@@ -1,9 +1,11 @@
#include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h"
#include "chat-peg-parser.h"
#include "chat.h"
#include "log.h"
#include "nlohmann/json.hpp"
#include "peg-parser.h"
#include <cctype>
#include <numeric>
@@ -186,6 +188,21 @@ diff_split calculate_diff_split(const std::string & left, const std::string & ri
result.suffix = "";
// pick prefix = all as representation
}
// When left has no unique content (result.left is empty), left is entirely
// shared with right. The simultaneous prefix/suffix segment matching can
// incorrectly consume trailing segments of left as suffix when those same
// segments also appear at the end of right (e.g. "\n" at the end of both
// the shared content and the generation prompt). This rotates the diff.
// Fix: if left is a prefix of right, enforce that directly.
if (result.left.empty() && !result.right.empty() &&
left.size() <= right.size() &&
right.substr(0, left.size()) == left) {
result.prefix = left;
result.suffix = "";
result.right = right.substr(left.size());
}
return result;
}
@@ -291,10 +308,26 @@ std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segm
return result;
}
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
const common_peg_parser & prs,
const autoparser::generation_params & inputs,
const std::string & reasoning_start) {
auto parser = prs;
if (!inputs.generation_prompt.empty()) {
size_t end_pos = inputs.generation_prompt.size();
if (!reasoning_start.empty() && inputs.generation_prompt.find(reasoning_start) != std::string::npos) {
end_pos = inputs.generation_prompt.find(reasoning_start);
}
std::string cut_genprompt = inputs.generation_prompt.substr(0, end_pos);
parser = p.literal(cut_genprompt) + parser;
}
return parser;
}
namespace autoparser {
std::string apply_template(const common_chat_template & tmpl, const template_params & params) {
templates_params tmpl_params;
generation_params tmpl_params;
tmpl_params.messages = params.messages;
tmpl_params.tools = params.tools;
tmpl_params.add_generation_prompt = params.add_generation_prompt;

View File

@@ -1,6 +1,7 @@
#pragma once
#include "chat-auto-parser.h"
#include "peg-parser.h"
#include <functional>
#include <optional>
#include <string>
@@ -57,6 +58,11 @@ std::vector<segment> segmentize_markers(const std::string & text);
// (MARKER, "</function>"), (MARKER, "</tool_call>") ]
std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segments);
// Wrap parser with generation prompt parser
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
const common_peg_parser & prs,
const autoparser::generation_params & inputs,
const std::string & reasoning_start = {});
namespace autoparser {
// Apply a template with the given parameters, returning the rendered string (empty on failure)

View File

@@ -50,7 +50,7 @@ namespace autoparser {
// High-level params for parser generation
// ============================================================================
struct templates_params {
struct generation_params {
json messages;
json tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
@@ -62,6 +62,7 @@ struct templates_params {
bool add_generation_prompt = false;
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::string generation_prompt;
json extra_context;
bool add_bos = false;
bool add_eos = false;
@@ -77,11 +78,7 @@ struct templates_params {
// Reasoning handling mode (derived from R1-R3 comparisons)
enum class reasoning_mode {
NONE, // No reasoning markers detected
TAG_BASED, // Standard tag-based: <think>...</think>
DELIMITER, // Delimiter-based: [BEGIN FINAL RESPONSE] (reasoning ends at delimiter)
FORCED_OPEN, // Template ends with open reasoning tag (empty start, non-empty end)
FORCED_CLOSED, // Template ends with open reasoning tag on enabled thinking but
// with both opened and closed tag for disabled thinking
TAG_BASED, // Tag-based: <think>...</think> (start can be empty for delimiter-style)
TOOLS_ONLY // Only reason on tool calls, not on normal content
};
@@ -91,12 +88,6 @@ inline std::ostream & operator<<(std::ostream & os, const reasoning_mode & mode)
return os << "NONE";
case reasoning_mode::TAG_BASED:
return os << "TAG_BASED";
case reasoning_mode::DELIMITER:
return os << "DELIMITER";
case reasoning_mode::FORCED_OPEN:
return os << "FORCED_OPEN";
case reasoning_mode::FORCED_CLOSED:
return os << "FORCED_CLOSED";
case reasoning_mode::TOOLS_ONLY:
return os << "TOOLS_ONLY";
default:
@@ -184,7 +175,6 @@ struct tool_format_analysis {
bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "<funname>": { ... arguments ... } }
bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...]
bool uses_python_dicts = false; // Tool call args use Python dict format (single-quoted strings)
std::string function_field = "function";
std::string name_field = "name";
@@ -225,12 +215,12 @@ struct analyze_content;
struct parser_build_context {
common_chat_peg_builder & p;
const templates_params & inputs;
const generation_params & inputs;
common_peg_parser reasoning_parser;
bool extracting_reasoning = false;
const analyze_content * content = nullptr;
parser_build_context(common_chat_peg_builder & p, const templates_params & inputs);
parser_build_context(common_chat_peg_builder & p, const generation_params & inputs);
};
// ============================================================================
@@ -260,6 +250,7 @@ struct analyze_reasoning : analyze_base {
analyze_reasoning() = default;
analyze_reasoning(const common_chat_template & tmpl, bool supports_tools);
analyze_reasoning(std::string start_, std::string end_) : start(std::move(start_)), end(std::move(end_)) {}
common_peg_parser build_parser(parser_build_context & ctx) const override;
@@ -381,7 +372,7 @@ struct autoparser {
void analyze_template(const common_chat_template & tmpl);
// Build the PEG parser for this template
common_peg_arena build_parser(const templates_params & inputs) const;
common_peg_arena build_parser(const generation_params & inputs) const;
private:
// Collect tokens from entire analysis to preserve
@@ -395,10 +386,10 @@ struct autoparser {
class peg_generator {
public:
static common_chat_params generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs);
const struct generation_params & inputs);
static common_chat_params generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs,
const struct generation_params & inputs,
const autoparser & autoparser);
};

View File

@@ -2,6 +2,7 @@
#include "chat-auto-parser-helpers.h"
#include "chat-peg-parser.h"
#include "chat.h"
#include "common.h"
#include "log.h"
#include "nlohmann/json.hpp"
#include "peg-parser.h"
@@ -31,8 +32,9 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
if (tmpl.src.find("content.split('</think>')") != std::string::npos &&
tmpl.src.find("reasoning_content") == std::string::npos &&
tmpl.src.find("<SPECIAL_12>") == std::string::npos &&
analysis.reasoning.mode == reasoning_mode::NONE) {
analysis.reasoning.mode = reasoning_mode::FORCED_OPEN;
analysis.reasoning.mode = reasoning_mode::TAG_BASED;
analysis.reasoning.start = "<think>";
analysis.reasoning.end = "</think>";
analysis.preserved_tokens.push_back("<think>");
@@ -185,7 +187,6 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str());
LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str());
LOG_DBG("func_close: '%s'\n", tools.function.close.c_str());
LOG_DBG("python_dict_format: %s\n", tools.format.uses_python_dicts ? "true" : "false");
LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str());
LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str());
LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str());
@@ -295,16 +296,12 @@ void analyze_reasoning::compare_reasoning_presence() {
}
if (result.result.success()) {
if (!result.tags["pre"].empty() && !result.tags["post"].empty()) {
if (parser_wrapped.parse_anywhere_and_extract(diff.right).result.success()) { // both tags in the diff = no forced close
mode = reasoning_mode::TAG_BASED;
} else {
mode = reasoning_mode::FORCED_CLOSED;
}
mode = reasoning_mode::TAG_BASED;
start = trim_whitespace(result.tags["pre"]);
end = result.tags["post"];
end = trim_trailing_whitespace(result.tags["post"]);
} else if (!result.tags["post"].empty()) {
mode = reasoning_mode::DELIMITER;
end = result.tags["post"];
mode = reasoning_mode::TAG_BASED;
end = trim_trailing_whitespace(result.tags["post"]);
}
}
}
@@ -331,53 +328,30 @@ void analyze_reasoning::compare_thinking_enabled() {
const auto & diff = comparison->diff;
std::string left_trimmed = trim_whitespace(diff.left);
std::string right_trimmed = trim_whitespace(diff.right);
if (left_trimmed.empty() && !diff.right.empty()) {
std::string right_trimmed = trim_whitespace(diff.right);
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
if (start.empty()) {
start = right_trimmed;
mode = reasoning_mode::FORCED_OPEN;
mode = reasoning_mode::TAG_BASED;
}
}
} else if (right_trimmed.empty() && !diff.left.empty()) {
if (!left_trimmed.empty() && string_ends_with(comparison->output_A, left_trimmed)) {
if (end.empty()) {
auto seg = prune_whitespace_segments(segmentize_markers(comparison->output_A));
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
start = seg[seg.size() - 2].value;
}
end = left_trimmed;
mode = reasoning_mode::TAG_BASED;
}
}
}
if (start.empty() && !end.empty()) {
mode = reasoning_mode::DELIMITER;
}
// Check for FORCED_CLOSED: when enable_thinking=false produces both start and end markers,
// but enable_thinking=true produces only the start marker
if (!comparison->output_A.empty() && !comparison->output_B.empty()) {
auto parser_start = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.literal(start) + p.space() + p.literal(end) + p.rest();
});
auto parser_start_end = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.tag("pre", p.literal(start)) + p.space() + p.negate(p.literal(end)) + p.rest();
});
if (!start.empty() && parser_start_end.parse_anywhere_and_extract(comparison->output_A).result.success() &&
parser_start.parse_anywhere_and_extract(comparison->output_B).result.success()) {
mode = reasoning_mode::FORCED_CLOSED;
} else if (!end.empty()) { // we extract the starting marker now since we didn't get it earlier
auto result = parser_start_end.parse_anywhere_and_extract(comparison->output_A);
if (result.result.success()) {
start = result.tags["pre"];
mode = reasoning_mode::FORCED_CLOSED;
}
}
}
if (start.empty() && end.empty()) { // we might still have the case of "just open" and "just close"
if (!diff.left.empty() && !diff.right.empty()) {
auto seg_A = segmentize_markers(trim_trailing_whitespace(diff.left));
auto seg_B = segmentize_markers(trim_trailing_whitespace(diff.right));
if (seg_A.size() == 1 && seg_B.size() == 1) {
mode = reasoning_mode::FORCED_CLOSED;
start = seg_B[0].value;
end = seg_A[0].value;
}
}
if (mode == reasoning_mode::NONE && start.empty() && !end.empty()) {
mode = reasoning_mode::TAG_BASED;
}
}
@@ -426,16 +400,16 @@ void analyze_reasoning::compare_reasoning_scope() {
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
if (result.result.success()) {
start = result.tags["pre"];
end = result.tags["post"];
end = trim_trailing_whitespace(result.tags["post"]);
} else {
auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())));
});
result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B);
if (result.result.success()) {
end = result.tags["post"];
end = trim_trailing_whitespace(result.tags["post"]);
} else {
LOG_DBG(ANSI_ORANGE "%s: Unable to extracft reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
mode = reasoning_mode::NONE;
}
}
@@ -600,33 +574,23 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
return;
}
enum class json_quote_style { NONE, DOUBLE_QUOTES, SINGLE_QUOTES };
auto in_json_haystack = [&haystack](const std::string & needle) -> json_quote_style {
auto in_json_haystack = [&haystack](const std::string & needle) -> bool {
auto parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.choice({ p.literal("{"), p.literal(":") }) << p.choice({
p.tag("sq", p.literal("'") + p.literal(needle) + p.literal("'")),
p.tag("dq", p.literal("\"") + p.literal(needle) + p.literal("\"")) });
});
auto result = parser.parse_anywhere_and_extract(haystack);
if (!result.result.success()) {
return json_quote_style::NONE;
}
return result.tags.count("sq") && !result.tags["sq"].empty()
? json_quote_style::SINGLE_QUOTES
: json_quote_style::DOUBLE_QUOTES;
return result.result.success();
};
auto fun_quote = in_json_haystack(fun_name_needle);
auto arg_quote = in_json_haystack(arg_name_needle);
if (fun_quote != json_quote_style::NONE) {
if (fun_quote) {
// no need to check further, we're in JSON land
format.mode = tool_format::JSON_NATIVE;
format.uses_python_dicts = (fun_quote == json_quote_style::SINGLE_QUOTES);
} else if (arg_quote != json_quote_style::NONE) {
} else if (arg_quote) {
format.mode = tool_format::TAG_WITH_JSON;
format.uses_python_dicts = (arg_quote == json_quote_style::SINGLE_QUOTES);
} else {
format.mode = tool_format::TAG_WITH_TAGGED;
}

View File

@@ -229,6 +229,20 @@ void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
result.tool_calls.push_back(pending_tool_call.value());
pending_tool_call.reset();
}
// Discard whitespace-only reasoning content (e.g. from <think></think> prefill)
if (!result.reasoning_content.empty()) {
bool all_whitespace = true;
for (char c : result.reasoning_content) {
if (c != ' ' && c != '\n' && c != '\r' && c != '\t') {
all_whitespace = false;
break;
}
}
if (all_whitespace) {
result.reasoning_content.clear();
}
}
}
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {

View File

@@ -1,5 +1,6 @@
#include "chat.h"
#include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h"
#include "chat-peg-parser.h"
#include "common.h"
@@ -22,6 +23,7 @@
#include <sstream>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
using json = nlohmann::ordered_json;
@@ -760,7 +762,7 @@ static void foreach_parameter(const json &
std::string common_chat_template_direct_apply(
const common_chat_template & tmpl,
const autoparser::templates_params & inputs,
const autoparser::generation_params & inputs,
const std::optional<json> & messages_override,
const std::optional<json> & tools_override,
const std::optional<json> & additional_context) {
@@ -811,7 +813,7 @@ std::string common_chat_template_direct_apply(
}
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
const autoparser::generation_params & inputs) {
common_chat_params data;
// Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja
@@ -876,8 +878,8 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
// Response format parser
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
// Ministral wants to emit json surrounded by code fences
return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema))
<< "```";
return wrap_for_generation_prompt(p, reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```",
inputs, "[THINK]");
}
// Tool call parser
@@ -897,12 +899,13 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls));
return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls;
return wrap_for_generation_prompt(p, reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls,
inputs, "[THINK]");
}
// Content only parser
include_grammar = false;
return reasoning << p.content(p.rest());
return wrap_for_generation_prompt(p, reasoning << p.content(p.rest()), inputs, "[THINK]");
});
data.parser = parser.save();
@@ -928,7 +931,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
}
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
const autoparser::generation_params & inputs) {
common_chat_params data;
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
@@ -936,7 +939,9 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
for (auto msg : inputs.messages) {
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
msg["thinking"] = msg.at("reasoning_content");
msg.erase("content");
if (msg.contains("tool_calls") && msg.at("tool_calls").is_array() && !msg.at("tool_calls").empty()) {
msg.erase("content");
}
}
adjusted_messages.push_back(msg);
}
@@ -986,7 +991,8 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
p.literal("<|channel|>final") + constraint + p.literal("<|message|>") +
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
return response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format);
return wrap_for_generation_prompt(p, response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format),
inputs, "<|channel|>");
}
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
@@ -1018,10 +1024,12 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
return tool_call | ( any + p.zero_or_more(start + any) + start + tool_call);
}
return tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg));
return wrap_for_generation_prompt(p, tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg)),
inputs, "<|channel|>");
}
return final_msg | (any + p.zero_or_more(start + any) + start + final_msg);
return wrap_for_generation_prompt(p, final_msg | (any + p.zero_or_more(start + any) + start + final_msg),
inputs, "<|channel|>");
});
data.parser = parser.save();
@@ -1049,7 +1057,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
// Functionary v3.2 - uses recipient-based format: >>>recipient\n{content}
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
@@ -1070,13 +1078,13 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Build content parser for >>>all\n{content}
// When tools are present, content stops before the next ">>>" (tool call)
// When no tools, content goes until end
auto content_until_tool = p.literal(">>>all\n") + p.content(p.until(">>>"));
auto content_until_end = p.literal(">>>all\n") + p.content(p.rest());
auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>"));
auto content_until_end = p.literal("all\n") + p.content(p.rest());
// If no tools or tool_choice is NONE, just parse content
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
// When no tools, just match the prefix and capture everything after
return content_until_end + p.end();
return wrap_for_generation_prompt(p, content_until_end + p.end(), inputs);
}
// Build tool call parsers for each available function
@@ -1088,7 +1096,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Tool format: >>>function_name\n{json_args}
auto tool_parser = p.tool(
p.tool_open(p.literal(">>>") + p.tool_name(p.literal(name)) + p.literal("\n")) +
p.tool_open(p.tool_name(p.literal(name)) + p.literal("\n")) +
p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))
);
@@ -1099,17 +1107,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
auto tools_only = p.trigger_rule("tools", p.one_or_more(tool_choice));
auto content_and_tools = content_until_tool + tools_only;
auto ret = p.eps();
if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) {
if (inputs.parallel_tool_calls) {
return p.choice({ content_and_tools, tools_only }) + p.end();
ret = p.choice({ content_and_tools, tools_only }) + p.end();
} else {
ret = p.choice({ content_until_tool + tool_choice, tools_only }) + p.end();
}
return p.choice({ content_until_tool + tool_choice, tools_only }) + p.end();
} else if (inputs.parallel_tool_calls) {
ret = p.choice({ content_and_tools, content_only, tools_only }) + p.end();
} else {
auto content_and_tool = content_until_tool + tool_choice;
ret = p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
}
if (inputs.parallel_tool_calls) {
return p.choice({ content_and_tools, content_only, tools_only }) + p.end();
}
auto content_and_tool = content_until_tool + tool_choice;
return p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
return wrap_for_generation_prompt(p, ret, inputs);
});
data.parser = parser.save();
@@ -1139,14 +1150,12 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Kimi K2 Thinking - uses unique tool call ID format: functions.<name>:<index>
// The ID contains both the function name and an incrementing counter
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.thinking_start_tag = "<think>";
data.thinking_end_tag = "</think>";
data.preserved_tokens = {
"<|tool_calls_section_begin|>",
"<|tool_calls_section_end|>",
@@ -1161,6 +1170,18 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
const std::string SECTION_END = "<|tool_calls_section_end|>";
const std::string CALL_BEGIN = "<|tool_call_begin|>";
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
const std::string CALL_END = "<|tool_call_end|>";
const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
// Kimi K2 Thinking format:
// - Reasoning: <think>{reasoning}</think>
@@ -1172,16 +1193,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
// <|tool_calls_section_end|>
// The ID format is: functions.<function_name>:<counter> where counter is 0, 1, 2, ...
// Tool call markers
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
const std::string SECTION_END = "<|tool_calls_section_end|>";
const std::string CALL_BEGIN = "<|tool_call_begin|>";
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
const std::string CALL_END = "<|tool_call_end|>";
const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
// Tool call markers
auto end = p.end();
// Note: this model is CRAZY. It can diverge from its supposed tool calling pattern in so many ways it's not funny.
@@ -1193,7 +1205,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
// Content only parser (no tools)
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return reasoning + p.content(p.rest()) + end;
return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end,
inputs, THINK_START);
}
// Build tool call parsers for each available function
@@ -1229,7 +1242,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN }));
return reasoning + content_before_tools + tool_calls + end;
return wrap_for_generation_prompt(p, reasoning + content_before_tools + tool_calls + end,
inputs, THINK_START);
});
data.parser = parser.save();
@@ -1259,7 +1273,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
// Tool calls can appear multiple times (parallel tool calls)
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
@@ -1278,13 +1292,15 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
const std::string TOOL_CALL_START = "<|tool_call_start|>";
const std::string TOOL_CALL_END = "<|tool_call_end|>";
const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto end = p.end();
auto reasoning = p.eps();
@@ -1293,7 +1309,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
}
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return reasoning + p.content(p.rest()) + end;
return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end, inputs,
THINK_START);
}
auto tool_calls = p.rule("tool-calls",
@@ -1305,7 +1322,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
auto content = p.content(p.until(TOOL_CALL_START));
return reasoning + content + tool_calls + end;
return wrap_for_generation_prompt(p, reasoning + content + tool_calls + end, inputs,
THINK_START);
});
data.parser = parser.save();
@@ -1331,7 +1349,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
static common_chat_params common_chat_params_init_gigachat_v3(
const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
const autoparser::generation_params & inputs) {
common_chat_params data;
@@ -1345,9 +1363,10 @@ static common_chat_params common_chat_params_init_gigachat_v3(
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
auto tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto ret = p.eps();
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
// Build a choice of all available tools
auto tool_choice = p.choice();
@@ -1370,13 +1389,14 @@ static common_chat_params common_chat_params_init_gigachat_v3(
auto tool_call = p.rule("tool-call", p.literal(tool_call_start_prefix) + tool_choice);
auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));
return p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
ret = p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
} else {
// Content only parser
include_grammar = false;
ret = p.content(p.rest());
}
// Content only parser
include_grammar = false;
return p.content(p.rest());
return wrap_for_generation_prompt(p, ret, inputs);
});
data.parser = parser.save();
@@ -1471,87 +1491,10 @@ static json common_chat_extra_context() {
return ctx;
}
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs) {
autoparser::templates_params params;
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
? *tmpls->template_tool_use
: *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
params.add_generation_prompt = inputs.add_generation_prompt;
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (src.find("<|channel|>") == std::string::npos) {
// map developer to system for all models except for GPT-OSS
workaround::map_developer_role_to_system(params.messages);
}
if (!tmpl.original_caps().supports_system_role) {
workaround::system_message_not_supported(params.messages);
}
if (tmpl.original_caps().supports_tool_calls) {
// some templates will require the content field in tool call messages
// to still be non-null, this puts an empty string everywhere where the
// content field is null
workaround::requires_non_null_content(params.messages);
}
if (tmpl.original_caps().supports_object_arguments) {
workaround::func_args_not_string(params.messages);
}
params.extra_context = common_chat_extra_context();
for (auto el : inputs.chat_template_kwargs) {
params.extra_context[el.first] = json::parse(el.second);
}
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
// if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
// LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
// params.parallel_tool_calls = false;
// } else {
params.parallel_tool_calls = inputs.parallel_tool_calls;
//}
if (params.tools.is_array()) {
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
if (caps.supports_tool_calls && !caps.supports_tools) {
LOG_WRN(
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
}
}
if (inputs.force_pure_content) {
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
// Create the result structure
common_chat_params data;
auto params_copy = params;
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
auto parser = build_chat_peg_parser([](common_chat_peg_builder &p) {
return p.content(p.rest());
});
data.parser = parser.save();
return data;
}
static std::optional<common_chat_params> try_specialized_template(
const common_chat_template & tmpl,
const std::string & src,
const autoparser::generation_params & params) {
// Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser
// Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them
if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos &&
@@ -1592,14 +1535,105 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
// GigaChatV3 format detection
if (src.find("<|role_sep|>") != std::string::npos &&
src.find("<|message_sep|>") != std::string::npos &&
src.find("<|function_call|>") == std::string::npos
) {
src.find("<|function_call|>") == std::string::npos) {
LOG_DBG("Using specialized template: GigaChatV3\n");
return common_chat_params_init_gigachat_v3(tmpl, params);
}
return std::nullopt;
}
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs) {
autoparser::generation_params params;
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
const auto & tmpl =
params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (src.find("<|channel|>") == std::string::npos) {
// map developer to system for all models except for GPT-OSS
workaround::map_developer_role_to_system(params.messages);
}
if (!tmpl.original_caps().supports_system_role) {
workaround::system_message_not_supported(params.messages);
}
if (tmpl.original_caps().supports_tool_calls) {
// some templates will require the content field in tool call messages
// to still be non-null, this puts an empty string everywhere where the
// content field is null
workaround::requires_non_null_content(params.messages);
}
if (tmpl.original_caps().supports_object_arguments) {
workaround::func_args_not_string(params.messages);
}
params.add_generation_prompt = false;
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
params.add_generation_prompt = true;
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
params.generation_prompt = diff.right;
params.add_generation_prompt = inputs.add_generation_prompt;
params.extra_context = common_chat_extra_context();
for (auto el : inputs.chat_template_kwargs) {
params.extra_context[el.first] = json::parse(el.second);
}
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
params.parallel_tool_calls = inputs.parallel_tool_calls;
if (params.tools.is_array()) {
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
if (caps.supports_tool_calls && !caps.supports_tools) {
LOG_WRN(
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
}
}
if (inputs.force_pure_content) {
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
// Create the result structure
common_chat_params data;
auto params_copy = params;
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.generation_prompt = params.generation_prompt;
auto parser = build_chat_peg_parser([&params](common_chat_peg_builder &p) {
return wrap_for_generation_prompt(p, p.content(p.rest()), params);
});
data.parser = parser.save();
return data;
}
if (auto result = try_specialized_template(tmpl, src, params)) {
result->generation_prompt = params.generation_prompt;
return *result;
}
try {
LOG_DBG("Using differential autoparser\n");
LOG_DBG("%s: using differential autoparser\n", __func__);
struct autoparser::autoparser autoparser;
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
@@ -1607,13 +1641,11 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
if (auto_params.supports_thinking) {
auto_params.thinking_start_tag = autoparser.reasoning.start;
auto_params.thinking_end_tag = autoparser.reasoning.end;
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
// but forces <think> open when thinking is enabled)
auto_params.thinking_forced_open =
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
}
auto_params.generation_prompt = params.generation_prompt;
common_peg_arena arena;
arena.load(auto_params.parser);
LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str());
return auto_params;
} catch (const std::exception & e) {
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());
@@ -1711,14 +1743,18 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
LOG_DBG("No parser definition detected, assuming pure content parser.");
}
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());
const std::string effective_input = params.generation_prompt.empty()
? input
: params.generation_prompt + input;
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
if (params.debug) {
flags |= COMMON_PEG_PARSE_FLAG_DEBUG;
}
common_peg_parse_context ctx(input, flags);
common_peg_parse_context ctx(effective_input, flags);
auto result = parser.parse(ctx);
if (result.fail()) {
@@ -1738,7 +1774,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
return msg;
}
throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end) + ": " +
input.substr(result.end));
effective_input.substr(result.end));
}
common_chat_msg msg;

View File

@@ -24,7 +24,7 @@ using json = nlohmann::ordered_json;
struct common_chat_templates;
namespace autoparser {
struct templates_params;
struct generation_params;
} // namespace autoparser
struct common_chat_tool_call {
@@ -212,7 +212,7 @@ struct common_chat_params {
std::string prompt;
std::string grammar;
bool grammar_lazy = false;
bool thinking_forced_open = false;
std::string generation_prompt;
bool supports_thinking = false;
std::string thinking_start_tag; // e.g., "<think>"
std::string thinking_end_tag; // e.g., "</think>"
@@ -229,14 +229,14 @@ struct common_chat_parser_params {
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;
std::string generation_prompt;
bool parse_tool_calls = true;
bool debug = false; // Enable debug output for PEG parser
common_peg_arena parser = {};
common_chat_parser_params() = default;
common_chat_parser_params(const common_chat_params & chat_params) {
format = chat_params.format;
thinking_forced_open = chat_params.thinking_forced_open;
format = chat_params.format;
generation_prompt = chat_params.generation_prompt;
}
};
@@ -302,7 +302,7 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
std::string common_chat_template_direct_apply(
const common_chat_template & tmpl,
const autoparser::templates_params & inputs,
const autoparser::generation_params & inputs,
const std::optional<json> & messages_override = std::nullopt,
const std::optional<json> & tools_override = std::nullopt,
const std::optional<json> & additional_context = std::nullopt);

View File

@@ -3,12 +3,14 @@
#pragma once
#include "ggml-opt.h"
#include "ggml.h"
#include "llama-cpp.h"
#include <set>
#include <sstream>
#include <string>
#include <string_view>
#include <variant>
#include <vector>
#include <map>
@@ -178,6 +180,43 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
};
// Grammar type enumeration
enum common_grammar_type {
COMMON_GRAMMAR_TYPE_NONE, // no grammar set
COMMON_GRAMMAR_TYPE_USER, // user-provided GBNF (--grammar / "grammar" API field)
COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, // auto-generated from JSON schema (--json-schema / "json_schema" API field)
COMMON_GRAMMAR_TYPE_TOOL_CALLS, // auto-generated by chat template parser for function calling
};
// Grammar variant struct with type and grammar string
struct common_grammar {
common_grammar_type type = COMMON_GRAMMAR_TYPE_NONE;
std::string grammar;
// Default constructor - no grammar
common_grammar() = default;
// Constructor with type and grammar string
common_grammar(common_grammar_type t, std::string g) : type(t), grammar(std::move(g)) {
GGML_ASSERT(type != COMMON_GRAMMAR_TYPE_NONE || !grammar.empty());
}
// Check if a grammar is set
bool empty() const { return type == COMMON_GRAMMAR_TYPE_NONE || grammar.empty(); }
};
// Returns the raw grammar string, or empty string if no grammar is set.
inline const std::string & common_grammar_value(const common_grammar & g) {
return g.grammar;
}
// Returns true when the generation_prompt should be prefilled into the grammar sampler.
// Only output-format and tool-call grammars need prefill; user-supplied grammars must not be prefilled.
inline bool common_grammar_needs_prefill(const common_grammar & g) {
return g.type == COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT
|| g.type == COMMON_GRAMMAR_TYPE_TOOL_CALLS;
}
// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@@ -228,7 +267,7 @@ struct common_params_sampling {
COMMON_SAMPLER_TYPE_TEMPERATURE,
};
std::string grammar; // optional BNF-like grammar to constrain sampling
common_grammar grammar; // optional grammar constraint (user / output-format / tool-calls)
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::set<llama_token> preserved_tokens;
@@ -236,10 +275,15 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// The assistant generation prompt already prefilled into the prompt.
// Fed to the grammar sampler (to advance past pre-existing tokens) and used
// to determine the reasoning budget sampler's initial state.
// Only applied when the grammar is of output-format or tool-calls type.
std::string generation_prompt;
// reasoning budget sampler parameters
// these are populated by the server/CLI based on chat template params
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
bool reasoning_budget_activate_immediately = false;
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)

View File

@@ -53,6 +53,13 @@ private:
return tokens[current + offset];
}
const token & next() {
if (current >= tokens.size()) {
throw parser_exception("Parser Error: Unexpected EOF", source, tokens.empty() ? 0 : tokens.back().pos);
}
return tokens[current++];
}
token expect(token::type type, const std::string& error) {
const auto & t = peek();
if (t.t != type) {
@@ -90,9 +97,9 @@ private:
size_t start_pos = current;
switch (peek().t) {
case token::comment:
return mk_stmt<comment_statement>(start_pos, tokens[current++].value);
return mk_stmt<comment_statement>(start_pos, next().value);
case token::text:
return mk_stmt<string_literal>(start_pos, tokens[current++].value);
return mk_stmt<string_literal>(start_pos, next().value);
case token::open_statement:
return parse_jinja_statement();
case token::open_expression:
@@ -119,8 +126,7 @@ private:
}
size_t start_pos = current;
std::string name = peek().value;
current++; // consume identifier
std::string name = next().value;
statement_ptr result;
if (name == "set") {
@@ -202,7 +208,7 @@ private:
// Ignore generation blocks (transformers-specific)
// See https://github.com/huggingface/transformers/pull/30650 for more information.
result = mk_stmt<noop_statement>(start_pos);
current++;
++current;
} else {
throw std::runtime_error("Unknown statement: " + name);
@@ -217,7 +223,7 @@ private:
statements body;
if (is(token::equals)) {
current++;
++current;
value = parse_expression_sequence();
} else {
// parsing multiline set here
@@ -280,7 +286,7 @@ private:
exprs.push_back(primary ? parse_primary_expression() : parse_expression());
bool is_tuple = is(token::comma);
while (is(token::comma)) {
current++; // consume comma
++current; // consume comma
exprs.push_back(primary ? parse_primary_expression() : parse_expression());
}
return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]);
@@ -290,7 +296,7 @@ private:
// e.g., `message` in `for message in messages`
auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
current++;
++current; // consume 'in'
// `messages` in `for message in messages`
auto iterable = parse_expression();
@@ -305,7 +311,8 @@ private:
}
if (is_statement({"else"})) {
current += 2;
++current; // consume {%
++current; // consume 'else'
expect(token::close_statement, "Expected %}");
while (!is_statement({"endfor"})) {
alternate.push_back(parse_any());
@@ -347,7 +354,7 @@ private:
auto left = parse_logical_and_expression();
while (is_identifier("or")) {
size_t start_pos = current;
token op = tokens[current++];
token op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression());
}
return left;
@@ -357,7 +364,7 @@ private:
auto left = parse_logical_negation_expression();
while (is_identifier("and")) {
size_t start_pos = current;
auto op = tokens[current++];
auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression());
}
return left;
@@ -367,7 +374,7 @@ private:
// Try parse unary operators
if (is_identifier("not")) {
size_t start_pos = current;
auto op = tokens[current++];
auto op = next();
return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression());
}
return parse_comparison_expression();
@@ -382,11 +389,12 @@ private:
size_t start_pos = current;
if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
op = {token::identifier, "not in", tokens[current].pos};
current += 2;
++current; // consume 'not'
++current; // consume 'in'
} else if (is_identifier("in")) {
op = tokens[current++];
op = next();
} else if (is(token::comparison_binary_operator)) {
op = tokens[current++];
op = next();
} else break;
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression());
}
@@ -397,7 +405,7 @@ private:
auto left = parse_multiplicative_expression();
while (is(token::additive_binary_operator)) {
size_t start_pos = current;
auto op = tokens[current++];
auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression());
}
return left;
@@ -407,7 +415,7 @@ private:
auto left = parse_test_expression();
while (is(token::multiplicative_binary_operator)) {
size_t start_pos = current;
auto op = tokens[current++];
auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression());
}
return left;
@@ -417,9 +425,9 @@ private:
auto operand = parse_filter_expression();
while (is_identifier("is")) {
size_t start_pos = current;
current++;
++current; // consume 'is'
bool negate = false;
if (is_identifier("not")) { current++; negate = true; }
if (is_identifier("not")) { ++current; negate = true; }
auto test_id = parse_primary_expression();
// FIXME: tests can also be expressed like this: if x is eq 3
if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id));
@@ -432,7 +440,7 @@ private:
auto operand = parse_call_member_expression();
while (is(token::pipe)) {
size_t start_pos = current;
current++;
++current; // consume pipe
auto filter = parse_primary_expression();
if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter));
@@ -490,7 +498,7 @@ private:
statement_ptr parse_member_expression(statement_ptr object) {
size_t start_pos = current;
while (is(token::dot) || is(token::open_square_bracket)) {
auto op = tokens[current++];
auto op = next();
bool computed = op.t == token::open_square_bracket;
statement_ptr prop;
if (computed) {
@@ -536,7 +544,7 @@ private:
statement_ptr parse_primary_expression() {
size_t start_pos = current;
auto t = tokens[current++];
auto t = next();
switch (t.t) {
case token::numeric_literal:
if (t.value.find('.') != std::string::npos) {
@@ -547,7 +555,7 @@ private:
case token::string_literal: {
std::string val = t.value;
while (is(token::string_literal)) {
val += tokens[current++].value;
val += next().value;
}
return mk_stmt<string_literal>(start_pos, val);
}
@@ -562,9 +570,9 @@ private:
statements vals;
while (!is(token::close_square_bracket)) {
vals.push_back(parse_expression());
if (is(token::comma)) current++;
if (is(token::comma)) ++current;
}
current++;
++current;
return mk_stmt<array_literal>(start_pos, std::move(vals));
}
case token::open_curly_bracket: {
@@ -573,9 +581,9 @@ private:
auto key = parse_expression();
expect(token::colon, "Expected :");
pairs.push_back({std::move(key), parse_expression()});
if (is(token::comma)) current++;
if (is(token::comma)) ++current;
}
current++;
++current;
return mk_stmt<object_literal>(start_pos, std::move(pairs));
}
default:

View File

@@ -451,7 +451,7 @@ struct value_array_t : public value_t {
}
protected:
virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence());
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), other.val_arr.end(), value_equivalence());
}
};
using value_array = std::shared_ptr<value_array_t>;
@@ -587,7 +587,7 @@ struct value_object_t : public value_t {
}
protected:
virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence());
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), other.val_obj.end(), value_equivalence());
}
};
using value_object = std::shared_ptr<value_object_t>;

View File

@@ -163,9 +163,15 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
ctx->force_pos = 0;
}
// forward declaration for use in clone
static struct llama_sampler * common_reasoning_budget_init_state(
const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
int32_t budget, common_reasoning_budget_state initial_state);
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
return common_reasoning_budget_init(
return common_reasoning_budget_init_state(
ctx->vocab,
ctx->start_matcher.tokens,
ctx->end_matcher.tokens,
@@ -191,13 +197,13 @@ static struct llama_sampler_i common_reasoning_budget_i = {
/* .backend_set_input = */ nullptr,
};
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
static struct llama_sampler * common_reasoning_budget_init_state(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
// promote COUNTING with budget <= 0 to FORCING
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
initial_state = REASONING_BUDGET_FORCING;
@@ -217,3 +223,41 @@ struct llama_sampler * common_reasoning_budget_init(
}
);
}
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
const std::vector<llama_token> & prefill_tokens) {
// Determine initial state from prefill: COUNTING if the prefill begins with
// the start sequence but does not also contain the end sequence after it.
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE;
if (!prefill_tokens.empty() && !start_tokens.empty() &&
prefill_tokens.size() >= start_tokens.size() &&
std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) {
initial_state = REASONING_BUDGET_COUNTING;
// If the end sequence also follows the start in the prefill, reasoning
// was opened and immediately closed — stay IDLE.
if (!end_tokens.empty() &&
prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) {
auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size();
if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() &&
std::equal(end_tokens.begin(), end_tokens.end(), end_start)) {
initial_state = REASONING_BUDGET_IDLE;
}
}
}
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
}
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
}

View File

@@ -24,14 +24,26 @@ enum common_reasoning_budget_state {
// DONE: passthrough forever
//
// Parameters:
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires
// budget - max tokens allowed in the reasoning block
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
// note: COUNTING with budget <= 0 is promoted to FORCING
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires
// budget - max tokens allowed in the reasoning block
// prefill_tokens - tokens already present in the prompt (generation prompt);
// used to determine the initial state: COUNTING if they begin
// with start_tokens (but don't also end with end_tokens),
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
//
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
const std::vector<llama_token> & prefill_tokens = {});
// Variant that takes an explicit initial state (used by tests and clone).
// COUNTING with budget <= 0 is promoted to FORCING.
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,

View File

@@ -1,13 +1,16 @@
#include "sampling.h"
#include "common.h"
#include "ggml.h"
#include "log.h"
#include "reasoning-budget.h"
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstring>
#include <unordered_map>
#include <vector>
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
@@ -189,9 +192,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
std::vector<llama_sampler *> samplers;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
const std::string & grammar_str = common_grammar_value(params.grammar);
if (grammar_str.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
grmr = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
@@ -240,17 +244,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
trigger_patterns_c.push_back(regex.c_str());
}
if (!params.grammar.empty()) {
if (!grammar_str.empty()) {
if (params.grammar_lazy) {
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, grammar_str.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size());
} else {
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
grmr = llama_sampler_init_grammar(vocab, grammar_str.c_str(), "root");
}
}
}
// Feed generation prompt tokens to the grammar sampler so it advances past
// tokens the template already placed in the prompt.
// Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled.
std::vector<llama_token> prefill_tokens;
if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) {
GGML_ASSERT(vocab != nullptr);
prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true);
if (!prefill_tokens.empty()) {
std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true);
if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) {
// Some tokenizers will add a space before the first special token, need to remove
prefill_tokens = std::vector<llama_token>(prefill_tokens.begin() + 1, prefill_tokens.end());
}
}
if (grmr) {
try {
for (const auto & token : prefill_tokens) {
llama_sampler_accept(grmr, token);
LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token);
}
} catch (std::exception &e) {
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
throw e;
}
}
}
// reasoning budget sampler — added first so it can force tokens before other samplers
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
samplers.push_back(common_reasoning_budget_init(
@@ -259,7 +292,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
params.reasoning_budget_end,
params.reasoning_budget_forced,
params.reasoning_budget_tokens,
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
prefill_tokens));
}
if (params.has_logit_bias()) {

View File

@@ -31,10 +31,10 @@ import gguf
from gguf.vocab import MistralTokenizerType, MistralVocab
try:
from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.base import TokenizerVersion # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found]
SentencePieceTokenizer,
)
@@ -45,9 +45,9 @@ except ImportError:
_MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
_mistral_common_installed = False
TokenizerVersion = None
Tekkenizer = None
SentencePieceTokenizer = None
TokenizerVersion: Any = None
Tekkenizer: Any = None
SentencePieceTokenizer: Any = None
_mistral_import_error_msg = (
"Mistral format requires `mistral-common` to be installed. Please run "
"`pip install mistral-common[image,audio]` to install it."
@@ -145,6 +145,7 @@ class ModelBase:
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
self._is_nvfp4 = False
self._is_mxfp4 = False
# Apply heuristics to figure out typical tensor encoding based on first tensor's dtype
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
@@ -220,7 +221,7 @@ class ModelBase:
if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
tensor_names_from_index.update(weight_map.keys())
part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None)
part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None) # ty: ignore[invalid-assignment]
part_names = sorted(part_dict.keys())
else:
weight_map = {}
@@ -712,6 +713,7 @@ class ModelBase:
def prepare_tensors(self):
# detect NVFP4 quantization (ModelOpt format)
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
quant_method = (self.hparams.get("quantization_config") or {}).get("quant_method")
quant_layers = (self.hparams.get("quantization_config") or {}).get("quantized_layers") or {}
quant_config_file = self.dir_model / "hf_quant_config.json"
@@ -728,6 +730,7 @@ class ModelBase:
quant_algo = "NVFP4"
self._is_nvfp4 = quant_algo == "NVFP4"
self._is_mxfp4 = quant_method == "mxfp4"
# NVFP4 weights are repacked and written directly to gguf_writer.
# This must run before dequant_model so NVFP4 tensors are removed
@@ -876,6 +879,12 @@ class ModelBase:
if self.metadata.name is None:
self.metadata.name = self.dir_model.name
if self.ftype in (gguf.LlamaFileType.ALL_F32, gguf.LlamaFileType.MOSTLY_F16, gguf.LlamaFileType.MOSTLY_BF16):
if self._is_nvfp4:
self.ftype = gguf.LlamaFileType.MOSTLY_NVFP4
elif self._is_mxfp4:
self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE
# Generate parameter weight class (useful for leader boards) if not yet determined
if self.metadata.size_label is None and total_params > 0:
self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
@@ -1062,6 +1071,10 @@ class TextModel(ModelBase):
self.gguf_writer.add_head_count_kv(n_head_kv)
logger.info(f"gguf: key-value head count = {n_head_kv}")
if self.hparams.get("is_causal") is False:
self.gguf_writer.add_causal_attention(False)
logger.info("gguf: causal attention = False")
# TODO: Handle "sliding_attention" similarly when models start implementing it
rope_params = self.rope_parameters.get("full_attention", self.rope_parameters)
if (rope_type := rope_params.get("rope_type")) is not None:
@@ -4260,6 +4273,16 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
@ModelBase.register("InternVisionModel")
class InternVisionModel(MmprojModel):
min_dynamic_tiles: int = 0
max_dynamic_tiles: int = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.min_dynamic_tiles = self.global_config.get("min_dynamic_patch", 0)
self.max_dynamic_tiles = self.global_config.get("max_dynamic_patch", 0)
def set_gguf_parameters(self):
assert self.hparams_vision is not None
if isinstance(self.hparams_vision['image_size'], list):
@@ -4282,6 +4305,11 @@ class InternVisionModel(MmprojModel):
downsample_ratio = self.global_config.get("downsample_ratio")
assert downsample_ratio is not None
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
# older models may not have min/max_dynamic_patch in config
if self.min_dynamic_tiles > 0:
self.gguf_writer.add_vision_preproc_min_tiles(self.min_dynamic_tiles)
if self.max_dynamic_tiles > 0:
self.gguf_writer.add_vision_preproc_max_tiles(self.max_dynamic_tiles)
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".position_embd." in new_name:
@@ -5878,7 +5906,7 @@ class InternLM2Model(TextModel):
logger.error(f'Error: Missing {tokenizer_path}')
sys.exit(1)
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
@@ -6199,7 +6227,7 @@ class BertModel(TextModel):
vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size)
else:
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
@@ -8876,7 +8904,7 @@ class T5Model(TextModel):
if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}")
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
@@ -9013,7 +9041,7 @@ class T5EncoderModel(TextModel):
if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}")
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
@@ -11121,8 +11149,7 @@ class GptOssModel(TextModel):
# TODO: remove once MXFP4 is supported more generally
def dequant_model(self):
quant_config = self.hparams.get("quantization_config")
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
if self._is_mxfp4:
return
return super().dequant_model()
@@ -12275,6 +12302,7 @@ class LazyTorchTensor(gguf.LazyBase):
kwargs = {}
if func is torch.Tensor.numpy:
assert len(args)
return args[0].numpy()
return cls._wrap_fn(func)(*args, **kwargs)

View File

@@ -112,11 +112,11 @@ class Tensor:
(n_dims, name_len, dtype) = struct.unpack('<3I', data[offset:offset + 12])
assert n_dims >= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}'
assert name_len < 4096, 'Absurd tensor name length'
quant = gguf.GGML_QUANT_SIZES.get(dtype)
self.dtype = gguf.GGMLQuantizationType(dtype)
quant = gguf.GGML_QUANT_SIZES.get(self.dtype)
assert quant is not None, 'Unknown tensor type'
(blksize, tysize) = quant
offset += 12
self.dtype= gguf.GGMLQuantizationType(dtype)
self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)])
offset += 4 * n_dims
self.name = bytes(data[offset:offset + name_len])

View File

@@ -199,10 +199,13 @@ class LoraTorchTensor:
kwargs = {}
if func is torch.permute:
assert len(args)
return type(args[0]).permute(*args, **kwargs)
elif func is torch.reshape:
assert len(args)
return type(args[0]).reshape(*args, **kwargs)
elif func is torch.stack:
assert len(args)
assert isinstance(args[0], Sequence)
dim = kwargs.get("dim", 0)
assert dim == 0
@@ -211,6 +214,7 @@ class LoraTorchTensor:
torch.stack([b._lora_B for b in args[0]], dim),
)
elif func is torch.cat:
assert len(args)
assert isinstance(args[0], Sequence)
dim = kwargs.get("dim", 0)
assert dim == 0
@@ -362,7 +366,7 @@ if __name__ == '__main__':
logger.error(f"Model {hparams['architectures'][0]} is not supported")
sys.exit(1)
class LoraModel(model_class):
class LoraModel(model_class): # ty: ignore[unsupported-base]
model_arch = model_class.model_arch
lora_alpha: float

View File

@@ -14,7 +14,7 @@ The unified auto-parser uses a pure differential, compositional approach (inspir
**Analysis + Parser Building in Two Steps**:
1. `autoparser::autoparser tmpl_analysis(tmpl)` — runs all differential comparisons and populates the analysis structs
2. `autoparser::peg_generator::generate_parser(tmpl, params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar
2. `autoparser::peg_generator::generate_parser(tmpl, generation_params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar
## Data Structures
@@ -34,7 +34,7 @@ All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h
### `analyze_tools` and its sub-structs
- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`, `uses_python_dicts`)
- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`)
- [common/chat-auto-parser.h:196-200](common/chat-auto-parser.h#L196-L200) — `tool_function_analysis`: `name_prefix`, `name_suffix`, `close` markers around function names
- [common/chat-auto-parser.h:202-210](common/chat-auto-parser.h#L202-L210) — `tool_arguments_analysis`: `start/end` container markers, `name_prefix/suffix`, `value_prefix/suffix`, `separator`
- [common/chat-auto-parser.h:212-217](common/chat-auto-parser.h#L212-L217) — `tool_id_analysis`: `pos` enum, `prefix`/`suffix` markers around call ID values
@@ -47,12 +47,21 @@ All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h
| Value | Description |
|-----------------|-----------------------------------------------------------------------------------|
| `NONE` | No reasoning markers detected |
| `TAG_BASED` | Standard tag-based: `<think>...</think>` |
| `DELIMITER` | Delimiter-based: reasoning ends at a delimiter (e.g., `[BEGIN FINAL RESPONSE]`) |
| `FORCED_OPEN` | Template ends with open reasoning tag when `enable_thinking=true` |
| `FORCED_CLOSED` | `enable_thinking=false` emits both tags; `enable_thinking=true` emits only start |
| `TAG_BASED` | Tag-based: `<think>...</think>` (start can be empty for delimiter-style formats) |
| `TOOLS_ONLY` | Reasoning only appears in tool call responses, not plain content |
**Generation Prompt & Reasoning Prefill**: Computed in `common_chat_templates_apply_jinja` before invoking either the specialized handlers or the auto-parser, by rendering the template twice — once with `add_generation_prompt=false` and once with `add_generation_prompt=true` — and storing the diff suffix as `generation_params::generation_prompt`. This string is propagated into `common_chat_params::generation_prompt` and `common_chat_parser_params::generation_prompt`.
The generation prompt is prepended to model output before PEG parsing via `wrap_for_generation_prompt()`. The portion *before* the reasoning start marker (if any) is prepended as a literal to ensure any boilerplate added by the template is consumed. The full string is also fed to the grammar sampler via `llama_sampler_accept` (stored in `common_params_sampling::grammar_prefill`), advancing the grammar past tokens already in the prompt. It is used to determine the reasoning budget sampler's initial state — COUNTING if the prefill tokens begin with the reasoning start sequence (but don't also contain the end sequence), IDLE otherwise.
**`grammar_prefill`** (`common_params_sampling`): The generation prompt string tokenized and accepted by the grammar sampler at init time. Only applied when `grammar_external` is false (i.e., the grammar was not set explicitly by the user).
Three outcomes for reasoning-prefill handling (in `generate_parser()`):
1. **Start+end in generation prompt** (e.g. `<think></think>\n`): the parser sees reasoning as opened and immediately closed; whitespace-only reasoning content is discarded.
2. **Only start in generation prompt** (e.g. `<think>\n`): the parser sees reasoning as already open.
3. **Start marker present but not at the end** (e.g. Apriel's `<|begin_assistant|>` followed by boilerplate): the marker is a template artifact; the start literal is cleared so reasoning uses delimiter-style (end-only). For templates that ignore `add_generation_prompt` (empty diff), the rendered `data.prompt` is used as fallback — but only for non-TOOLS_ONLY modes, since in TOOLS_ONLY the start tag is model-generated and may appear in prior conversation turns.
**`content_mode`**: How the template wraps assistant content.
| Value | Description |
@@ -261,16 +270,16 @@ Text is segmentized into markers and non-marker fragments using `segmentize_mark
- Searches `diff.right` (output with reasoning) for the reasoning content needle
- Uses PEG parsers to find surrounding markers:
- If both pre/post markers found in `diff.right``TAG_BASED` (both tags visible in diff = no forced close)
- If both found but post marker only in the full output B → `FORCED_CLOSED`
- If only post marker found → `DELIMITER`
- If both pre/post markers found in `diff.right``TAG_BASED`
- If both found but post marker only in the full output B → `TAG_BASED` (template forces markers; handled via prefill)
- If only post marker found → `TAG_BASED` (delimiter-style, empty start)
- Sets `reasoning.start` and `reasoning.end`
**R2 — `compare_thinking_enabled()`**: Compares `enable_thinking=false` vs `true` with a generation prompt.
- Detects `FORCED_OPEN`: `enable_thinking=true` adds a non-empty marker at the end of the prompt (where model will start generating) — sets `reasoning.start`, mode = `FORCED_OPEN`
- Detects `FORCED_CLOSED`: `enable_thinking=false` produces both start+end markers; `enable_thinking=true` produces only start marker
- Handles the reverse case: if both start and end are still empty, looks for a single-segment diff on each side to extract both markers
- Detects template-added reasoning markers: `enable_thinking=true` appends a non-empty marker sets `reasoning.start`, mode = `TAG_BASED`
- Handles the reverse case (`enable_thinking=false` appends the marker instead): extracts both start (from the preceding segment) and end markers; mode = `TAG_BASED`
- The reasoning prefill (markers added by the template) is later extracted in `common_chat_templates_apply_jinja` and prepended to model output before parsing
**R3 — `compare_reasoning_scope()`**: Compares assistant message with reasoning+text-content vs reasoning+tool-calls.
@@ -343,7 +352,7 @@ Classification logic:
A workaround array in `common/chat-diff-analyzer.cpp` applies post-hoc patches after analysis. Each workaround is a lambda that inspects the template source and overrides analysis results. Current workarounds:
1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')`: sets `reasoning.mode = FORCED_OPEN` with `<think>`/`</think>` markers if no reasoning was detected
1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')` but not `<SPECIAL_12>`: sets `reasoning.mode = TAG_BASED` with `<think>`/`</think>` markers if no reasoning was detected
2. **Granite 3.3** — source contains specific "Write your thoughts" text: forces `TAG_BASED` reasoning with `<think>`/`</think>` and `WRAPPED_WITH_REASONING` content with `<response>`/`</response>`
3. **Cohere Command R+** — source contains `<|CHATBOT_TOKEN|>`: sets `ALWAYS_WRAPPED` content mode if no content start is already set
4. **Functionary 3.1** — source contains `set has_code_interpreter`: forces `PLAIN` content, specific `per_call_start/end`, clears preserved tokens to only keep Functionary-specific markers
@@ -355,12 +364,13 @@ Each analyzer struct (`analyze_reasoning`, `analyze_content`, `analyze_tools`) i
#### Reasoning Parser (`analyze_reasoning::build_parser`)
| Mode | Parser |
|-----------------------------------|---------------------------------------------------------------------|
| Not extracting reasoning | `eps()` |
| `FORCED_OPEN` or `FORCED_CLOSED` | `reasoning(until(end)) + end` — opening tag was in the prompt |
| `TAG_BASED` or `TOOLS_ONLY` | `optional(start + reasoning(until(end)) + end)` |
| `DELIMITER` | `optional(reasoning(until(end)) + end)` — no start marker |
| Mode | Parser |
|-----------------------------------------------|---------------------------------------------------------------------------|
| Not extracting reasoning | `eps()` |
| `TAG_BASED` or `TOOLS_ONLY` (non-empty start) | `optional(start + reasoning(until(end)) + end + space())` |
| `TAG_BASED` or `TOOLS_ONLY` (empty start) | `optional(reasoning(until(end)) + end + space())` — delimiter-style |
Note: The start marker may be empty either because the analyzer detected delimiter-style reasoning, or because `generate_parser()` cleared a template artifact start marker (see Generation Prompt & Reasoning Prefill above). Whitespace-only reasoning content (e.g. from a `<think></think>` prefill) is discarded by the mapper.
#### Content Parser (`analyze_content::build_parser`)
@@ -410,9 +420,7 @@ All three tool parsers return:
reasoning + optional(content(until(trigger_marker))) + tool_calls + end()
```
### Python Dict Format
When `format.uses_python_dicts` is true (detected when single-quoted strings appear in JSON argument context), `build_parser()` pre-registers a `json-string` rule that accepts both single-quoted and double-quoted strings. This is done before any `p.json()` call so all JSON parsing inherits the flexible rule.
Each returned parser is wrapped by `wrap_for_generation_prompt()`, which prepends a literal for any boilerplate prefix of the generation prompt (the portion before the reasoning start marker).
## Mapper
@@ -421,22 +429,22 @@ When `format.uses_python_dicts` is true (detected when single-quoted strings app
- **Buffered arguments**: Before `tool_name` is known, argument text goes to `args_buffer`; once the name is set, the buffer is flushed to `current_tool->arguments`
- **`args_target()`**: Returns a reference to whichever destination is currently active (buffer or tool args), eliminating branching
- **`closing_quote_pending`**: Tracks whether a closing `"` needs to be appended when a string argument value is finalized (for schema-declared string types in tagged format)
- **Quote normalization**: Python-style quotes (`'key': 'value'`) are converted to JSON (`"key": "value"`)
- **Whitespace-only reasoning**: Reasoning content that consists entirely of whitespace (e.g. from a `<think></think>` prefill) is cleared so the message shows no reasoning
- **Brace auto-closing**: At tool close, unclosed `{` braces are closed automatically
## Files
| File | Purpose |
|-------------------------------------------|----------------------------------------------------------------------|
| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `templates_params` |
| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods |
| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds |
| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, |
| | `compare_variants()`, string helpers |
| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers |
| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` |
| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis |
| `tools/parser/template-analysis.cpp` | Template analysis tool |
| File | Purpose |
|-------------------------------------------|---------------------------------------------------------------------------------|
| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `generation_params` |
| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods |
| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds |
| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, `compare_variants()`, |
| | `wrap_for_generation_prompt()`, string helpers |
| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers |
| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` |
| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis |
| `tools/parser/template-analysis.cpp` | Template analysis tool |
## Testing & Debugging
@@ -516,10 +524,10 @@ To support a new template format:
## Edge Cases and Quirks
1. **Forced Thinking**: When `enable_thinking=true` and the model prompt ends with an open reasoning tag (e.g., `<think>`), the parser enters forced thinking mode and immediately expects reasoning content without waiting for a start marker.
1. **Generation Prompt & Reasoning Prefill**: The generation prompt is extracted by diffing `add_generation_prompt=false` vs `true` in `common_chat_templates_apply_jinja`, so it contains exactly what the template appends — avoiding false positives from prior conversation turns.
2. **Per-Call vs Per-Section Markers**: Some templates wrap each tool call individually (`per_call_start/end`); others wrap the entire section (`section_start/end`). T2 (`check_per_call_markers()`) disambiguates by checking if the second call in a two-call output starts with the section marker.
3. **Python Dict Format**: The Seed template family uses single-quoted JSON (`'key': 'value'`). The `uses_python_dicts` flag causes the PEG builder to register a flexible `json-string` rule accepting both quote styles before any JSON rules are built.
4. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction.
5. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case.
6. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`.
7. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats.
3. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction.
4. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case.
5. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`.
6. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats.
7. **Undetected Tool Format**: If `analyze_tools` concludes tool calling is supported but cannot determine the format, `build_parser()` logs an error and returns `eps()` (graceful degradation) rather than aborting.

View File

@@ -12,9 +12,9 @@ Legend:
- 🟡 Partially supported by this backend
- ❌ Not supported by this backend
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
| Operation | BLAS | CANN | CPU | CUDA | MTL | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
|-----------|------|------|------|------|------|------|------|------|------|------|------|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| ABS | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
@@ -23,63 +23,63 @@ Legend:
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | | ❌ | ✅ | ❌ | | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | 🟡 | ❌ | ✅ | ❌ | | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
@@ -91,31 +91,31 @@ Legend:
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | | ❌ | 🟡 | ✅ | | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | | ❌ | 🟡 | ✅ | | ❌ | ❌ |
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | | ❌ | ❌ | ✅ | | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | | ❌ | ❌ | ✅ | | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -28,9 +28,6 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
return f'({result})?' if min_items == 0 else result
def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
has_min = min_value != None
has_max = max_value != None
def digit_range(from_char: str, to_char: str):
out.append("[")
if from_char == to_char:
@@ -106,7 +103,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
out.append(to_str[i])
out.append("]")
if has_min and has_max:
if min_value is not None and max_value is not None:
if min_value < 0 and max_value < 0:
out.append("\"-\" (")
_generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
@@ -133,7 +130,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
less_decimals = max(decimals_left - 1, 1)
if has_min:
if min_value is not None:
if min_value < 0:
out.append("\"-\" (")
_generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
@@ -177,7 +174,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
more_digits(length - 1, less_decimals)
return
if has_max:
if max_value is not None:
if max_value >= 0:
if top_level:
out.append("\"-\" [1-9] ")

View File

@@ -64,7 +64,7 @@ def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device
print("Using SentenceTransformer to apply all numbered layers")
model = SentenceTransformer(model_path)
tokenizer = model.tokenizer
config = model[0].auto_model.config # type: ignore
config = model[0].auto_model.config
else:
tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
@@ -108,8 +108,8 @@ def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device
print(f"Model file: {type(model).__module__}")
# Verify the model is using the correct sliding window
if hasattr(model.config, 'sliding_window'): # type: ignore
print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore
if hasattr(model.config, 'sliding_window'):
print(f"Model's sliding_window: {model.config.sliding_window}")
else:
print("Model config does not have sliding_window attribute")
@@ -152,7 +152,7 @@ def main():
device = next(model.parameters()).device
else:
# For SentenceTransformer, get device from the underlying model
device = next(model[0].auto_model.parameters()).device # type: ignore
device = next(model[0].auto_model.parameters()).device
model_name = os.path.basename(model_path)
@@ -177,7 +177,7 @@ def main():
print(f"{token_id:6d} -> '{token_str}'")
print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}")
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}")
else:
# Standard approach: use base model output only
encoded = tokenizer(
@@ -205,12 +205,12 @@ def main():
print(f"Embedding dimension: {all_embeddings.shape[1]}")
if len(all_embeddings.shape) == 1:
n_embd = all_embeddings.shape[0] # type: ignore
n_embd = all_embeddings.shape[0]
n_embd_count = 1
all_embeddings = all_embeddings.reshape(1, -1)
else:
n_embd = all_embeddings.shape[1] # type: ignore
n_embd_count = all_embeddings.shape[0] # type: ignore
n_embd = all_embeddings.shape[1]
n_embd_count = all_embeddings.shape[0]
print()

View File

@@ -2,7 +2,7 @@
import argparse
import sys
from common import compare_tokens # type: ignore
from common import compare_tokens # type: ignore[import-not-found]
def parse_arguments():

View File

@@ -6,7 +6,7 @@ import re
from copy import copy
from enum import Enum
from inspect import getdoc, isclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
from docstring_parser import parse
from pydantic import BaseModel, create_model
@@ -1158,7 +1158,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
# Assert that the parameter has a type annotation
if param.annotation == inspect.Parameter.empty:
raise TypeError(f"Parameter '{param.name}' in function '{func.__name__}' lacks a type annotation")
raise TypeError(f"""Parameter '{param.name}' in function '{getattr(func, "__name__", "")}' lacks a type annotation""")
# Find the parameter's description in the docstring
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
@@ -1166,7 +1166,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
# Assert that the parameter has a description
if not param_doc or not param_doc.description:
raise ValueError(
f"Parameter '{param.name}' in function '{func.__name__}' lacks a description in the docstring")
f"""Parameter '{param.name}' in function '{getattr(func, "__name__", "")}' lacks a description in the docstring""")
# Add parameter details to the schema
param_docs.append((param.name, param_doc))
@@ -1177,7 +1177,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
dynamic_fields[param.name] = (
param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
# Creating the dynamic model
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
dynamic_model = create_model(f"{getattr(func, '__name__')}", **dynamic_fields)
for name, param_doc in param_docs:
dynamic_model.model_fields[name].description = param_doc.description
@@ -1285,7 +1285,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
if items != {}:
array = {"properties": items}
array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
fields[field_name] = (List[array_type], ...)
fields[field_name] = (list[array_type], ...) # ty: ignore[invalid-type-form]
else:
fields[field_name] = (list, ...)
elif field_type == "object":

View File

@@ -1788,9 +1788,11 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // src
ggml_tensor * src1 = dst->src[1]; // index
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16
|| dst->type == GGML_TYPE_BF16);
switch (src0->type) {
case GGML_TYPE_BF16:
case GGML_TYPE_F16:
case GGML_TYPE_F32:
if (src0->type == dst->type) {
@@ -1881,6 +1883,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
break;
}
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));
@@ -1891,7 +1894,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
}
acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
dst->type);
@@ -1965,7 +1968,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
// Only check env once.
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (weight_to_nz && is_matmul_weight(weight)) {
if (weight_to_nz && weight->type != GGML_TYPE_BF16 && is_matmul_weight(weight)) {
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
} else {
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
@@ -2146,6 +2149,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
switch (type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
ggml_cann_mat_mul_fp(ctx, dst);
break;
case GGML_TYPE_Q4_0:

View File

@@ -1234,7 +1234,8 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (!need_transform(tensor->type)) {
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
if (weight_to_nz && tensor->type != GGML_TYPE_BF16
&& is_matmul_weight((const ggml_tensor *) tensor)) {
GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1);
weight_format_to_nz(tensor, offset, ctx->device);
@@ -1443,7 +1444,8 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t
if (ne0 % MATRIX_ROW_PADDING != 0) {
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
}
} else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
} else if (weight_to_nz && tensor->type != GGML_TYPE_BF16
&& is_matmul_weight((const ggml_tensor *) tensor)) {
// NZ format weight are not support quantized yet.
// If ND tensor transform to NZ, size may changed.
int64_t shape[] = { tensor->ne[1], tensor->ne[0] };
@@ -2283,6 +2285,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_MUL_MAT:
{
switch (op->src[0]->type) {
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
@@ -2320,6 +2325,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
case GGML_TYPE_Q8_0:
return true;
default:
@@ -2332,6 +2340,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
switch (op->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
return true;
default:
return false;
@@ -2341,20 +2352,30 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_CPY:
{
ggml_tensor * src = op->src[0];
#ifdef ASCEND_310P
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
(src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) {
// only support F32 and F16.
// only support F32 and F16 on 310P.
return false;
}
#else
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_BF16) ||
(src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16 && src->type != GGML_TYPE_BF16)) {
// only support F32, F16 and BF16.
return false;
}
#endif
return true;
}
break;
case GGML_OP_CONT:
{
// TODO: support GGML_TYPE_BF16
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
return true;
default:
return false;

View File

@@ -572,9 +572,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
set(KLEIDIAI_FETCH_ARGS
URL ${KLEIDIAI_DOWNLOAD_URL}
DOWNLOAD_EXTRACT_TIMESTAMP NEW
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}
)
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW)
endif()
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.28")
FetchContent_Declare(KleidiAI_Download

View File

@@ -3194,6 +3194,7 @@ class tinyBLAS_PPC {
private:
__attribute__((always_inline))
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);
@@ -3204,6 +3205,7 @@ class tinyBLAS_PPC {
}
}
__attribute__((always_inline))
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);

View File

@@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_SOURCES_CUDA ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else()
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
list(APPEND GGML_SOURCES_CUDA
template-instances/fattn-vec-instance-f16-f16.cu
template-instances/fattn-vec-instance-q4_0-q4_0.cu
template-instances/fattn-vec-instance-q8_0-q8_0.cu
template-instances/fattn-vec-instance-bf16-bf16.cu)
endif()
ggml_add_backend_library(ggml-cuda

View File

@@ -41,6 +41,16 @@ template<typename dst_t, typename src_t>
return __bfloat162float(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
return __float22half2_rn(x);
} else if constexpr(std::is_same_v<src_t, nv_bfloat162> && std::is_same_v<dst_t, float2>) {
#ifdef GGML_USE_HIP
return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x)));
#else
#if __CUDA_ARCH__ >= 800
return __bfloat1622float2(x);
#else
return make_float2(__bfloat162float(x.x), __bfloat162float(x.y));
#endif // __CUDA_ARCH__ >= 800
#endif // GGML_USE_HIP
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
// bypass compile error on cuda 12.0.1
#ifdef GGML_USE_HIP

View File

@@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
return sum;
}
template <int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c;
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
float sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
__align__(16) nv_bfloat162 tmp[cpy_ne];
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
#ifdef V_DOT2_F32_F16_AVAILABLE
// FIXME replace macros in vector FA kernel with templating and use FP32 for BF16
ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]));
#else
ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
return sum;
}
template<int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -321,6 +352,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
}
}
template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
static_assert(std::is_same_v<T, float>, "BF16 V dequantization only supports float output");
static_assert(ne % 2 == 0, "bad ne");
__align__(16) nv_bfloat162 tmp[ne/2];
ggml_cuda_memcpy_1<ne*sizeof(nv_bfloat16)>(tmp, (const nv_bfloat16 *) vx + i0);
float2 * dst_f2 = (float2 *) dst;
#pragma unroll
for (int l = 0; l < ne/2; ++l) {
dst_f2[l] = ggml_cuda_cast<float2>(tmp[l]);
}
}
template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
const block_q4_0 * x = (const block_q4_0 *) vx;
@@ -547,6 +591,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_Q8_0) {
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_BF16) {
return vec_dot_fattn_vec_KQ_bf16<D, nthreads>;
} else {
static_assert(type_K == -1, "bad type");
return nullptr;
@@ -567,6 +613,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
return dequantize_V_q5_1<T, ne>;
} else if constexpr (type_V == GGML_TYPE_Q8_0) {
return dequantize_V_q8_0<T, ne>;
} else if constexpr (type_V == GGML_TYPE_BF16) {
return dequantize_V_bf16<float, ne>;
} else {
static_assert(type_V == -1, "bad type");
return nullptr;

View File

@@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec(
#endif // GGML_USE_HIP
constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q;
constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q;
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4;
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16;
#ifdef V_DOT2_F32_F16_AVAILABLE
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
#else
@@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec(
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
half2 tmp[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
if constexpr (type_V == GGML_TYPE_BF16) {
float2 tmp_f[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp_f,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]);
}
} else {
dequantize_V(V + k*nb21, tmp,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
}
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
#pragma unroll
@@ -563,6 +573,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
@@ -570,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
@@ -577,6 +589,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
@@ -584,3 +597,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16)

View File

@@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
@@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
@@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
@@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
@@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
@@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
#else
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
#endif // GGML_CUDA_FA_ALL_QUANTS
GGML_ABORT("fatal error");
@@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
#endif // GGML_CUDA_FA_ALL_QUANTS
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_BF16:
break;
default:
return BEST_FATTN_KERNEL_NONE;

View File

@@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
}
}
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
@@ -173,11 +173,11 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
return 1;
}
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_dst) {
case 1:
return 1;
return small_k ? nwarps : 1;
case 2:
case 3:
case 4:
@@ -193,7 +193,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
return 1;
}
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false, bool small_k = false>
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
@@ -208,7 +208,7 @@ static __global__ void mul_mat_vec_q(
constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
@@ -414,14 +414,16 @@ static __global__ void mul_mat_vec_q(
template<ggml_type type>
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) {
const int nwarps = calc_nwarps(type, ncols_dst, table_id);
const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
const int64_t nblocks = (nrows_x + rpb - 1) / rpb;
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);
const dim3 block_dims(warp_size, nwarps, 1);
return {block_nums, block_dims};
}
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false, bool small_k = false>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@@ -434,7 +436,7 @@ static void mul_mat_vec_q_switch_fusion(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@@ -444,7 +446,7 @@ static void mul_mat_vec_q_switch_fusion(
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@@ -488,11 +490,33 @@ static void mul_mat_vec_q_switch_ncols_dst(
switch (ncols_dst) {
case 1: {
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
// When K is small, increase rows_per_block to match nwarps so each warp has more work to do
// Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_iter_1warp = vdr * warp_size / qi;
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
if (use_small_k) {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id, true);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
} else {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
}
} break;
case 2: {
constexpr int c_ncols_dst = 2;

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16);

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16);

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0);

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1);

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0);

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1);

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0);

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16);

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16);

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16);

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16);

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16);

View File

@@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16);

View File

@@ -5,7 +5,7 @@ import os
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.

View File

@@ -45,6 +45,7 @@ static int opt_verbose = 0;
static int opt_profile = 0;
static int opt_hostbuf = 1; // hostbuf ON by default
static int opt_experimental = 0;
static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only
// Enable all stages by default
static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE;
@@ -1693,7 +1694,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
// Start the DSP-side service. We need to pass the queue ID to the
// DSP in a FastRPC call; the DSP side will import the queue and start
// listening for packets in a callback.
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx);
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx);
if (err != 0) {
GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
@@ -3372,6 +3373,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
const char * str_etm = getenv("GGML_HEXAGON_ETM");
const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
@@ -3381,8 +3383,9 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask;
opt_opsync = str_opsync ? atoi(str_opsync) : 0;
opt_profile = str_profile ? atoi(str_profile) : 0;
opt_etm = str_etm ? atoi(str_etm) : 0;
opt_etm = str_etm ? atoi(str_etm) : 0;
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx;
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {

View File

@@ -40,6 +40,24 @@ target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
# HMX acceleration: available on v73+ architectures
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-matmul-ops.c
)
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-matmul-ops.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)

View File

@@ -175,6 +175,86 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) {
return q->capacity;
}
// ---------------------------------------------------------------------------
// Overflow-safe DMA push: all UDMA type1 descriptor fields (roiwidth,
// roiheight, srcstride, dststride) are 16-bit, max 65535. This helper
// transparently handles values that exceed the 16-bit limit and submits
// chained DMA transtions.
//
// Case 1 (fast path): all params fit in 16 bits -> direct dma_queue_push.
// Case 2 (contiguous block): width == srcstride == dststride. Reshape the
// flat transfer into a 2D descriptor with sub_width <= 65535. Produces a
// single descriptor, preserving async DMA behavior.
// Case 3 (stride overflow): srcstride or dststride > 65535. Issue rows
// one at a time. The first N-1 rows are pushed+popped synchronously;
// the last row is left async so the caller can pop it.
// ---------------------------------------------------------------------------
#define UDMA_MAX_FIELD_VAL 65535u
static inline bool dma_queue_push_chained(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t width, size_t nrows) {
// Fast path: everything fits in 16 bits.
if (__builtin_expect(
width <= UDMA_MAX_FIELD_VAL &&
nrows <= UDMA_MAX_FIELD_VAL &&
src_stride <= UDMA_MAX_FIELD_VAL &&
dst_stride <= UDMA_MAX_FIELD_VAL, 1)) {
return dma_queue_push(q, dptr, dst_stride, src_stride, width, nrows);
}
// Case 2: contiguous block (width == src_stride == dst_stride).
// Reshape total bytes into sub_width * sub_nrows where sub_width <= 65535.
if (width == src_stride && width == dst_stride) {
size_t total = width * nrows;
// Pick the largest 128-byte-aligned sub_width that divides total evenly.
size_t sub_width = UDMA_MAX_FIELD_VAL & ~(size_t)127; // 65408
while (sub_width > 0 && total % sub_width != 0) {
sub_width -= 128;
}
if (sub_width == 0) {
// Fallback: use original width (must fit) with adjusted nrows.
// This shouldn't happen for 128-aligned DMA sizes.
sub_width = width;
}
size_t sub_nrows = total / sub_width;
// Handle sub_nrows > 65535 by issuing chunked descriptors.
const uint8_t *src = (const uint8_t *)dptr.src;
uint8_t *dst = (uint8_t *)dptr.dst;
size_t rows_done = 0;
while (rows_done < sub_nrows) {
size_t chunk = sub_nrows - rows_done;
if (chunk > UDMA_MAX_FIELD_VAL) chunk = UDMA_MAX_FIELD_VAL;
dma_ptr p = dma_make_ptr(dst + rows_done * sub_width, src + rows_done * sub_width);
if (!dma_queue_push(q, p, sub_width, sub_width, sub_width, chunk))
return false;
rows_done += chunk;
// Complete all chunks without waiting except the last one, so the
// caller's single dma_queue_pop drains the final descriptor.
if (rows_done < sub_nrows)
dma_queue_pop_nowait(q);
}
return true;
}
// Case 3: stride overflow — fall back to row-by-row.
{
const uint8_t *src = (const uint8_t *)dptr.src;
uint8_t *dst = (uint8_t *)dptr.dst;
for (size_t r = 0; r < nrows; ++r) {
dma_ptr p = dma_make_ptr(dst + r * dst_stride,
src + r * src_stride);
if (!dma_queue_push(q, p, 0, 0, width, 1))
return false;
if (r + 1 < nrows)
dma_queue_pop_nowait(q);
}
return true;
}
}
#ifdef __cplusplus
} // extern "C"
#endif

View File

@@ -29,10 +29,22 @@ static inline uint64_t hex_get_pktcnt() {
return pktcnt;
}
static inline int32_t hex_is_aligned(void * addr, uint32_t align) {
static inline size_t hmx_ceil_div(size_t num, size_t den) {
return (num + den - 1) / den;
}
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
return ((size_t) addr & (align - 1)) == 0;
}
static inline size_t hex_align_up(size_t v, size_t align) {
return hmx_ceil_div(v, align) * align;
}
static inline size_t hex_align_down(size_t v, size_t align) {
return (v / align) * align;
}
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n;
@@ -43,6 +55,14 @@ static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m);
}
static inline size_t hex_smin(size_t a, size_t b) {
return a < b ? a : b;
}
static inline size_t hex_smax(size_t a, size_t b) {
return a > b ? a : b;
}
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
Q6_l2fetch_AP((void *) p, control);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,72 @@
// HMX operation entry-point declarations.
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)
#ifndef HMX_OPS_H
#define HMX_OPS_H
#include <stddef.h>
#include <stdint.h>
#ifndef restrict
# define restrict __restrict
#endif
#ifdef __cplusplus
extern "C" {
#endif
struct htp_context; // forward declaration
typedef struct {
float *dst;
const float *activation;
const __fp16 *permuted_weight;
int m;
int k;
int n;
int act_stride;
int weight_stride;
int dst_stride;
int ne02;
int ne03;
int ne12;
int ne13;
size_t src0_nb2;
size_t src0_nb3;
size_t src1_nb2;
size_t src1_nb3;
size_t dst_nb2;
size_t dst_nb3;
} hmx_matmul_w16a32_batched_params_t;
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
// act_stride: activation row stride in elements (= k for contiguous, or
// nb[1]/sizeof(float) for permuted tensors like attention Q).
// weight_stride: weight row stride in elements (= k for compact weights, or
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const __fp16 *permuted_weight,
int m, int k, int n,
int act_stride,
int weight_stride);
// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32.
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx,
const hmx_matmul_w16a32_batched_params_t *params);
// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL)
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const uint8_t *permuted_weight,
int m, int k, int n,
int weight_type);
#ifdef __cplusplus
}
#endif
#endif // HMX_OPS_H

View File

@@ -0,0 +1,34 @@
// Conditional fine-grained profiling macros for HMX operations.
//
// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this
// header) to instrument sub-operation latencies with HAP qtimer. When the
// macro is not defined the TIMER_* helpers expand to nothing so there is zero
// overhead.
//
// Usage:
// TIMER_DEFINE(my_phase); // declare accumulator variable
// TIMER_START(my_phase); // snapshot start time
// ... work ...
// TIMER_STOP(my_phase); // accumulate elapsed ticks
// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase));
#ifndef HMX_PROFILE_H
#define HMX_PROFILE_H
#include <HAP_perf.h>
// #define ENABLE_PROFILE_TIMERS
#if defined(ENABLE_PROFILE_TIMERS)
# define TIMER_DEFINE(name) int64_t name##_ticks = 0
# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count()
# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0
# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks)
#else
# define TIMER_DEFINE(name)
# define TIMER_START(name)
# define TIMER_STOP(name)
# define TIMER_US(name) 0LL
#endif
#endif // HMX_PROFILE_H

View File

@@ -0,0 +1,88 @@
// HMX tile-level inline helpers (FP16 32x32 tile operations).
// Ported from htp-ops-lib/include/dsp/hmx_utils.h. (https://github.com/haozixu/htp-ops-lib)
#ifndef HMX_UTILS_H
#define HMX_UTILS_H
#include <hexagon_types.h>
#include <stddef.h>
#define HMX_FP16_TILE_N_ROWS 32
#define HMX_FP16_TILE_N_COLS 32
#define HMX_FP16_TILE_N_ELMS 1024
#define HMX_FP16_TILE_SIZE 2048
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) {
asm volatile("bias = mxmem2(%0)" :: "r"(scales));
}
// Initialise aligned 256-byte area with scale vector + zero padding.
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
HVX_Vector *pv = (HVX_Vector *)out_scales;
*pv++ = v_scale;
*pv = Q6_V_vzero();
}
// Load multiple contiguous tiles with :deep streaming.
// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt].
// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank
// boundary, otherwise the mxmem instruction will raise a precise bus error.
// Callers must ensure their VTCM layout satisfies this constraint.
static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles,
const __fp16 *col_tiles,
size_t n_tiles) {
size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1;
asm volatile(
"{ activation.hf = mxmem(%0, %1):deep\n"
"weight.hf = mxmem(%2, %3) }\n"
:: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit)
: "memory");
}
// Load a single activation+weight tile pair (no :deep streaming).
// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula
// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047.
// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation
// places a tile near a 4 MB bank boundary, the oversized region crosses it and
// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly
// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047).
static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile,
const __fp16 *wt_tile) {
asm volatile(
"{ activation.hf = mxmem(%0, %1)\n"
"weight.hf = mxmem(%2, %3) }\n"
:: "r"(act_tile), "r"(2047),
"r"(wt_tile), "r"(2047)
: "memory");
}
static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) {
// Use the combined convert-and-store instruction (matches the reference
// Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence
// "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter.
asm volatile(
"mxmem(%0, %1):after.hf = acc\n"
:: "r"(out), "r"(0)
: "memory");
}
// Compute inner product of two vectors of tiles and store result.
static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out,
const __fp16 *row_tiles,
const __fp16 *col_tiles,
size_t n_tiles) {
hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles);
hmx_consume_accumulator_fp16(out);
}
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
uint8_t *p = *vtcm_ptr;
*vtcm_ptr += size;
return p;
}
#endif // HMX_UTILS_H

View File

@@ -30,6 +30,12 @@ struct htp_context {
atomic_bool vtcm_needs_release;
uint32_t opmask;
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
#ifdef HTP_HAS_HMX
int hmx_enabled; // Runtime flag: HMX initialisation succeeded
size_t vtcm_scratch_size; // Usable dynamic scratch (vtcm_size minus tail reservation)
#endif
};
#endif /* HTP_CTX_H */

View File

@@ -32,13 +32,14 @@ enum htp_status {
// Duplicated here because we can't include full ggml.h in the htp build.
// We have some static_asserts in the cpp code to ensure things are in sync.
enum htp_data_type {
HTP_TYPE_F32 = 0,
HTP_TYPE_F16 = 1,
HTP_TYPE_Q4_0 = 2,
HTP_TYPE_Q8_0 = 8,
HTP_TYPE_I32 = 26,
HTP_TYPE_I64 = 27,
HTP_TYPE_MXFP4 = 39,
HTP_TYPE_F32 = 0,
HTP_TYPE_F16 = 1,
HTP_TYPE_Q4_0 = 2,
HTP_TYPE_Q8_0 = 8,
HTP_TYPE_IQ4_NL = 20,
HTP_TYPE_I32 = 26,
HTP_TYPE_I64 = 27,
HTP_TYPE_MXFP4 = 39,
HTP_TYPE_COUNT
};
@@ -87,6 +88,8 @@ static inline size_t htp_t_block_size(uint32_t t) {
return QK4_0;
case HTP_TYPE_Q8_0:
return QK8_0;
case HTP_TYPE_IQ4_NL:
return QK4_NL;
case HTP_TYPE_MXFP4:
return QK_MXFP4;
default:
@@ -105,6 +108,8 @@ static inline size_t htp_type_nbytes(uint32_t t) {
return sizeof(block_q4_0);
case HTP_TYPE_Q8_0:
return sizeof(block_q8_0);
case HTP_TYPE_IQ4_NL:
return sizeof(block_iq4_nl);
case HTP_TYPE_MXFP4:
return sizeof(block_mxfp4);
default:

View File

@@ -7,7 +7,7 @@
#include "remote.idl"
interface htp_iface : remote_handle64 {
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx);
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx);
AEEResult stop();
AEEResult enable_etm();
AEEResult disable_etm();

View File

@@ -9,6 +9,9 @@
#include "hex-utils.h"
#include "hvx-types.h"
#define hvx_vmem(A) *((HVX_Vector *)(A))
#define hvx_vmemu(A) *((HVX_UVector *)(A))
static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) {
// Rotate as needed.
v = Q6_V_vlalign_VVR(v, v, (size_t) dst);
@@ -112,11 +115,15 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
return Q6_Q_and_QQ(p_exp, p_frac);
}
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
const HVX_Vector zero = Q6_V_vsplat_R(0);
static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) {
const HVX_Vector zero = Q6_V_vzero();
HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)));
return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0));
}
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
#if __HVX_ARCH__ < 79
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
@@ -128,6 +135,30 @@ static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
return v;
}
#if __HVX_ARCH__ >= 79
static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v, one);
return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p));
}
static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one);
return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p));
}
#else
static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v, one);
return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)));
}
static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one);
return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)));
}
#endif
/* Q6_Vsf_equals_Vw is only available on v73+.*/
#if __HVX_ARCH__ < 73
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)

View File

@@ -25,6 +25,10 @@
#include "htp-ops.h"
#include "worker-pool.h"
#ifdef HTP_HAS_HMX
#include "hmx-ops.h"
#endif // HTP_HAS_HMX
AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
struct htp_context * ctx;
int err = 0;
@@ -163,6 +167,9 @@ static int vtcm_acquire(struct htp_context * ctx) {
}
ctx->vtcm_inuse = true;
return 0;
}
@@ -246,7 +253,7 @@ static void vtcm_free(struct htp_context * ctx) {
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
static void htp_error_callback(dspqueue_t queue, int error, void * context);
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) {
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) {
struct htp_context * ctx = (struct htp_context *) handle;
if (!ctx) {
@@ -280,6 +287,21 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
return AEE_ENOMEMORY;
}
#ifdef HTP_HAS_HMX
if (use_hmx) {
ctx->vtcm_scratch_size = ctx->vtcm_size;
ctx->hmx_enabled = 1;
FARF(HIGH, "HMX enabled: vtcm-scratch %zu", ctx->vtcm_scratch_size);
} else {
// HMX disabled: skip HMX initialisation so the
// dispatch loop falls through to the HVX compute paths.
ctx->hmx_enabled = 0;
ctx->vtcm_scratch_size = ctx->vtcm_size;
FARF(HIGH, "HMX disabled (use_hmx=0): vtcm-scratch %zu", ctx->vtcm_scratch_size);
}
#endif
qurt_sysenv_max_hthreads_t hw_threads;
qurt_sysenv_get_max_hw_threads(&hw_threads);
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
@@ -340,6 +362,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
for (int i = 0; i < ctx->n_threads; i++) {
dma_queue_delete(ctx->dma[i]);
}
#ifdef HTP_HAS_HMX
if (ctx->hmx_enabled) {
ctx->hmx_enabled = 0;
}
#endif
vtcm_free(ctx);
@@ -375,8 +403,9 @@ static int send_htp_rsp(struct htp_context * c,
struct dspqueue_buffer * bufs,
size_t n_bufs,
struct profile_data * prof) {
// Prep response struct
// Prep response struct (zero-init to clear cmp/unused union)
struct htp_general_rsp rsp;
memset(&rsp, 0, sizeof(rsp));
rsp.op = op;
rsp.status = status;
rsp.prof_usecs = prof->usecs;
@@ -1037,6 +1066,210 @@ static void proc_flash_attn_ext_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
}
#ifdef HTP_HAS_HMX
// ---------------------------------------------------------------------------
// HMX operation wrappers — self-contained, bypass htp_ops_context / htp_spad.
// VTCM, DMA and thread dispatch are managed inside the HMX kernels.
// ---------------------------------------------------------------------------
static void proc_hmx_matmul_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
size_t n_bufs) {
// HMX weight tile requires N to be 32-aligned.
if (req->src0.ne[1] % 32 != 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
const bool is_batched = (req->src0.ne[2] * req->src0.ne[3] > 1 ||
req->src1.ne[2] * req->src1.ne[3] > 1);
// Quantised HMX kernels only handle flat 2D matmul (host already rejects
// batched quantised, but guard here too). F16 batched matmul is handled
// by the dedicated wrapper in hmx-matmul-ops.c.
if (is_batched &&
req->src0.type != HTP_TYPE_F16) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// HMX assumes contiguous row-major layout. Fall back for permuted
// tensors where strides are non-monotonic (e.g. transposed KV cache).
if (req->src0.nb[0] > req->src0.nb[1] ||
req->src1.nb[0] > req->src1.nb[1]) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// M alignment: when M > 32 but not 32-aligned, we split into
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
const int m_total = (int) req->src1.ne[1];
const int m_tail = m_total % 32;
const int m_hmx = m_total - m_tail;
if (m_hmx == 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights.
// Other types (e.g. MXFP4) fall back to HVX.
{
uint32_t wtype = req->src0.type;
if (wtype != HTP_TYPE_F16 &&
wtype != HTP_TYPE_Q4_0 &&
wtype != HTP_TYPE_Q8_0 &&
wtype != HTP_TYPE_IQ4_NL) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// Quantised HMX path requires K aligned to 256 (x4x2 super-block).
// F16 HMX path requires K aligned to 32 (tile width).
if (wtype != HTP_TYPE_F16 && req->src0.ne[0] % 256 != 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
if (wtype == HTP_TYPE_F16 && req->src0.ne[0] % 32 != 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
}
(void) n_bufs;
struct dspqueue_buffer rsp_bufs[1];
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);
// src0 = weights, src1 = activation, dst = output
void * wgt = (void *) bufs[0].ptr;
float * act = (float *) bufs[1].ptr;
float * dst = (float *) bufs[2].ptr;
int k = (int) req->src0.ne[0]; // inner dimension
int n = (int) req->src0.ne[1]; // weight columns
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
int ret = -1;
const int ne02 = (int) req->src0.ne[2];
const int ne03 = (int) req->src0.ne[3];
const int ne12 = (int) req->src1.ne[2];
const int ne13 = (int) req->src1.ne[3];
// Row strides in elements. For compact tensors these equal k; for
// permuted attention views they can be larger, so pass the real stride.
const int act_stride = (int)(req->src1.nb[1] / sizeof(float));
const int weight_stride = (int)(req->src0.nb[1] / sizeof(__fp16));
switch (req->src0.type) {
case HTP_TYPE_F16:
if (is_batched) {
hmx_matmul_w16a32_batched_params_t batch_params = {
.dst = dst,
.activation = act,
.permuted_weight = (const __fp16 *) wgt,
.m = m_hmx,
.k = k,
.n = n,
.act_stride = act_stride,
.weight_stride = weight_stride,
.dst_stride = (int)(req->dst.nb[1] / sizeof(float)),
.ne02 = ne02,
.ne03 = ne03,
.ne12 = ne12,
.ne13 = ne13,
.src0_nb2 = req->src0.nb[2],
.src0_nb3 = req->src0.nb[3],
.src1_nb2 = req->src1.nb[2],
.src1_nb3 = req->src1.nb[3],
.dst_nb2 = req->dst.nb[2],
.dst_nb3 = req->dst.nb[3],
};
ret = hmx_mat_mul_permuted_w16a32_batched(ctx, &batch_params);
} else {
ret = hmx_mat_mul_permuted_w16a32(ctx, dst, act,
(const __fp16 *) wgt,
m_hmx, k, n,
act_stride,
weight_stride);
}
break;
default:
ret = hmx_mat_mul_permuted_qk_0_d16a32(ctx, dst, act,
(const uint8_t *) wgt,
m_hmx, k, n, (int) req->src0.type);
break;
}
if (ret == 0) {
rsp_status = HTP_STATUS_OK;
} else {
FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret);
vtcm_release(ctx);
req->flags &= ~HTP_OPFLAGS_SKIP_QUANTIZE;
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
vtcm_release(ctx);
}
// --- Phase 2: HVX on the remaining m_tail rows ---
if (m_tail > 0 && rsp_status == HTP_STATUS_OK) {
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0; // weights: unchanged
octx.src1 = req->src1;
octx.src1.ne[1] = m_tail; // only tail rows
octx.dst = req->dst;
octx.dst.ne[1] = m_tail; // only tail rows
// Always re-quantize tail src1: HMX Phase 1 overwrites VTCM,
// so any previously cached quantized data (SKIP_QUANTIZE pipeline)
// is invalid.
octx.flags = req->flags & ~HTP_OPFLAGS_SKIP_QUANTIZE;
octx.op = req->op;
octx.n_threads = ctx->n_threads;
// Offset activation and dst pointers past the HMX-processed rows.
// Use nb[1] (row stride in bytes) to compute the byte offset.
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t)((uint8_t *) bufs[1].ptr + (size_t) m_hmx * req->src1.nb[1]);
octx.dst.data = (uint32_t)((uint8_t *) bufs[2].ptr + (size_t) m_hmx * req->dst.nb[1]);
FARF(HIGH, "proc_hmx_matmul: HVX tail m_tail=%d act=%p dst=%p",
m_tail, (void *)(uintptr_t) octx.src1.data, (void *)(uintptr_t) octx.dst.data);
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
uint32_t hvx_ret = op_matmul(&octx);
vtcm_release(ctx);
if (hvx_ret != HTP_STATUS_OK) {
FARF(ERROR, "HVX tail matmul failed (ret=%u)", hvx_ret);
rsp_status = HTP_STATUS_INTERNAL_ERR;
}
} else {
rsp_status = HTP_STATUS_INTERNAL_ERR;
}
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
#endif // HTP_HAS_HMX
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
struct htp_context * ctx = (struct htp_context *) context;
@@ -1089,7 +1322,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
FARF(ERROR, "Bad matmul-req buffer list");
continue;
}
proc_matmul_req(ctx, &req, bufs, n_bufs);
#ifdef HTP_HAS_HMX
if (ctx->hmx_enabled) {
proc_hmx_matmul_req(ctx, &req, bufs, n_bufs);
} else
#endif
{
proc_matmul_req(ctx, &req, bufs, n_bufs);
}
break;
case HTP_OP_MUL_MAT_ID:

View File

@@ -53,9 +53,6 @@ endif()
message(STATUS "HIP and hipBLAS found")
# Workaround old compilers
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024")
file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
@@ -74,12 +71,11 @@ if (GGML_CUDA_FA_ALL_QUANTS)
list(APPEND GGML_SOURCES_ROCM ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else()
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
list(APPEND GGML_SOURCES_ROCM
../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu)
endif()
ggml_add_backend_library(ggml-hip
@@ -132,6 +128,11 @@ endif()
if (CXX_IS_HIPCC)
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
if (WIN32 AND CMAKE_BUILD_TYPE STREQUAL "Debug")
# CMake on Windows doesn't support the HIP language yet.
# Therefore we workaround debug build's failure on HIP backend this way.
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES COMPILE_FLAGS "-O2 -g")
endif()
target_link_libraries(ggml-hip PRIVATE hip::device)
else()
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP)

View File

@@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND)
list(APPEND GGML_SOURCES_MUSA ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else()
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
list(APPEND GGML_SOURCES_MUSA
../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu)
endif()
set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)

View File

@@ -1162,12 +1162,18 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp
return nullptr;
}
// Fix: Prevent division by zero if blck_size is 0 (e.g., deprecated types)
if (ggml_blck_size((enum ggml_type)tensor->type) == 0) {
GGML_LOG_ERROR("[%s] invalid tensor type received (blck_size is 0): %u\n", __func__, tensor->type);
return nullptr;
}
ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
// ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
if (result == nullptr) {
GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\n", __func__, tensor->type);
return nullptr;
}

View File

@@ -4667,22 +4667,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
if (a->ne[3] != b->ne[3]) {
return false;
}
ggml_type a_type = a->type;
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
) {
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
return false;
}
}
ggml_type src0_type = op->src[0]->type;
if (src0_type == GGML_TYPE_BF16 ) {
// TODO: support GGML_TYPE_BF16
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
return false;
}
// TODO: The configuration below needs more work to be supported with oneDNN
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&

View File

@@ -4604,12 +4604,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
};
const bool use_subgroup_reduce = device->subgroup_arithmetic;
for (uint32_t si = 0; si < 3; si++) {
const uint32_t S_V = gdn_sizes[si];
GGML_ASSERT(is_pow2(S_V));
uint32_t lanes_per_column;
if (S_V >= 128u && device->subgroup_clustered) {
lanes_per_column = 8u;
} else {
// Use largest power-of-two that divides both S_V and subgroup_size so that
// (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0.
// This means we don't need extra bounds checking logic in the shader.
lanes_per_column = std::min(S_V, device->subgroup_size);
}
const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size);
size_t gdn_len;
const void * gdn_data;
if (use_subgroup_reduce && need_clustered_shader) {
gdn_len = gated_delta_net_f32_len;
gdn_data = (const void *)gated_delta_net_f32_data;
} else if (use_subgroup_reduce) {
gdn_len = gated_delta_net_f32_nocluster_len;
gdn_data = (const void *)gated_delta_net_f32_nocluster_data;
} else {
gdn_len = gated_delta_net_f32_shmem_len;
gdn_data = (const void *)gated_delta_net_f32_shmem_data;
}
const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column;
const std::array<uint32_t, 3> wg_denoms = {1u, 1u, cols_per_wg};
for (uint32_t kda = 0; kda < 2; kda++) {
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
"main", 7, sizeof(vk_op_gated_delta_net_push_constants),
{1, 1, 1}, {gdn_sizes[si], kda}, 1);
gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size);
}
}
}
@@ -10438,7 +10468,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, { H, n_seqs, 1u });
pc, { H, n_seqs, S_v });
}
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
@@ -16018,6 +16048,7 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev)
case 0xE20C: // B570
return 18;
case 0xE20B: // B580
case 0xE211: // Pro B60
return 20;
default:
return 0;

View File

@@ -1,11 +1,25 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#if USE_SUBGROUP_CLUSTERED
#extension GL_KHR_shader_subgroup_clustered : enable
#endif
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0,
// so no bounds checking is needed.
layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint KDA = 0;
layout(constant_id = 2) const uint SUBGROUP_SIZE = 32;
layout(constant_id = 3) const uint LANES_PER_COLUMN = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN;
const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN;
layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint H;
@@ -27,14 +41,61 @@ layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; };
layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; };
layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; };
shared FLOAT_TYPE s_k[S_V];
shared FLOAT_TYPE s_q[S_V];
shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])
#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED
shared FLOAT_TYPE temp[SUBGROUP_SIZE];
// This does a reduction across groups of LANES_PER_COLUMN
FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) {
const uint lane = gl_SubgroupInvocationID;
temp[lane] = partial;
barrier();
[[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) {
FLOAT_TYPE other = temp[lane ^ s];
barrier();
temp[lane] += other;
barrier();
}
const FLOAT_TYPE result = temp[lane];
barrier();
return result;
}
#endif
// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant
FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) {
switch (LANES_PER_COLUMN) {
case 1u:
return partial;
#if USE_SUBGROUP_CLUSTERED
// Workaround for GLSL requiring a literal constant for the cluster size.
// The branches should all fold away.
case 2u:
return subgroupClusteredAdd(partial, 2u);
case 4u:
return subgroupClusteredAdd(partial, 4u);
case 8u:
return subgroupClusteredAdd(partial, 8u);
case 16u:
return subgroupClusteredAdd(partial, 16u);
case 32u:
return subgroupClusteredAdd(partial, 32u);
case 64u:
return subgroupClusteredAdd(partial, 64u);
#endif
default:
#if USE_SUBGROUP_ADD
return subgroupAdd(partial);
#else
return reduce_add_shmem(partial);
#endif
}
}
void main() {
const uint head_id = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint col = gl_LocalInvocationID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN;
const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN);
const uint iq1 = head_id % neq1;
const uint iq3 = seq_id / rq3;
@@ -42,9 +103,9 @@ void main() {
const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;
FLOAT_TYPE state[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
FLOAT_TYPE s_shard[ROWS_PER_LANE];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]);
}
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
@@ -53,76 +114,56 @@ void main() {
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
const uint k_off = q_off;
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
s_q[col] = FLOAT_TYPE(data_q[q_off + col]);
s_k[col] = FLOAT_TYPE(data_k[k_off + col]);
const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
if (KDA != 0) {
const uint g_base = gb_off * S_V;
s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));
}
barrier();
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
FLOAT_TYPE k_reg[ROWS_PER_LANE];
FLOAT_TYPE q_reg[ROWS_PER_LANE];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
const uint i = r * LANES_PER_COLUMN + lane;
k_reg[r] = FLOAT_TYPE(data_k[k_off + i]);
q_reg[r] = FLOAT_TYPE(data_q[q_off + i]);
}
FLOAT_TYPE g_exp[ROWS_PER_LANE];
if (KDA == 0) {
const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
FLOAT_TYPE kv_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
kv_col += dot(
vec4(state[i], state[i+1], state[i+2], state[i+3]),
vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
);
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
g_exp[r] = g_val;
}
FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;
FLOAT_TYPE attn_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
sv = g_val * sv + kv * delta_col;
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
}
data_dst[attn_off + col] = attn_col * scale;
} else {
FLOAT_TYPE kv_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
kv_col += dot(gv * sv, kv);
const uint g_base = gb_off * S_V;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
const uint i = r * LANES_PER_COLUMN + lane;
g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i]));
}
}
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
FLOAT_TYPE attn_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
sv = gv * sv + kv * delta_col;
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
FLOAT_TYPE kv_shard = 0.0;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
kv_shard += g_exp[r] * s_shard[r] * k_reg[r];
}
FLOAT_TYPE kv_col = reduce_partial(kv_shard);
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
}
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
FLOAT_TYPE attn_partial = 0.0;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col;
attn_partial += s_shard[r] * q_reg[r];
}
FLOAT_TYPE attn_col = reduce_partial(attn_partial);
if (lane == 0) {
data_dst[attn_off + col] = attn_col * scale;
}
attn_off += S_V * H;
barrier();
}
[[unroll]] for (uint i = 0; i < S_V; i++) {
data_dst[s_off + state_base + col * S_V + i] = state[i];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
}
}

View File

@@ -987,7 +987,9 @@ void process_shaders() {
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}}));
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}}));
string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

View File

@@ -95,6 +95,11 @@ struct ggml_webgpu_generic_shader_decisions {
uint32_t wg_size = 0;
};
struct ggml_webgpu_ssm_conv_shader_decisions {
uint32_t block_size;
uint32_t tokens_per_wg;
};
/** Argsort **/
struct ggml_webgpu_argsort_shader_lib_context {
@@ -131,6 +136,26 @@ struct ggml_webgpu_set_rows_shader_decisions {
uint32_t wg_size;
};
/** Set **/
struct ggml_webgpu_set_pipeline_key {
ggml_type type;
bool inplace;
bool operator==(const ggml_webgpu_set_pipeline_key & other) const {
return type == other.type && inplace == other.inplace;
}
};
struct ggml_webgpu_set_pipeline_key_hash {
size_t operator()(const ggml_webgpu_set_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
/** Get Rows **/
struct ggml_webgpu_get_rows_pipeline_key {
@@ -186,6 +211,67 @@ struct ggml_webgpu_pad_pipeline_key_hash {
}
};
/** Solve Tri **/
struct ggml_webgpu_solve_tri_pipeline_key {
int type;
int n;
int k;
bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
return type == other.type && n == other.n && k == other.k;
}
};
struct ggml_webgpu_solve_tri_pipeline_key_hash {
size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.n);
ggml_webgpu_hash_combine(seed, key.k);
return seed;
}
};
/** SSM Conv **/
struct ggml_webgpu_ssm_conv_pipeline_key {
int type;
int vectorized;
bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
return type == other.type && vectorized == other.vectorized;
}
};
/** Gated Delta Net **/
struct ggml_webgpu_gated_delta_net_pipeline_key {
int type;
int s_v;
int kda;
bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
return type == other.type && s_v == other.s_v && kda == other.kda;
}
};
struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.s_v);
ggml_webgpu_hash_combine(seed, key.kda);
return seed;
}
};
struct ggml_webgpu_ssm_conv_pipeline_key_hash {
size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.vectorized);
return seed;
}
};
/** Scale **/
struct ggml_webgpu_scale_pipeline_key {
@@ -466,14 +552,22 @@ class ggml_webgpu_shader_lib {
unary_pipelines; // type/op/inplace
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
scale_pipelines; // inplace
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
solve_tri_pipelines; // type
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
ssm_conv_pipelines; // type/vectorized
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
webgpu_pipeline,
ggml_webgpu_gated_delta_net_pipeline_key_hash>
gated_delta_net_pipelines; // type/S_v/kda
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
pad_pipelines; // circular/non-circular
pad_pipelines; // circular/non-circular
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
binary_pipelines; // type/op/inplace/overlap
binary_pipelines; // type/op/inplace/overlap
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
concat_pipelines; // type
concat_pipelines; // type
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
repeat_pipelines; // type
repeat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
@@ -487,6 +581,7 @@ class ggml_webgpu_shader_lib {
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@@ -519,11 +614,11 @@ class ggml_webgpu_shader_lib {
switch (key.op) {
case GGML_OP_RMS_NORM:
defines.push_back("OP_RMS_NORM");
defines.push_back("RMS_NORM");
variant = "rms_norm";
break;
case GGML_OP_L2_NORM:
defines.push_back("OP_L2_NORM");
defines.push_back("L2_NORM");
variant = "l2_norm";
break;
default:
@@ -535,8 +630,9 @@ class ggml_webgpu_shader_lib {
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
const uint32_t row_norm_wg_size = 128u;
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
return row_norm_pipelines[key];
@@ -609,6 +705,46 @@ class ggml_webgpu_shader_lib {
return set_rows_pipelines[key];
}
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace };
auto it = set_pipelines.find(key);
if (it != set_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "set";
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_I32:
defines.push_back("TYPE_I32");
variant += "_i32";
break;
default:
GGML_ABORT("Unsupported type for set shader");
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_set, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
set_pipelines[key] = pipeline;
return set_pipelines[key];
}
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
auto it = cumsum_pipelines.find(1);
if (it != cumsum_pipelines.end()) {
@@ -695,6 +831,7 @@ class ggml_webgpu_shader_lib {
switch (key.src_type) {
case GGML_TYPE_F32:
defines.push_back("FLOAT_PARALLEL");
if (key.vectorized) {
defines.push_back("F32_VEC");
defines.push_back("SRC_TYPE=vec4<f32>");
@@ -709,6 +846,7 @@ class ggml_webgpu_shader_lib {
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("FLOAT_PARALLEL");
defines.push_back("F16");
defines.push_back("SRC_TYPE=f16");
defines.push_back("DST_TYPE=f32");
@@ -716,6 +854,7 @@ class ggml_webgpu_shader_lib {
variant += "_f16";
break;
case GGML_TYPE_I32:
defines.push_back("FLOAT_PARALLEL");
defines.push_back("I32");
defines.push_back("SRC_TYPE=i32");
defines.push_back("DST_TYPE=i32");
@@ -794,6 +933,128 @@ class ggml_webgpu_shader_lib {
return scale_pipelines[key];
}
webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_solve_tri_pipeline_key key = {
.type = context.dst->type,
.n = (int) context.src0->ne[0],
.k = (int) context.src1->ne[0],
};
auto it = solve_tri_pipelines.find(key);
if (it != solve_tri_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "solve_tri";
switch (key.type) {
case GGML_TYPE_F32:
variant += "_f32";
break;
default:
GGML_ABORT("Unsupported type for solve_tri shader");
}
const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size);
const uint32_t k_tile = wg_size;
const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row);
defines.push_back(std::string("N=") + std::to_string(key.n));
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("K_TILE=") + std::to_string(k_tile));
defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n));
auto processed = preprocessor.preprocess(wgsl_solve_tri, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
solve_tri_pipelines[key] = pipeline;
return solve_tri_pipelines[key];
}
webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_ssm_conv_pipeline_key key = {
.type = context.dst->type,
.vectorized = context.src1->ne[0] == 4,
};
auto it = ssm_conv_pipelines.find(key);
if (it != ssm_conv_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "ssm_conv";
switch (key.type) {
case GGML_TYPE_F32:
variant += "_f32";
break;
default:
GGML_ABORT("Unsupported type for ssm_conv shader");
}
if (key.vectorized) {
defines.push_back("VECTORIZED");
variant += "_vec4";
}
constexpr uint32_t block_size = 32u;
constexpr uint32_t tokens_per_wg = 8u;
defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u");
defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u");
auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines);
auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
decisions->block_size = block_size;
decisions->tokens_per_wg = tokens_per_wg;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
ssm_conv_pipelines[key] = pipeline;
return ssm_conv_pipelines[key];
}
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_gated_delta_net_pipeline_key key = {
.type = context.dst->type,
.s_v = (int) context.src2->ne[0],
.kda = context.src3->ne[0] == context.src2->ne[0],
};
auto it = gated_delta_net_pipelines.find(key);
if (it != gated_delta_net_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "gated_delta_net";
switch (key.type) {
case GGML_TYPE_F32:
variant += "_f32";
break;
default:
GGML_ABORT("Unsupported type for gated_delta_net shader");
}
if (key.kda) {
defines.push_back("KDA");
variant += "_kda";
}
defines.push_back("S_V=" + std::to_string(key.s_v) + "u");
defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u");
auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines);
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
gated_delta_net_pipelines[key] = pipeline;
return gated_delta_net_pipelines[key];
}
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };

View File

@@ -880,6 +880,68 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
params, entries, wg_x);
}
static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
const bool inplace = ggml_webgpu_tensor_equal(src0, dst);
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = inplace,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst);
const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type);
std::vector<uint32_t> params = {
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (((const int32_t *) dst->op_params)[3] / dst_type_size),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
1u,
(uint32_t) (((const int32_t *) dst->op_params)[0] / dst_type_size),
(uint32_t) (((const int32_t *) dst->op_params)[1] / dst_type_size),
(uint32_t) (((const int32_t *) dst->op_params)[2] / dst_type_size),
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
};
std::vector<wgpu::BindGroupEntry> entries;
uint32_t binding_index = 0;
if (!inplace) {
entries.push_back({ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) });
binding_index++;
}
entries.push_back({ .binding = binding_index,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
entries.push_back({ .binding = binding_index + 1,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
@@ -935,6 +997,208 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) src1->ne[0],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size);
const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_ssm_conv_shader_decisions *>(pipeline.context.get());
const uint32_t token_tiles = CEIL_DIV((uint32_t) dst->ne[1], decisions->tokens_per_wg);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) src1->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
token_tiles,
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size);
const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2];
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * src3,
ggml_tensor * src4,
ggml_tensor * src5,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.src2 = src2,
.src3 = src3,
.src4 = src4,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx);
const uint32_t s_v = (uint32_t) src2->ne[0];
const uint32_t h = (uint32_t) src2->ne[1];
const uint32_t n_tokens = (uint32_t) src2->ne[2];
const uint32_t n_seqs = (uint32_t) src2->ne[3];
const float scale = 1.0f / sqrtf((float) s_v);
uint32_t scale_u32;
memcpy(&scale_u32, &scale, sizeof(scale_u32));
std::vector<uint32_t> params = {
h,
n_tokens,
n_seqs,
s_v * h * n_tokens * n_seqs,
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
(uint32_t) (src2->nb[2] / ggml_type_size(src2->type)),
(uint32_t) (src2->nb[3] / ggml_type_size(src2->type)),
(uint32_t) (src4->nb[1] / ggml_type_size(src4->type)),
(uint32_t) (src4->nb[2] / ggml_type_size(src4->type)),
(uint32_t) (src4->nb[3] / ggml_type_size(src4->type)),
(uint32_t) src0->ne[1],
(uint32_t) (src2->ne[3] / src0->ne[3]),
scale_u32,
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(src2),
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
.size = ggml_webgpu_tensor_binding_size(ctx, src2) },
{ .binding = 3,
.buffer = ggml_webgpu_tensor_buf(src3),
.offset = ggml_webgpu_tensor_align_offset(ctx, src3),
.size = ggml_webgpu_tensor_binding_size(ctx, src3) },
{ .binding = 4,
.buffer = ggml_webgpu_tensor_buf(src4),
.offset = ggml_webgpu_tensor_align_offset(ctx, src4),
.size = ggml_webgpu_tensor_binding_size(ctx, src4) },
{ .binding = 5,
.buffer = ggml_webgpu_tensor_buf(src5),
.offset = ggml_webgpu_tensor_align_offset(ctx, src5),
.size = ggml_webgpu_tensor_binding_size(ctx, src5) },
{ .binding = 6,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs);
}
static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * idx,
@@ -1016,6 +1280,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * idx,
ggml_tensor * dst) {
const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32;
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.src1 = nullptr,
@@ -1060,7 +1326,10 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1));
uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]);
uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows;
uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@@ -1632,7 +1901,7 @@ static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * s
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.dst = dst,
.max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = inplace,
};
@@ -2176,6 +2445,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_CPY:
case GGML_OP_CONT:
return ggml_webgpu_cpy(ctx, src0, node);
case GGML_OP_SET:
return ggml_webgpu_set(ctx, src0, src1, node);
case GGML_OP_SET_ROWS:
return ggml_webgpu_set_rows(ctx, src0, src1, node);
case GGML_OP_GET_ROWS:
@@ -2219,6 +2490,12 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_DIAG:
case GGML_OP_TRI:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SOLVE_TRI:
return ggml_webgpu_solve_tri(ctx, src0, src1, node);
case GGML_OP_SSM_CONV:
return ggml_webgpu_ssm_conv(ctx, src0, src1, node);
case GGML_OP_GATED_DELTA_NET:
return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node);
case GGML_OP_PAD:
return ggml_webgpu_pad(ctx, src0, node);
case GGML_OP_ARGMAX:
@@ -2957,7 +3234,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
/* .is_host = */ NULL, // defaults to false
},
/* .device = */
dev,
dev,
/* .context = */ NULL
};
@@ -3040,6 +3317,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
(op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
break;
case GGML_OP_SET:
supports_op = src0->type == src1->type && src0->type == op->type &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32);
break;
case GGML_OP_SET_ROWS:
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
@@ -3180,6 +3461,27 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
}
}
break;
case GGML_OP_TRI:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_DIAG:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_SOLVE_TRI:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
break;
case GGML_OP_SSM_CONV:
supports_op = op->type == GGML_TYPE_F32;
break;
case GGML_OP_GATED_DELTA_NET:
{
const uint32_t s_v = (uint32_t) src2->ne[0];
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 &&
src2->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 &&
op->src[4]->type == GGML_TYPE_F32 && op->src[5]->type == GGML_TYPE_F32 &&
s_v <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
}
break;
case GGML_OP_CLAMP:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
@@ -3201,12 +3503,6 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_COS:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_OP_DIAG:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_OP_TRI:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_OP_PAD:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;

View File

@@ -0,0 +1,132 @@
@group(0) @binding(0)
var<storage, read_write> src_q: array<f32>;
@group(0) @binding(1)
var<storage, read_write> src_k: array<f32>;
@group(0) @binding(2)
var<storage, read_write> src_v: array<f32>;
@group(0) @binding(3)
var<storage, read_write> src_g: array<f32>;
@group(0) @binding(4)
var<storage, read_write> src_beta: array<f32>;
@group(0) @binding(5)
var<storage, read_write> src_state: array<f32>;
@group(0) @binding(6)
var<storage, read_write> dst: array<f32>;
struct Params {
h: u32,
n_tokens: u32,
n_seqs: u32,
s_off: u32,
sq1: u32,
sq2: u32,
sq3: u32,
sv1: u32,
sv2: u32,
sv3: u32,
sb1: u32,
sb2: u32,
sb3: u32,
neq1: u32,
rq3: u32,
scale: f32,
};
@group(0) @binding(7)
var<uniform> params: Params;
var<workgroup> sh_k: array<f32, S_V>;
var<workgroup> sh_q: array<f32, S_V>;
var<workgroup> sh_g: array<f32, S_V>;
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>
) {
let head_id = workgroup_id.x;
let seq_id = workgroup_id.y;
let col = local_id.x;
let iq1 = head_id % params.neq1;
let iq3 = seq_id / params.rq3;
let state_size = S_V * S_V;
let state_base = (seq_id * params.h + head_id) * state_size;
var state: array<f32, S_V>;
for (var i = 0u; i < S_V; i++) {
state[i] = src_state[state_base + col * S_V + i];
}
var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V;
for (var t = 0u; t < params.n_tokens; t++) {
let q_off = iq3 * params.sq3 + t * params.sq2 + iq1 * params.sq1;
let k_off = q_off;
let v_off = seq_id * params.sv3 + t * params.sv2 + head_id * params.sv1;
let gb_off = seq_id * params.sb3 + t * params.sb2 + head_id * params.sb1;
sh_q[col] = src_q[q_off + col];
sh_k[col] = src_k[k_off + col];
#ifdef KDA
let g_base = gb_off * S_V;
sh_g[col] = exp(src_g[g_base + col]);
#endif
workgroupBarrier();
let v_val = src_v[v_off + col];
let beta_val = src_beta[gb_off];
var kv_col = 0.0;
var delta_col = 0.0;
var attn_col = 0.0;
#ifdef KDA
for (var i = 0u; i < S_V; i++) {
kv_col += (sh_g[i] * state[i]) * sh_k[i];
}
delta_col = (v_val - kv_col) * beta_val;
for (var i = 0u; i < S_V; i++) {
state[i] = sh_g[i] * state[i] + sh_k[i] * delta_col;
attn_col += state[i] * sh_q[i];
}
#else
let g_val = exp(src_g[gb_off]);
for (var i = 0u; i < S_V; i++) {
kv_col += state[i] * sh_k[i];
}
delta_col = (v_val - g_val * kv_col) * beta_val;
for (var i = 0u; i < S_V; i++) {
state[i] = g_val * state[i] + sh_k[i] * delta_col;
attn_col += state[i] * sh_q[i];
}
#endif
dst[attn_off + col] = attn_col * params.scale;
attn_off += S_V * params.h;
workgroupBarrier();
}
for (var i = 0u; i < S_V; i++) {
dst[params.s_off + state_base + col * S_V + i] = state[i];
}
}

View File

@@ -640,6 +640,35 @@ var<uniform> params: Params;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
#ifdef FLOAT_PARALLEL
let blocks_per_row = params.ne0 / BLOCK_SIZE;
let row_count = params.n_rows * params.ne2 * params.ne3;
if (gid.x >= blocks_per_row * row_count) {
return;
}
let block_idx = gid.x % blocks_per_row;
var row_idx = gid.x / blocks_per_row;
let i_dst3 = row_idx / (params.ne2 * params.n_rows);
row_idx = row_idx % (params.ne2 * params.n_rows);
let i_dst2 = row_idx / params.n_rows;
let i_dst1 = row_idx % params.n_rows;
let i_idx2 = i_dst3 % params.idx2;
let i_idx1 = i_dst2 % params.idx1;
let i_idx0 = i_dst1;
let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
let idx_val = u32(idx[i_idx]);
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
copy_elements(i_src_row, i_dst_row, block_idx);
#else
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
return;
}
@@ -664,5 +693,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
copy_elements(i_src_row, i_dst_row, i);
}
#endif
}

View File

@@ -81,11 +81,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
}
sum = scratch[0];
#ifdef OP_RMS_NORM
#ifdef RMS_NORM
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
#elif OP_L2_NORM
#elif defined(L2_NORM)
let scale = 1.0/max(sqrt(sum), params.eps);
#endif
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {

View File

@@ -0,0 +1,109 @@
#ifdef TYPE_I32
#define TYPE i32
#else
#define TYPE f32
#endif
#ifndef INPLACE
@group(0) @binding(0)
var<storage, read_write> src0: array<TYPE>;
#define SRC1_BINDING 1
#else
#define SRC1_BINDING 0
#endif
#define DST_BINDING SRC1_BINDING + 1
#define PARAMS_BINDING SRC1_BINDING + 2
@group(0) @binding(SRC1_BINDING)
var<storage, read_write> src1: array<TYPE>;
@group(0) @binding(DST_BINDING)
var<storage, read_write> dst: array<TYPE>;
struct Params {
ne: u32,
offset_src0: u32,
offset_src1: u32,
offset_view: u32,
stride_src10: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst10: u32,
stride_dst11: u32,
stride_dst12: u32,
stride_dst13: u32,
src1_ne0: u32,
src1_ne1: u32,
src1_ne2: u32,
src1_ne3: u32,
};
@group(0) @binding(PARAMS_BINDING)
var<uniform> params: Params;
fn decode_src1_coords(idx: u32) -> vec4<u32> {
var i = idx;
let plane = params.src1_ne2 * params.src1_ne1 * params.src1_ne0;
let i3 = i / plane;
i = i % plane;
let row = params.src1_ne1 * params.src1_ne0;
let i2 = i / row;
i = i % row;
let i1 = i / params.src1_ne0;
let i0 = i % params.src1_ne0;
return vec4<u32>(i0, i1, i2, i3);
}
fn decode_view_coords(rel: u32) -> vec4<u32> {
let i3 = rel / params.stride_dst13;
let rem3 = rel % params.stride_dst13;
let i2 = rem3 / params.stride_dst12;
let rem2 = rem3 % params.stride_dst12;
let i1 = rem2 / params.stride_dst11;
let i0 = rem2 % params.stride_dst11;
return vec4<u32>(i0, i1, i2, i3);
}
fn view_rel_from_coords(coords: vec4<u32>) -> u32 {
return coords.x * params.stride_dst10 + coords.y * params.stride_dst11 +
coords.z * params.stride_dst12 + coords.w * params.stride_dst13;
}
fn src1_idx_from_coords(coords: vec4<u32>) -> u32 {
return coords.x * params.stride_src10 + coords.y * params.stride_src11 +
coords.z * params.stride_src12 + coords.w * params.stride_src13;
}
fn in_set_view(rel: u32, coords: vec4<u32>) -> bool {
return view_rel_from_coords(coords) == rel;
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
#ifdef INPLACE
let coords = decode_src1_coords(gid.x);
let src1_idx = params.offset_src1 + src1_idx_from_coords(coords);
let dst_idx = params.offset_view + view_rel_from_coords(coords);
dst[dst_idx] = src1[src1_idx];
#else
let rel = select(params.ne, gid.x - params.offset_view, gid.x >= params.offset_view);
let coords = decode_view_coords(rel);
if (rel < params.stride_dst13 * params.src1_ne3 && in_set_view(rel, coords)) {
dst[gid.x] = src1[params.offset_src1 + src1_idx_from_coords(coords)];
} else {
dst[gid.x] = src0[params.offset_src0 + gid.x];
}
#endif
}

View File

@@ -0,0 +1,121 @@
@group(0) @binding(0)
var<storage, read_write> src0: array<f32>;
@group(0) @binding(1)
var<storage, read_write> src1: array<f32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src00: u32,
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_src10: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
k: u32,
ne2: u32,
ne3: u32,
};
@group(0) @binding(3)
var<uniform> params: Params;
var<workgroup> shA: array<f32, BATCH_N * N>;
var<workgroup> shB: array<f32, BATCH_N * K_TILE>;
fn src0_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
return params.offset_src0 +
col * params.stride_src00 +
row * params.stride_src01 +
i2 * params.stride_src02 +
i3 * params.stride_src03;
}
fn src1_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
return params.offset_src1 +
col * params.stride_src10 +
row * params.stride_src11 +
i2 * params.stride_src12 +
i3 * params.stride_src13;
}
fn dst_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
return params.offset_dst +
col * params.stride_dst0 +
row * params.stride_dst1 +
i2 * params.stride_dst2 +
i3 * params.stride_dst3;
}
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>
) {
let batch = workgroup_id.y;
let col = workgroup_id.x * WG_SIZE + local_id.x;
let i3 = batch / params.ne2;
let i2 = batch % params.ne2;
let active_lane = local_id.x < K_TILE;
let active_col = active_lane && col < params.k;
var X: array<f32, N>;
for (var row_base = 0u; row_base < N; row_base += BATCH_N) {
let cur_n = min(BATCH_N, N - row_base);
for (var i = local_id.x; i < cur_n * N; i += WG_SIZE) {
let tile_row = i / N;
let tile_col = i % N;
shA[i] = src0[src0_idx(row_base + tile_row, tile_col, i2, i3)];
}
for (var i = local_id.x; i < cur_n * K_TILE; i += WG_SIZE) {
let tile_row = i / K_TILE;
let tile_col = i % K_TILE;
let global_col = workgroup_id.x * WG_SIZE + tile_col;
let sh_idx = tile_row * K_TILE + tile_col;
if (global_col < params.k) {
shB[sh_idx] = src1[src1_idx(row_base + tile_row, global_col, i2, i3)];
} else {
shB[sh_idx] = 0.0;
}
}
workgroupBarrier();
if (active_col) {
for (var row_offset = 0u; row_offset < cur_n; row_offset++) {
let r = row_base + row_offset;
var b = shB[row_offset * K_TILE + local_id.x];
let a_row = row_offset * N;
for (var t = 0u; t < r; t++) {
b -= shA[a_row + t] * X[t];
}
let x = b / shA[a_row + r];
X[r] = x;
dst[dst_idx(r, col, i2, i3)] = x;
}
}
workgroupBarrier();
}
}

View File

@@ -0,0 +1,65 @@
@group(0) @binding(0)
var<storage, read_write> src0: array<f32>;
@group(0) @binding(1)
var<storage, read_write> src1: array<f32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src01: u32,
stride_src02: u32,
stride_src11: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
nc: u32,
nr: u32,
n_t: u32,
n_s: u32,
token_tiles: u32,
};
@group(0) @binding(3)
var<uniform> params: Params;
@compute @workgroup_size(BLOCK_SIZE, TOKENS_PER_WG)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i1 = gid.x;
let tile_y = gid.y / TOKENS_PER_WG;
let local_token = gid.y % TOKENS_PER_WG;
let i3 = tile_y / params.token_tiles;
let token_tile = tile_y % params.token_tiles;
let i2 = token_tile * TOKENS_PER_WG + local_token;
if (i1 >= params.nr || i2 >= params.n_t || i3 >= params.n_s) {
return;
}
let src0_base = params.offset_src0 + i3 * params.stride_src02 + i2 + i1 * params.stride_src01;
let src1_base = params.offset_src1 + i1 * params.stride_src11;
var sum = 0.0;
#ifdef VECTORIZED
sum =
src0[src0_base + 0u] * src1[src1_base + 0u] +
src0[src0_base + 1u] * src1[src1_base + 1u] +
src0[src0_base + 2u] * src1[src1_base + 2u] +
src0[src0_base + 3u] * src1[src1_base + 3u];
#else
for (var i0 = 0u; i0 < params.nc; i0++) {
sum += src0[src0_base + i0] * src1[src1_base + i0];
}
#endif
let dst_idx = params.offset_dst + i3 * params.stride_dst2 + i2 * params.stride_dst1 + i1 * params.stride_dst0;
dst[dst_idx] = sum;
}

View File

@@ -301,6 +301,8 @@ class Keys:
IMAGE_SIZE = "clip.vision.image_size"
IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels"
IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels"
PREPROC_MIN_TILES = "clip.vision.preproc_min_tiles"
PREPROC_MAX_TILES = "clip.vision.preproc_max_tiles"
PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size"
PATCH_SIZE = "clip.vision.patch_size"
EMBEDDING_LENGTH = "clip.vision.embedding_length"
@@ -3869,6 +3871,8 @@ class LlamaFileType(IntEnum):
# MOSTLY_Q4_0_8_8 = 35 # removed from gguf files, use Q4_0 and runtime repack
MOSTLY_TQ1_0 = 36 # except 1d tensors
MOSTLY_TQ2_0 = 37 # except 1d tensors
MOSTLY_MXFP4_MOE = 38 # except 1d tensors
MOSTLY_NVFP4 = 39 # except 1d tensors
GUESSED = 1024 # not specified in the model file

View File

@@ -1156,6 +1156,12 @@ class GGUFWriter:
def add_vision_min_pixels(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value)
def add_vision_preproc_max_tiles(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PREPROC_MAX_TILES, value)
def add_vision_preproc_min_tiles(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PREPROC_MIN_TILES, value)
def add_vision_preproc_image_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value)
@@ -1300,7 +1306,7 @@ class GGUFWriter:
else:
raise ValueError("Invalid GGUF metadata value type or value")
return kv_data
return bytes(kv_data)
@staticmethod
def format_n_bytes_to_str(num: int) -> str:

View File

@@ -138,7 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
if isinstance(meta_noop, tuple):
dtype, shape = meta_noop
assert callable(shape)
res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape)) # ty: ignore[call-top-callable]
else:
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)

View File

@@ -91,11 +91,11 @@ class __Quant(ABC):
def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
cls.qtype = qtype
cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
cls.__quantize_lazy = LazyNumpyTensor._wrap_fn(
cls.__quantize_lazy: Any = LazyNumpyTensor._wrap_fn(
cls.__quantize_array,
meta_noop=(np.uint8, cls.__shape_to_bytes)
)
cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn(
cls.__dequantize_lazy: Any = LazyNumpyTensor._wrap_fn(
cls.__dequantize_array,
meta_noop=(np.float32, cls.__shape_from_bytes)
)

View File

@@ -11,33 +11,33 @@ from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVa
try:
from sentencepiece import SentencePieceProcessor
except ImportError:
SentencePieceProcessor = None
SentencePieceProcessor: Any = None
try:
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.utils import ( # type: ignore[import-not-found]
_filter_valid_tokenizer_files,
)
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found]
SentencePieceTokenizer,
)
except ImportError:
_mistral_common_installed = False
MistralTokenizer = None
Tekkenizer = None
SentencePieceTokenizer = None
_filter_valid_tokenizer_files = None
MistralTokenizer: Any = None
Tekkenizer: Any = None
SentencePieceTokenizer: Any = None
_filter_valid_tokenizer_files: Any = None
else:
_mistral_common_installed = True
try:
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.utils import ( # type: ignore[import-not-found]
get_one_valid_tokenizer_file,
)
except ImportError:
# We still want the conversion to work with older mistral-common versions.
get_one_valid_tokenizer_file = None
get_one_valid_tokenizer_file: Any = None
import gguf
@@ -703,7 +703,7 @@ class MistralVocab(Vocab):
tokenizer_file_path = base_path / tokenizer_file
self.tokenizer = MistralTokenizer.from_file(
self.tokenizer: Any = MistralTokenizer.from_file(
tokenizer_file_path
).instruct_tokenizer.tokenizer
self.tokenizer_type = (

View File

@@ -7,7 +7,6 @@
{%- set available_tool_string = '' -%}
{%- set add_tool_id = true -%}
{%- set add_thoughts = true -%} {# whether to include <thinking> reasoning blocks #}
{%- set add_generation_prompt = true -%} {# whether to emit reasoning starter before assistant response #}
{# Optional token placeholders (safe defaults) #}
{%- set bos_token = bos_token or '' -%}
{%- set eos_token = eos_token or '' -%}

View File

@@ -15,10 +15,10 @@
{%- set ns.is_tool = false -%}
{%- for tool in message['tool_calls']-%}
{%- if not ns.is_first -%}
{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}}
{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] | tojson + '\n' + '```' + '<tool▁call▁end>'}}
{%- set ns.is_first = true -%}
{%- else -%}
{{'\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}}
{{'\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] | tojson + '\n' + '```' + '<tool▁call▁end>'}}
{%- endif -%}
{%- endfor -%}
{{'<tool▁calls▁end><end▁of▁sentence>'}}

View File

@@ -28,25 +28,25 @@
{%- set ns.is_last_user = true -%}{{'<User>' + message['content']}}
{%- endif -%}
{%- if message['role'] == 'assistant' and message['tool_calls'] -%}
{%- if ns.is_last_user -%}{{'<Assistant></think>'}}
{%- if ns.is_last_user -%}{{'<Assistant><think></think>'}}
{%- endif -%}
{%- set ns.is_last_user = false -%}
{%- set ns.is_first = false -%}
{%- set ns.is_tool = false -%}
{%- for tool in message['tool_calls'] -%}
{%- if not ns.is_first -%}
{%- if not message['content'] -%}{{'<tool▁calls▁begin><tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] + '<tool▁call▁end>'}}
{%- else -%}{{message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] + '<tool▁call▁end>'}}
{%- if not message['content'] -%}{{'<tool▁calls▁begin><tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] | tojson + '<tool▁call▁end>'}}
{%- else -%}{{message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] | tojson + '<tool▁call▁end>'}}
{%- endif -%}
{%- set ns.is_first = true -%}
{%- else -%}{{'<tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] + '<tool▁call▁end>'}}
{%- else -%}{{'<tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] | tojson + '<tool▁call▁end>'}}
{%- endif -%}
{%- endfor -%}{{'<tool▁calls▁end><end▁of▁sentence>'}}
{%- endif -%}
{%- if message['role'] == 'assistant' and not message['tool_calls'] -%}
{%- if ns.is_last_user -%}{{'<Assistant>'}}
{%- if message['prefix'] is defined and message['prefix'] and thinking -%}{{'<think>'}}
{%- else -%}{{'</think>'}}
{%- else -%}{{'<think></think>'}}
{%- endif -%}
{%- endif -%}
{%- set ns.is_last_user = false -%}
@@ -65,7 +65,7 @@
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool -%}{{'<Assistant>'}}
{%- if not thinking -%}{{'</think>'}}
{%- else -%}{{'<think>'}}
{%- if not thinking -%}{{'<think></think>'}}
{%- else -%}{{'<think>'}}
{%- endif -%}
{%- endif %}

View File

@@ -49,7 +49,7 @@ Example function tool call syntax:
{%- endif -%}
{%- set tool_name = tc['function']['name'] -%}
{%- set tool_args = tc['function']['arguments'] -%}
{{- '<tool▁call▁begin>' + tc['type'] + '<tool▁sep>' + tool_name + '\n' + '```json' + '\n' + tool_args + '\n' + '```' + '<tool▁call▁end>' -}}
{{- '<tool▁call▁begin>' + tc['type'] + '<tool▁sep>' + tool_name + '\n' + '```json' + '\n' + tool_args | tojson + '\n' + '```' + '<tool▁call▁end>' -}}
{%- endfor -%}
{{- '<tool▁calls▁end><end▁of▁sentence>' -}}
{%- endif -%}

View File

@@ -42,9 +42,9 @@
{%- if 'tool_calls' in message and message['tool_calls'] -%}
{%- for tool_call in message['tool_calls'] -%}
{%- if tool_call["function"]["name"] == "python" -%}
{{ '<|python_tag|>' + tool_call['function']['arguments'] }}
{{ '<|python_tag|>' + tool_call['function']['arguments'] | tojson }}
{%- else -%}
{{ '<function=' + tool_call['function']['name'] + '>' + tool_call['function']['arguments'] + '</function>' }}
{{ '<function=' + tool_call['function']['name'] + '>' + tool_call['function']['arguments'] | tojson + '</function>' }}
{%- endif -%}
{%- endfor -%}
{{ '<|eom_id|>' }}

View File

@@ -1,5 +1,5 @@
{
"extraPaths": ["gguf-py", "examples/model-conversion/scripts"],
"extraPaths": ["gguf-py", "examples/model-conversion/scripts", "examples/model-conversion/scripts/utils"],
"pythonVersion": "3.9",
"pythonPlatform": "All",
"reportUnusedImport": "warning",

View File

@@ -684,6 +684,7 @@ else:
sys.exit(1)
assert isinstance(hexsha8_baseline, str)
name_baseline = bench_data.get_commit_name(hexsha8_baseline)
hexsha8_compare = name_compare = None
@@ -717,6 +718,7 @@ else:
parser.print_help()
sys.exit(1)
assert isinstance(hexsha8_compare, str)
name_compare = bench_data.get_commit_name(hexsha8_compare)
# Get tool-specific configuration

View File

@@ -95,9 +95,9 @@ if __name__ == '__main__':
'-p', 'Hey',
'--no-warmup',
'--log-disable',
'-no-cnv']
'-st']
if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo:
cmd.append('-fa')
cmd += ('-fa', 'on')
try:
subprocess.check_call(cmd)
except subprocess.CalledProcessError:

View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python3
import sys
from collections import defaultdict
def parse_log_file(filepath):
"""Parse log file and extract function VGPR usage."""
import re
functions = defaultdict(lambda: {'vgprs': 0, 'spill': 0, 'location': ''})
try:
with open(filepath, 'r') as f:
content = f.read()
# Find all function entries with VGPR usage including location
pattern = r'([^:]+:\d+):.*?Function Name: (\S+).*?VGPRs: (\d+).*?VGPRs Spill: (\d+)'
matches = re.findall(pattern, content, re.DOTALL)
for location, func_name, vgprs, spill in matches:
functions[func_name]['vgprs'] = int(vgprs)
functions[func_name]['spill'] = int(spill)
# Extract just the filename and line number
parts = location.split('/')
if len(parts) > 0:
short_location = parts[-1] # Get last part (filename)
# Check if there's a line number after filename
if ':' in short_location:
functions[func_name]['location'] = short_location
else:
functions[func_name]['location'] = location
else:
functions[func_name]['location'] = location
except FileNotFoundError:
print(f"Error: File {filepath} not found", file=sys.stderr) # noqa: NP100
sys.exit(1)
return functions
def main():
if len(sys.argv) < 2:
print("Usage: ./vgpr_check.py <log_file>", file=sys.stderr) # noqa: NP100
sys.exit(1)
log_file = sys.argv[1]
ignored = {
'_ZL21gated_linear_attn_f32ILi128EEviiiifPKfS1_S1_S1_S1_Pf',
'_ZL18flash_attn_ext_f16ILi64ELi64ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi64ELi64ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL13rwkv_wkv7_f32ILi128EEviiiiPKfS1_S1_S1_S1_S1_S1_Pf',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi1ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi2ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi16ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi4ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi4ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi64ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi64ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi64ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi64ELi1ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi64ELi64ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi4ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi8ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL24mul_mat_q_stream_k_fixupIL9ggml_type22ELi8ELb1EEvPKiS2_PfPKfiiimimimi',
'_ZL9mul_mat_qIL9ggml_type3ELi32ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type3ELi48ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type20ELi32ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type17ELi64ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL15flash_attn_tileILi256ELi256ELi32ELi1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL9mul_mat_qIL9ggml_type19ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type17ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type22ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type19ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type19ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type7ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type3ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type3ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type7ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type7ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type11ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type11ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL24mul_mat_q_stream_k_fixupIL9ggml_type11ELi128ELb0EEvPKiS2_PfPKfiiimimimi',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL9mul_mat_qIL9ggml_type2ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi32ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi1ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
}
functions = parse_log_file(log_file)
found_issues = False
# First print all ignored functions (deduplicated)
printed_ignored = set()
for func_name, data in sorted(functions.items()):
total_vgprs = int(data['vgprs']) + int(data['spill'])
if total_vgprs > 256 and func_name in ignored and func_name not in printed_ignored:
location = data.get('location', log_file)
print(f"{location}: {func_name} - Total VGPRs: {total_vgprs} ({data['vgprs']} + {data['spill']}) [IGNORED]") # noqa: NP100
printed_ignored.add(func_name)
# Then print new functions with issues in red
for func_name, data in sorted(functions.items()):
total_vgprs = int(data['vgprs']) + int(data['spill'])
if total_vgprs > 256 and func_name not in ignored:
status = "[IGNORED]" if func_name in ignored else ""
location = data.get('location', log_file)
# Print in red if not ignored
color_code = "\033[91m" if func_name not in ignored else ""
reset_code = "\033[0m" if func_name not in ignored else ""
print(f"{color_code}{location}: {func_name} - Total VGPRs: {total_vgprs} ({data['vgprs']} + {data['spill']}) {status}{reset_code}") # noqa: NP100
if func_name not in ignored:
found_issues = True
sys.exit(1 if found_issues else 0)
if __name__ == "__main__":
main()

View File

@@ -241,10 +241,10 @@ class CodeEditor(QPlainTextEdit):
if not self.isReadOnly():
selection = QTextEdit.ExtraSelection()
line_color = QColorConstants.Yellow.lighter(160)
selection.format.setBackground(line_color) # pyright: ignore[reportAttributeAccessIssue]
selection.format.setProperty(QTextFormat.Property.FullWidthSelection, True) # pyright: ignore[reportAttributeAccessIssue]
selection.cursor = self.textCursor() # pyright: ignore[reportAttributeAccessIssue]
selection.cursor.clearSelection() # pyright: ignore[reportAttributeAccessIssue]
selection.format.setBackground(line_color) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
selection.format.setProperty(QTextFormat.Property.FullWidthSelection, True) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
selection.cursor = self.textCursor() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
selection.cursor.clearSelection() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
extra_selections.append(selection)
self.setExtraSelections(extra_selections)
@@ -262,8 +262,8 @@ class CodeEditor(QPlainTextEdit):
)
extra = QTextEdit.ExtraSelection()
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue]
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue]
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
self.setExtraSelections(self.extraSelections() + [extra])
@@ -274,8 +274,8 @@ class CodeEditor(QPlainTextEdit):
cursor.select(QTextCursor.SelectionType.LineUnderCursor)
extra = QTextEdit.ExtraSelection()
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue]
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue]
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
self.setExtraSelections(self.extraSelections() + [extra])
@@ -395,8 +395,8 @@ class JinjaTester(QMainWindow):
ensure_ascii=ensure_ascii,
)
)
env.globals["strftime_now"] = lambda format: datetime.now().strftime(format)
env.globals["raise_exception"] = raise_exception
env.globals["strftime_now"] = lambda format: datetime.now().strftime(format) # ty: ignore[invalid-assignment]
env.globals["raise_exception"] = raise_exception # ty: ignore[invalid-assignment]
try:
template = env.from_string(template_str)
output = template.render(context)

View File

@@ -189,6 +189,7 @@ def benchmark(
data: list[dict] = []
assert isinstance(prompts, list)
for i, p in enumerate(prompts):
if seed_offset >= 0:
random.seed(3 * (seed_offset + 1000 * i) + 1)

View File

@@ -39,6 +39,9 @@ opmask=
nhvx=
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
hmx=
[ "$HMX" != "" ] && hmx="GGML_HEXAGON_USE_HMX=$HMX"
ndev=
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
@@ -51,7 +54,7 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
$verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb \
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 256 -fa on \

Some files were not shown because too many files have changed in this diff Show More