Compare commits

...

17 Commits
b8995 ... b9012

Author SHA1 Message Date
Julien Denize
048a490f76 convert : Mistral format yarn apply_scale support (#22612)
* [BUGFIX] Mistral format apply_scale support.

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix misunderstood boolean parameters

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-05-03 21:51:21 +02:00
JM Robles
db44417b02 convert : apply Q/K RoPE permutation in NVFP4 repack path (#22611)
Llama-architecture q_proj/k_proj weights need an axis-0 row permutation
to match GGML's RoPE convention. The BF16 path applies this in
LlamaModel.modify_tensors via LlamaModel.permute, but the NVFP4 path
bypasses modify_tensors and writes weights directly through
ModelBase._repack_nvfp4. Without the permutation, attention heads end
up scrambled at inference and the model produces gibberish.

This change overrides _repack_nvfp4 on LlamaModel and applies the same
permutation to both the nibble-packed weight and the per-block scale
before delegating to ModelBase._repack_nvfp4 via super(). Reuses the
existing LlamaModel.permute static helper and respects the existing
undo_permute flag, so subclasses (Mistral, Granite, Llama4, etc.)
inherit the fix automatically.

Verified on TinyLlama-1.1B reproducer: perplexity drops from 4419
(gibberish) to 43.9, matching the BF16-dequantized baseline (44.0).
Also verified end-to-end on ALIA-40b-instruct-2601 (BSC, Llama
architecture) with multilingual generation in Spanish/Catalan/Basque/
Galician all coherent with the fix applied.

Co-authored-by: Chema <chema@montevive.ai>
2026-05-03 18:22:00 +03:00
lucy
d05fe1d7da fix: CUDA device PCI bus ID de-dupe OOMing (ignoring other 3 gpus entirely) (#22533)
* fix: CUDA device PCI bus ID detection for multi-GPU de-dupe

* HIP, MUSA macros

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-05-02 22:19:25 +02:00
Georgi Gerganov
0754b7b6fe server : avoid checkpoint data host copies (#22558)
* server : avoid checkpoint data host copies

* llama : refactor llama_io_read_i
2026-05-02 18:03:25 +03:00
JusteLeo
09294365a9 ggml-virtgpu: fix circular dependency in headers (#22557) 2026-05-02 21:28:50 +08:00
Csaba Kecskemeti
63d93d1733 convert : disable uint types (#18908) 2026-05-02 09:05:59 +03:00
Shawn Gu
c5a3bc39b1 opencl: Adreno optimization for MoE - MxFP4 (#22301)
* MoE Mxfp4 CLC kernel added, router reorder on GPU

* Pass test-backend-ops for MoE mxfp4 Adreno CLC

* remove putenv in llama-model.cpp

* fix indent style and whitespace

* opencl: remove unnecessary headers

* opencl: do not save cl_program objects

* opencl: remove unnecessary assert

* fix precision issue

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-05-01 23:02:24 -07:00
Johannes Gäßler
9dbb372610 Github: update issue templates (#22594) 2026-05-02 07:56:13 +02:00
Georgi Gerganov
228e836344 sync : ggml 2026-05-02 08:55:29 +03:00
Georgi Gerganov
ed23489f42 ggml : bump version to 0.10.2 (ggml/1474) 2026-05-02 08:55:29 +03:00
Georgi Gerganov
457e2288c9 sync : ggml 2026-05-02 07:22:35 +03:00
Georgi Gerganov
e8ec7ab058 ggml : try fix win32 build (whisper/0) 2026-05-02 07:22:35 +03:00
Yiwei Shao
1a03cf47f6 hexagon: hmx flash attention (#22347)
* hmx: extract shared interleave headers and unify matmul batched

* hmx: add HMX-accelerated flash attention for prefill

* hmx: replace asm wrappers with Q6_ intrinsics in hmx-utils.h

Switches three single-instruction helpers from inline asm to the matching
Q6_ intrinsics, matching the style established by aizip f8737609a and used
by the upstream PR #21554 hmx-matmul-ops.c rewrite:

  hmx_set_output_scales       asm "bias=mxmem2"  -> Q6_bias_mxmem2_A
  hmx_load_tile_pair_fp16     asm packet         -> Q6_activation_hf_mxmem_RR
                                                    + Q6_weight_hf_mxmem_RR
  hmx_consume_accumulator_fp16 asm "mxmem=acc"   -> Q6_mxmem_AR_after_hf

hmx_load_tiles_fp16 stays on inline asm: it uses ":deep" activation
streaming, and the mixed Q6_activation_hf_mxmem_RR_deep + non-deep
Q6_weight_hf_mxmem_RR pair fails the HMX backend constraint check
("activate weight pair (1) exceeds limit (1)"). The asm bundle keeps
both halves in one VLIW packet and avoids the diagnostic.

Functionally equivalent — same instructions emitted; the Q6_ intrinsics
just give the compiler more visibility for scheduling.

* hmx: drop the duplicate interleave_fp16_weight_chunk_to_tiles

* hmx:  apply upstream optimization to hmx-flash-attn-ops.c
apply restrict, __builtin_assume, and pointer accumulation to the three HMX workers (qk_dot, o_update, o_norm) and the matching inline HMX loops in op_hmx_flash_attn_ext.

* hmx: unify interleave helper

* hmx: multi-thread Q load / O store and enable prefill FA dispatch

Extract inline Q-load and O-store loops into worker_pool-parallel helpers
(fa_phase_q_load, fa_phase_o_store) so HVX threads split the F32↔F16
conversion work across row ranges.  Also relax the softmax threading
gate from n_row_vec_cnt >= n_threads to >= 2, which was unnecessarily
forcing single-thread fallback when n_rows_g < 512.

On the dispatch side, remove the ne[2] != 1 guard that blocked multi-head
(prefill) FA from reaching the HTP backend — GQA is already handled
internally by both the HMX and HVX flash-attention paths.

* hmx: relax matmul pipeline gate to cover k > n shapes (e.g. FFN_down)

* hmx: optimize FA softmax mask phase (no-ALiBi fast path + GQA dedup)

* hmx: Add an asm memory clobber at the phase boundary to prevent reorder bug

* [experimental]: fp16 softmax (EXP2_HF) to accelerate fa

Bake log2(e) into qk_scale and use hvx_exp2_hf directly for P and m_diff
(base-2 consistent, matches htp-ops-lib). ~22 ALU ops for 64 lanes vs
~44 for the F32 round-trip path.

* hmx flash-attn: refine cost model coefficients based on profiling data

* hmx flash-attn: replace asm clobber with targeted volatile reads on vtcm_d_tiles

* hmx flash-attn: fix prefill correctness (dst indexing, softmax reduce, V stride)

* hmx flash-attn: fix p_tiles dual-tile OOB race; enable MT + pipeline

* hmx flash-attn: preserve additive mask bias in no-ALiBi fast path

The no-ALiBi fast path (max_bias==0) was skipping mask add entirely on
the assumption that mask values are only {0, -inf}.  This is wrong when
the mask carries additive positional bias — those terms were silently
dropped.  Keep the slope-mul skip (slope≡1.0) but add mask back so the
bias survives; vmux still clamps below -16 to -inf.

Also add HMX FA coverage to test-backend-ops: prefill shapes (nb=64,
nb=32) × {mask on/off} × {ALiBi on/off} × {softcap on/off}, F16 KV,
hs ∈ {64, 128}.

* hmx: fix softcap+EXP2_HF interaction, tighten matmul pipeline gate, add FA tests

- flash-attn: when EXP2_HF is on AND logit_softcap is active, fold
  log2(e) into the post-tanh multiplier (v_cap) instead of pre-baking
  it into qk_scale.  Pre-baking shifted the tanh knee from x≈c to
  x≈c/log2(e) and produced numerically wrong softcapped outputs
  whenever both knobs were enabled.
- flash-attn softmax (fa_softmax_thread): replace the union+memcpy
  scalar extract pattern with HVX vmux-based per-row accumulators on
  rowmax/rowsum.  Add hvx_vec_get_f16 helper in hvx-base.h.  Functional
  parity, less scalar code, clearer hf/qf16 lane-format contract.
- matmul (hmx_mat_mul_permuted_qk_0_d16a32): pick pipeline vs sequential
  layout based on whether the chunker actually yields >=2 n-chunks,
  instead of the static (m>=128 && n>=256) gate.  Avoids paying for
  output double-buffer + worker dispatch when there is no HMX/HVX
  overlap to gain (e.g. shapes that collapse to one n-chunk).
- tests: add HMX flash-attention coverage over the
  {mask, ALiBi (max_bias), logit_softcap} cross-product for the prefill
  path — head_dim 64/128, GQA 4×4, kv=512/nb=64 plus a kv=113/nb=32
  non-aligned case.

* [Help Wanted]: refactor D matrix computation into separate function for clarity and maintainability

* format code

* hexagon: looks like -O3 is causing issues with the large code base, switch to -O2 and -flto instead

* hexagon: use hex_ prefix for swap_ptr

* hexagon: move vtcm_seq_alloc into vtcm-utils.h

More vtcm allocator updates are coming so it makes sense to start the separate hdr for it.

* hmx-utils: add hmx_prefix for layout converters

* hmx-mm: move main hmx_mm functions to the end, remove unused fwd decls, etc

* hmx-mm: remove unused qweight_fetch_task_state_t and minor alignment fixes

* hmx-fa: minor alignment fixes

* hmx-fa: move hmx_flash_atten into hmx-ops.h

* hmx-fa: remove redundant workpool pointer in the hmx_fa_ctx, plus minor alignment updates

* hmx-fa: minor alignment and simplifications

* hexagon: move FA_EXP_F16 option to hostside CMake file

* hmx-fa: use hvx_vec_splat_f16 instead of fp16_to_bits

* hmx-fa: add hvx_splat_u16/u8 and use that in the fa instead custom hvx_fill

* hmx-fa: some more alignment updates in the core fa function

* hmx-fa: keep slopes in vtcm in fp16

Saves malloc/free and removes the need for float -> fp16 downcast on every use.

* hexagon: consistent noinline usage (after static)

* hex-hmx: consistent use FARF_HIGH to enable debug output

* hmx-utils: no need for always_inline attr

* hex-hmx: consistent noinline usage (static noinline ...)

* hex-hmx: simplify init_col_scales

* hexagon: fix editorconfig errors

* hmx-mm: minor alignment fixes

---------

Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
2026-05-01 20:29:13 -07:00
ddh0
b97ebdc98f llama-quant : fix --tensor-type when default qtype is overriden (#22572)
fix #22544 (my fault!)

Credit to @Anai-Guo, ref #22559 - since that one was closed due to the
new contributor policy I am taking the liberty of re-submitting that PR
here.
2026-05-01 19:55:55 +02:00
Aparna M P
2098fd6169 hexagon: enable non-contiguous row tensor support for unary ops (#22574) 2026-05-01 10:09:23 -07:00
Aleksander Grygier
ab6120cde5 webui: Spring Cleaning Refactor v1 (#22505)
* wip: server_tools

* feat: Integrate with `/tools` endpoint

* feat: Builtin + MCP + JSON Schema Tools WIP

* refactor

* displayName -> display_name

* snake_case everywhere

* rm redundant field

* feat: Improvements

* chore: update webui build output

* refactor: Updates after server updates

* chore: update webui build output

* change arg to --tools all

* feat: UI improvements

* chore: update webui build output

* add readme mention

* llama-gen-docs

* chore: update webui build output

* chore: update webui build output

* chore: update webui build output

* feat: Reorganize settings sections

* feat: Separate dialogs for MCP Servers Settings and Import/Export

* feat: WIP

* feat: WIP

* feat: WIP

* feat: WIP

* feat: WIP

* feat: WIP

* WIP on allozaur/20677-webui-server-tools

* feat: UI improvements

* chore: Update package lock

* chore: Run `npm audit fix`

* feat: UI WIP

* feat: UI

* refactor: Desktop Icon Strip DRY

* feat: Cleaner rendering and transition for ChatScreen

* feat: UI improvements

* feat: UI improvement

* feat: Remove MCP Server "enable" switch from Tools submenu

* chore: Run `npm audit fix`

* feat: WIP

* feat: Logic improvements

* refactor: Cleanup

* refactor: DRY

* test: Fix Chat Sidebar UI Tests

* chore: Update package lock

* refactor: Cleanup

* feat: Chat Message Action Card with Continue and Permission flow implementations

* feat: Add agentic steering messages, draft messages and improve chat UX

* fix: Search results UI

* test: Fix unit test

* feat: UI/UX improvements

* refactor: Simplify `useToolsPanel` access in components

* feat: Implement Processing Info Context API

* feat: Implement 'Go back to chat' functionality for settings

* feat: Enhance MCP Server management in Chat Form Attachments

* style: Minor UI and branding adjustments

* chore: Update webui static build output

* chore: Formatting, linting & type checks

* feat: Draft messages logic

* feat: UI improvements

* feat: Steering Messages improvements

* refactor: Cleanup

* refactor: Cleanup

* feat: Improve UI

* refactor: Settings navigation hook

* refactor: DRY code

* refactor: DRY ChatMessageUser UI components

* refactor: Desktop Icon Strip DRY

* refactor: Tools & permissions

* fix: Navigation condition

* refactor: Cleanup

* refactor: Cleanup

* refactor: Cleanup

* fix: preserve reasoning_content in agentic flow

* refactor: Storybook cleanup

* refactor: isInViewport util function

* refactor: Rename globally `onClick` to `onclick`

* chore: `npm audit fix`

* refactor: Action Icon usage

* refactor: Naming

* refactor: JS in `class` directive

* refactor: Chat components cleanup WIP

* refactor: Components structure

* refactor: Cleanup WIP

* feat: New ChatAttachmentsPreview component

* feat: UI improvements

* feat: UI improvements

* refactor: Cleanup

* refactor: ChatAttachmentsPreview UI/UX

* refactor: Remove dead code

* refactor: Cleanup

* fix: Model Name aliases displaying

* feat: Shortcut improvements

* refactor: Chat Message

* feat: Move Import/Export to settings

* refactor: Cleanup

* refactor: Cleanup

* refactor: Cleanup

* refactor: Cleanup

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-05-01 18:36:29 +02:00
Masashi Yoshimura
c3c1505392 ggml-webgpu: Fix vectorized handling in mul-mat and mul-mat-id (#22578)
* Fix vectorized condition of mul-mat-fast pipeline and add vectorized variant to mul-mat-id

* Apply suggestion from @CISC

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-05-01 07:55:01 -07:00
214 changed files with 11063 additions and 8205 deletions

View File

@@ -12,6 +12,8 @@ body:
after recreating the CMake build directory and with `-DGGML_CCACHE=OFF`.
If the compilation succeeds with ccache disabled you should be able to permanently fix the issue
by clearing `~/.cache/ccache` (on Linux).
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
- type: textarea
id: commit
attributes:

View File

@@ -1,5 +1,5 @@
name: Bug (model use)
description: Something goes wrong when using a model (in general, not specific to a single llama.cpp module).
description: Something goes wrong when running a model (crashes, garbled outputs, etc.).
title: "Eval bug: "
labels: ["bug-unconfirmed", "model evaluation"]
body:
@@ -12,6 +12,8 @@ body:
If you encountered the issue while using an external UI (e.g. ollama),
please reproduce your issue using one of the examples/binaries in this repository.
The `llama-completion` binary can be used for simple and reproducible model inference.
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
- type: textarea
id: version
attributes:

View File

@@ -10,6 +10,8 @@ body:
This issue template is intended for miscellaneous bugs that don't fit into any other category.
If you encountered the issue while using an external UI (e.g. ollama),
please reproduce your issue using one of the examples/binaries in this repository.
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
- type: textarea
id: version
attributes:

View File

@@ -8,6 +8,8 @@ body:
value: |
[Please post your idea first in Discussion if there is not yet a consensus for this enhancement request. This will help to keep this issue tracker focused on enhancements that the community has agreed needs to be implemented.](https://github.com/ggml-org/llama.cpp/discussions/categories/ideas)
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
- type: checkboxes
id: prerequisites
attributes:

View File

@@ -8,6 +8,8 @@ body:
value: |
Don't forget to check for any [duplicate research issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3A%22research+%F0%9F%94%AC%22)
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
- type: checkboxes
id: research-stage
attributes:

View File

@@ -9,6 +9,8 @@ body:
Don't forget to [check for existing refactor issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3Arefactoring) in case it's already covered.
Also you may want to check [Pull request refactor label as well](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Aopen+is%3Apr+label%3Arefactoring) for duplicates too.
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
- type: textarea
id: background-description
attributes:

View File

@@ -2889,6 +2889,20 @@ class LlamaModel(TextModel):
.swapaxes(1, 2)
.reshape(weights.shape))
def _repack_nvfp4(self, name: str, weight: Tensor, scale: Tensor, scale2: Tensor, input_scale: Tensor):
# Mirror the BF16 Q/K RoPE permutation site in modify_tensors; the NVFP4 path bypasses it.
if self.undo_permute:
n_head = self.find_hparam(["n_heads", "num_attention_heads"], optional=True)
n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"], optional=True)
if n_head is not None:
if name.endswith("q_proj.weight"):
weight = LlamaModel.permute(weight, n_head, n_head)
scale = LlamaModel.permute(scale, n_head, n_head)
elif name.endswith("k_proj.weight"):
weight = LlamaModel.permute(weight, n_head, n_kv_head)
scale = LlamaModel.permute(scale, n_head, n_kv_head)
super()._repack_nvfp4(name, weight, scale, scale2, input_scale)
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
@@ -12702,11 +12716,12 @@ class MistralModel(LlamaModel):
def set_mistral_config(gguf_writer: gguf.GGUFWriter, hparams: dict):
if "yarn" in hparams:
yarn_params = hparams["yarn"]
mscale_all_dim = 1.0 if not yarn_params["apply_scale"] else 0.0
gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
gguf_writer.add_rope_scaling_yarn_log_mul(mscale_all_dim)
gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
if "llama_4_scaling" in hparams:
@@ -13232,17 +13247,18 @@ class LazyTorchTensor(gguf.LazyBase):
}
# only used when byteswapping data. Only correct size is needed
# TODO: uncomment uint64, uint32, and uint16, ref: https://github.com/pytorch/pytorch/issues/58734
_dtype_byteswap_map: dict[torch.dtype, type] = {
torch.float64: np.float64,
torch.float32: np.float32,
torch.bfloat16: np.float16,
torch.float16: np.float16,
torch.int64: np.int64,
torch.uint64: np.uint64,
# torch.uint64: np.uint64,
torch.int32: np.int32,
torch.uint32: np.uint32,
# torch.uint32: np.uint32,
torch.int16: np.int16,
torch.uint16: np.uint16,
# torch.uint16: np.uint16,
torch.int8: np.int8,
torch.uint8: np.uint8,
torch.bool: np.uint8,

View File

@@ -5,7 +5,7 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 10)
set(GGML_VERSION_PATCH 1)
set(GGML_VERSION_PATCH 2)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")

View File

@@ -5431,8 +5431,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;
char pci_bus_id[16] = {};
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
char pci_bus_id[32] = {};
CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i));
dev_ctx->pci_bus_id = pci_bus_id;
dev_ctx->op_offload_min_batch_size = min_batch_size;

View File

@@ -55,6 +55,7 @@
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceGetAttribute hipDeviceGetAttribute
#define cudaDeviceGetPCIBusId hipDeviceGetPCIBusId
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t

View File

@@ -39,6 +39,7 @@
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceGetPCIBusId musaDeviceGetPCIBusId
#define cudaDeviceProp musaDeviceProp
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaError_t musaError_t

View File

@@ -22,7 +22,8 @@ message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")

View File

@@ -2254,8 +2254,7 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
return false;
}
if (dst->ne[2] != 1 || dst->ne[3] != 1) {
// FA during prompt still needs work
if (dst->ne[3] != 1) {
return false;
}
@@ -2421,8 +2420,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
return false;
}
// TODO: add support for non-contigiuos tensors
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
// TODO: add support for non-contiguous elements within a row
if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) {
return false;
}

View File

@@ -44,6 +44,11 @@ target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
# HMX acceleration: available on v73+ architectures
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
@@ -52,11 +57,13 @@ if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-queue.c
hmx-matmul-ops.c
hmx-flash-attn-ops.c
)
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-matmul-ops.c
hmx-flash-attn-ops.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)

View File

@@ -138,15 +138,15 @@ set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
#Compiler Options
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O2")
set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O2")
set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}")
set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}")

View File

@@ -17,13 +17,14 @@
#include "htp-ctx.h"
#include "htp-ops.h"
#include "htp-ops.h"
#include "hmx-ops.h"
// Must be multiple of 32
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
// This is a bit of a hack because the compiler is strugling to properly inline
// the default hvx_vec_f32_to_f16 with output into the local array.
static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
{
*(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
}
@@ -621,6 +622,17 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
return HTP_STATUS_NO_SUPPORT;
}
#ifdef HTP_HAS_HMX
// HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) {
int ret = hmx_flash_attn_ext(octx);
if (ret == HTP_STATUS_OK) {
return ret;
}
// VTCM too small or other failure -> fall through to HVX path
}
#endif
struct htp_fa_context factx;
factx.octx = octx;

View File

@@ -74,6 +74,12 @@ static inline size_t hex_smax(size_t a, size_t b) {
return a > b ? a : b;
}
static inline void hex_swap_ptr(void ** p1, void ** p2) {
void * t = *p1;
*p1 = *p2;
*p2 = t;
}
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
Q6_l2fetch_AP((void *) p, control);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -61,6 +61,9 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
int m, int k, int n,
int weight_type);
// HMX flash attention
int hmx_flash_attn_ext(struct htp_ops_context * octx);
#ifdef __cplusplus
}
#endif

View File

@@ -4,6 +4,9 @@
#ifndef HMX_UTILS_H
#define HMX_UTILS_H
#include "hvx-base.h"
#include <assert.h>
#include <hexagon_types.h>
#include <stddef.h>
@@ -12,21 +15,188 @@
#define HMX_FP16_TILE_N_ELMS 1024
#define HMX_FP16_TILE_SIZE 2048
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
// Initialise aligned 256-byte area with scale vector + zero padding.
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
HVX_Vector *pv = (HVX_Vector *)out_scales;
*pv++ = v_scale;
*pv = Q6_V_vzero();
static inline void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
volatile HVX_Vector *pv = (HVX_Vector *) out_scales;
pv[0] = v_scale;
pv[1] = Q6_V_vzero();
}
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
// --- Shared scatter offsets and interleave helper ---
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
uint8_t *p = *vtcm_ptr;
*vtcm_ptr += size;
return p;
// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile.
// word[i] = i*128 maps K-row-pair i to byte offset i*128.
// Column offset (n*4) is added at runtime. Entries 0..15 cover one tile (region 2047);
// entries 16..31 cover the next adjacent tile (region 4095) — pick region size at the
// call site to scatter into one tile (masked) or two contiguous tiles (unmasked).
static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = {
0 * 128, 1 * 128, 2 * 128, 3 * 128, 4 * 128, 5 * 128, 6 * 128, 7 * 128, 8 * 128, 9 * 128, 10 * 128,
11 * 128, 12 * 128, 13 * 128, 14 * 128, 15 * 128, 16 * 128, 17 * 128, 18 * 128, 19 * 128, 20 * 128, 21 * 128,
22 * 128, 23 * 128, 24 * 128, 25 * 128, 26 * 128, 27 * 128, 28 * 128, 29 * 128, 30 * 128, 31 * 128,
};
// Scatter row-major FP16 data (in VTCM scratch) into transposed [K][N] tiles.
// vtcm_src: [n_cols][src_stride] row-major fp16 (only first k elements per row are used)
// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16
// Processes rows [start_row, end_row) for multi-thread slicing.
// Full range: start_row=0, end_row=n_cols.
static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
const __fp16 * restrict vtcm_src,
int n_cols,
int k,
int src_stride,
int start_row,
int end_row) {
assert(k % HMX_FP16_TILE_N_COLS == 0);
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
// Each hvx_vmemu load brings 64 fp16 = 128 bytes covering 2 adjacent K-tiles.
// When n_k_tiles is even, scatter into 2 K-tiles per call (region 4095, no mask)
// using the upper half of hmx_transpose_scatter_offsets. Tail one K-tile (when
// n_k_tiles is odd) falls back to single-tile masked scatter.
const bool pair_scatter = (n_k_tiles & 1) == 0;
const size_t pair_region = (size_t) (2 * HMX_FP16_TILE_SIZE - 1);
const size_t single_region = (size_t) (HMX_FP16_TILE_SIZE - 1);
__builtin_assume(k > 0);
__builtin_assume(end_row > start_row);
if (pair_scatter) {
// Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter.
const int c_step = 2 * HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
const int n_c_iters = k / c_step;
for (int r = start_row; r < end_row; r += 2) {
const int ct = r / HMX_FP16_TILE_N_ROWS;
const int local_r = r % HMX_FP16_TILE_N_ROWS;
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
if (p1) {
for (int i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0);
p0 += c_byte_step;
HVX_Vector v1 = hvx_vmemu(p1);
p1 += c_byte_step;
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1);
tile_base += dst_step;
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0);
p0 += c_byte_step;
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
tile_base += dst_step;
}
}
}
} else {
// Fallback: scatter one K-tile per call (region 2047, masked).
const int c_step = HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
const int n_c_iters = k / c_step;
for (int r = start_row; r < end_row; r += 2) {
const int ct = r / HMX_FP16_TILE_N_ROWS;
const int local_r = r % HMX_FP16_TILE_N_ROWS;
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
if (p1) {
for (int i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0);
p0 += c_byte_step;
HVX_Vector v1 = hvx_vmemu(p1);
p1 += c_byte_step;
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1);
tile_base += dst_step;
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0);
p0 += c_byte_step;
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
tile_base += dst_step;
}
}
}
}
}
// Interleave row-major FP16 data into column-major tile format.
// Input: [n_rows, head_dim] row-major. Output: tile[dim_tile][row_tile].
// Processes rows [start_row, end_row) for multi-thread slicing.
// Full range: start_row=0, end_row=n_rows.
static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
const __fp16 * restrict src,
int n_rows,
int head_dim,
int src_stride,
int n_row_tiles,
int start_row,
int end_row) {
__builtin_assume(head_dim > 0);
const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS;
for (int r = start_row; r < end_row; r += 2) {
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows;
const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride);
const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL;
// Row-pair invariants hoisted out of the c loop.
const int r0 = r / HMX_FP16_TILE_N_ROWS;
const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
// tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0).
// Each c step (+= 64) advances both by 2 dim-tiles worth of fp16.
__fp16 * tb0 = tiles_out + (size_t) r0 * HMX_FP16_TILE_N_ELMS;
__fp16 * tb1 = tb0 + tile_stride_elms;
const size_t tb_step = 2 * tile_stride_elms;
if (pv_in1) {
for (int c = 0; c < head_dim; c += 64) {
HVX_Vector v0 = *pv_in0++;
HVX_Vector v1 = *pv_in1++;
HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2);
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp);
tb0 += tb_step;
tb1 += tb_step;
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int c = 0; c < head_dim; c += 64) {
HVX_Vector v0 = *pv_in0++;
HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2);
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp);
tb0 += tb_step;
tb1 += tb_step;
}
}
}
}
#endif // HMX_UTILS_H

View File

@@ -77,6 +77,12 @@ static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
return x;
}
static inline _Float16 hvx_vec_get_f16(HVX_Vector v) {
_Float16 __attribute__((aligned(128))) x;
hvx_vec_store_a(&x, 2, v);
return x;
}
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
// abs by clearing the fp16 sign bit
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);

View File

@@ -7,7 +7,8 @@
#include "hvx-base.h"
#define hvx_splat_loop_body(dst_type, vec_store) \
#define hvx_splat_pragma(x) _Pragma(#x)
#define hvx_splat_loop_body(dst_type, vec_store, unroll_cnt) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
\
@@ -16,7 +17,7 @@
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
hvx_splat_pragma(unroll(unroll_cnt)) \
for (; i < nvec; i++) { \
vdst[i] = src; \
} \
@@ -25,31 +26,47 @@
} \
} while(0)
static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
static inline void hvx_splat_a(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
assert((unsigned long) dst % 128 == 0);
hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a);
hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a, 4);
}
static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u);
static inline void hvx_splat_u(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u, 4);
}
static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) {
static inline void hvx_splat_f32_a(void * restrict dst, float v, uint32_t n) {
hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float));
}
static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) {
static inline void hvx_splat_f32_u(void * restrict dst, float v, uint32_t n) {
hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float));
}
static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) {
static inline void hvx_splat_f16_a(void * restrict dst, _Float16 v, uint32_t n) {
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
}
static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) {
static inline void hvx_splat_f16_u(void * restrict dst, _Float16 v, uint32_t n) {
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
}
static inline void hvx_splat_u16_a(void * restrict dst, uint16_t v, uint32_t n) {
hvx_splat_a(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t));
}
static inline void hvx_splat_u16_u(void * restrict dst, uint16_t v, uint32_t n) {
hvx_splat_u(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t));
}
static inline void hvx_splat_u8_a(void * restrict dst, uint8_t v, uint32_t n) {
hvx_splat_a(dst, Q6_Vb_vsplat_R(v), n, 1);
}
static inline void hvx_splat_u8_u(void * restrict dst, uint8_t v, uint32_t n) {
hvx_splat_u(dst, Q6_Vb_vsplat_R(v), n, 1);
}
#define hvx_copy_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \

View File

@@ -17,7 +17,7 @@
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
#define EXP_ONE (0x3f800000) // 1.0
#define EXP_RANGE_R (0x42B16666) // 88.7
#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
@@ -163,7 +163,7 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict
HVX_Vector vec_out = Q6_V_vzero();
static const float kInf = INFINITY;
static const float kMaxExp = 88.7f;
static const float kMaxExp = 88.7228f;
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
const HVX_Vector inf = hvx_vec_splat_f32(kInf);

View File

@@ -26,8 +26,8 @@ struct htp_unary_context {
const uint8_t * data_src0;
uint8_t * data_dst;
size_t src0_row_size;
size_t dst_row_size;
size_t src0_data_row_size; // actual data bytes per row
size_t dst_data_row_size; // actual data bytes per row
size_t src0_row_size_aligned;
size_t dst_row_size_aligned;
@@ -41,6 +41,40 @@ struct htp_unary_context {
uint32_t nc;
};
// Convert flat row index to DDR byte offset using the tensor's actual strides.
// ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3
static inline size_t unary_row_offset(uint32_t ir,
uint32_t ne1, uint32_t ne2,
size_t nb1, size_t nb2, size_t nb3) {
const uint32_t i1 = ir % ne1;
const uint32_t i2 = (ir / ne1) % ne2;
const uint32_t i3 = ir / (ne1 * ne2);
return i1 * nb1 + i2 * nb2 + i3 * nb3;
}
// Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice
// boundary of src and dst so the nb1 stride stays valid for all rows.
static inline uint32_t unary_block_size(uint32_t ir,
uint32_t end_row,
uint32_t block,
bool src_contig,
bool dst_contig,
uint32_t src_ne1,
uint32_t dst_ne1) {
uint32_t limit = MIN(block, end_row - ir);
if (!src_contig) {
const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1;
limit = MIN(limit, src_slice_end - ir);
}
if (!dst_contig) {
const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1;
limit = MIN(limit, dst_slice_end - ir);
}
return limit;
}
#define htp_unary_preamble \
const uint32_t ne00 = src->ne[0]; \
const uint32_t ne01 = src->ne[1]; \
@@ -276,8 +310,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
int32_t * op_params = octx->op_params;
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
const size_t src0_row_size = uctx->src0_row_size;
const size_t dst_row_size = uctx->dst_row_size;
const size_t src0_data_row_size = uctx->src0_data_row_size;
const size_t dst_data_row_size = uctx->dst_data_row_size;
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
@@ -303,7 +337,16 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
size_t src0_spad_half_size = uctx->src0_spad_half_size;
size_t dst_spad_half_size = uctx->dst_spad_half_size;
const int BLOCK = uctx->block;
// Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride
// 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every
// transfer stays within a nb1-uniform region. Skipped for contiguous tensors.
const bool src0_contig = (nb02 == (size_t)ne01 * nb01) &&
(nb03 == (size_t)ne02 * nb02);
const bool dst_contig = (nb2 == (size_t)ne1 * nb1) &&
(nb3 == (size_t)ne2 * nb2);
const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01);
const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1);
const uint32_t BLOCK = MIN(src0_max_block, dst_max_block);
if (BLOCK == 0) {
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
octx->src0_spad.size_per_thread, src0_row_size_aligned);
@@ -312,21 +355,23 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
dma_queue * dma_queue = octx->ctx->dma[ith];
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) {
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_queue_push(dma_queue,
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
nb1, dst_row_size_aligned, dst_data_row_size, 0);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03);
dma_queue_push(dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off),
src0_row_size_aligned, nb01, src0_data_row_size, block_size);
ir += block_size;
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
for (uint32_t ir = src0_start_row; ir < src0_end_row; ) {
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
@@ -361,18 +406,25 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
break;
}
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
dst_row_size, dst_row_size_aligned, block_size);
const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3);
dma_queue_push(dma_queue,
dma_make_ptr(data_dst + dst_off, dst_spad),
nb1, dst_row_size_aligned, dst_data_row_size, block_size);
// prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
const uint32_t next_ir = ir + block_size;
if (next_ir < src0_end_row) {
const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
const uint32_t pref_ir = next_ir + next_block_size;
if (pref_ir < src0_end_row) {
const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03);
dma_queue_push(dma_queue,
dma_make_ptr(src0_spad, data_src + src0_pref_off),
src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size);
}
}
ir += block_size;
}
dma_queue_flush(dma_queue);
@@ -426,11 +478,11 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
const size_t src0_row_size = src0->nb[1];
const size_t dst_row_size = dst->nb[1];
const size_t src0_data_row_size = src0->ne[0] * sizeof(float);
const size_t dst_data_row_size = dst->ne[0] * sizeof(float);
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN);
// VTCM scratchpads for all tensors
// N rows per thread, padded to HVX vector size
@@ -468,8 +520,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
.data_src0 = (const uint8_t *)src0->data,
.data_dst = (uint8_t *)dst->data,
.src0_row_size = src0_row_size,
.dst_row_size = dst_row_size,
.src0_data_row_size = src0_data_row_size,
.dst_data_row_size = dst_data_row_size,
.src0_row_size_aligned = src0_row_size_aligned,
.dst_row_size_aligned = dst_row_size_aligned,

View File

@@ -0,0 +1,16 @@
#ifndef VTCM_UTILS_H
#define VTCM_UTILS_H
#include "hex-utils.h"
#include <assert.h>
#include <stdint.h>
#include <hexagon_types.h>
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
uint8_t *p = *vtcm_ptr;
*vtcm_ptr += size;
return p;
}
#endif // VTCM_UTILS_H

View File

@@ -107,6 +107,10 @@ set(GGML_OPENCL_KERNELS
mul_mv_id_mxfp4_f32_flat
gemm_moe_mxfp4_f32
gemv_moe_mxfp4_f32
gemm_moe_mxfp4_f32_ns
gemv_moe_mxfp4_f32_ns
moe_reorder_b
moe_sort_by_expert
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q4_0_f32_l4_lm

View File

@@ -416,6 +416,15 @@ struct ggml_backend_opencl_context {
ggml_cl_buffer prealloc_src0;
ggml_cl_buffer prealloc_src1;
// prealloc buffers for MoE router table preprocess
bool toggle_reorder = false;
ggml_cl_buffer prealloc_post_router;
ggml_cl_buffer prealloc_emap;
ggml_cl_buffer prealloc_hist;
ggml_cl_buffer prealloc_tile_offset;
ggml_cl_buffer prealloc_total_tiles;
ggml_cl_buffer prealloc_slot_counter;
cl_program program_add;
cl_program program_add_id;
cl_program program_clamp;
@@ -531,6 +540,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1;
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns;
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans;
cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle;
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
@@ -587,6 +597,9 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;
cl_kernel kernel_timestep_embedding;
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns;
cl_kernel kernel_moe_reorder_b;
cl_kernel kernel_moe_histogram, kernel_moe_scan, kernel_moe_fill, kernel_moe_scatter;
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
cl_kernel kernel_mul_mv_id_mxfp4_f32;
@@ -945,6 +958,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans4_ns", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
@@ -2864,6 +2879,77 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// gemv_moe_mxfp4_f32_ns
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemv_moe_mxfp4_f32_ns.cl.h"
};
#else
const std::string kernel_src = read_file("gemv_moe_mxfp4_f32_ns.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_mxfp4_f32_ns", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// gemm_moe_mxfp4_f32_ns
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemm_moe_mxfp4_f32_ns.cl.h"
};
#else
const std::string kernel_src = read_file("gemm_moe_mxfp4_f32_ns.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// moe_reorder_b
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "moe_reorder_b.cl.h"
};
#else
const std::string kernel_src = read_file("moe_reorder_b.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_moe_reorder_b = clCreateKernel(prog, "kernel_moe_reorder_b", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// moe_sort_by_expert
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "moe_sort_by_expert.cl.h"
};
#else
const std::string kernel_src = read_file("moe_sort_by_expert.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_moe_histogram = clCreateKernel(prog, "kernel_moe_histogram", &err), err));
CL_CHECK((backend_ctx->kernel_moe_scan = clCreateKernel(prog, "kernel_moe_scan", &err), err));
CL_CHECK((backend_ctx->kernel_moe_fill = clCreateKernel(prog, "kernel_moe_fill", &err), err));
CL_CHECK((backend_ctx->kernel_moe_scatter = clCreateKernel(prog, "kernel_moe_scatter", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// gemv_noshuffle_q6_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3651,13 +3737,12 @@ struct ggml_tensor_extra_cl_mxfp4 {
CL_CHECK(clReleaseMemObject(e));
e = nullptr;
}
if (q != nullptr) {
if (q_img != nullptr) {
CL_CHECK(clReleaseMemObject(q_img));
q = nullptr;
q_img = nullptr;
}
// Currently, q_img and d_img are not used. They can be image1d_buffer_t
// Currently, e_img is not used. They can be image1d_buffer_t
// that wraps around q and d to utilize image access path.
q_img = nullptr;
e_img = nullptr;
size_q = 0;
size_e = 0;
@@ -4740,7 +4825,7 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
GGML_UNUSED(backend_ctx);
int ne01 = tensor->ne[1];
return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
}
inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
@@ -5151,8 +5236,9 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
CL_CHECK(err);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
// Adreno moe mxfp4 kernel needs special transpose and unshuffling
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans4_ns;
int ne00 = tensor->ne[0];
int ne01 = tensor->ne[1];
@@ -5172,9 +5258,21 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
CL_CHECK(clReleaseMemObject(data_device));
tensor->extra = extra;
// Create image for Q
cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32};
cl_image_desc img_desc_q = {
CL_MEM_OBJECT_IMAGE1D_BUFFER,
static_cast<size_t>(ggml_nelements(tensor) / 8),
0, 0, 0, 0, 0, 0, 0,
{ extra->q }
};
extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
tensor->extra = extra;
return;
}
#endif
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
@@ -5912,7 +6010,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans;
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans4_ns;
int ne00 = tensor->ne[0];
int ne01 = tensor->ne[1];
@@ -5936,7 +6034,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
CL_CHECK(clReleaseMemObject(data_device));
return;
}
#endif
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
@@ -12763,6 +12862,118 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
}
}
static void moe_router_reoerder(ggml_backend_t backend, const ggml_tensor * src, int ne20) {
cl_int err;
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra;
cl_ulong offset = extra->offset + src->view_offs;
const int ne21 = src->ne[1];
const int nb21 = src->nb[1];
const int ne02 = nb21 / src->nb[0];
const int n_tile_size = 32;
const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02;
cl_buffer_region region;
region.origin = offset;
region.size = nb21 * ne21;
cl_mem original_router_buf = clCreateSubBuffer(extra->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
backend_ctx->prealloc_post_router.allocate(backend_ctx->context, sizeof(int) * max_post_router_tile * n_tile_size);
region.origin = 0;
region.size = sizeof(int) * max_post_router_tile * n_tile_size;
cl_mem post_router_buf = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
backend_ctx->prealloc_emap.allocate(backend_ctx->context, sizeof(short) * max_post_router_tile);
region.origin = 0;
region.size = sizeof(short) * max_post_router_tile;
cl_mem emap_buf = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
backend_ctx->prealloc_hist.allocate(backend_ctx->context, sizeof(int) * ne02);
region.origin = 0;
region.size = sizeof(int) * ne02;
cl_mem hist_buf = clCreateSubBuffer(backend_ctx->prealloc_hist.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
backend_ctx->prealloc_tile_offset.allocate(backend_ctx->context, sizeof(int) * ne02);
region.origin = 0;
region.size = sizeof(int) * ne02;
cl_mem tile_offset_buf = clCreateSubBuffer(backend_ctx->prealloc_tile_offset.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
backend_ctx->prealloc_slot_counter.allocate(backend_ctx->context, sizeof(int) * ne02);
region.origin = 0;
region.size = sizeof(int) * ne02;
cl_mem slot_counter_buf = clCreateSubBuffer(backend_ctx->prealloc_slot_counter.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
backend_ctx->prealloc_total_tiles.allocate(backend_ctx->context, sizeof(int));
region.origin = 0;
region.size = sizeof(int);
cl_mem total_tiles_buf = clCreateSubBuffer(backend_ctx->prealloc_total_tiles.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
// Histogram
cl_kernel kernel = backend_ctx->kernel_moe_histogram;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &hist_buf));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &ne21));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne20));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne02));
size_t histogram_global_size[] = {(size_t)(((ne21 + 63) / 64) * 64), static_cast<size_t>(ne20), 1};
size_t histogram_local_size[] = {64, static_cast<size_t>(ne20), 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src);
// Scan
kernel = backend_ctx->kernel_moe_scan;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &hist_buf));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &tile_offset_buf));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &total_tiles_buf));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &slot_counter_buf));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n_tile_size));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne02));
size_t scan_global_size[] = {1};
size_t scan_local_size[] = {1};
backend_ctx->enqueue_ndrange_kernel(kernel, 1, scan_global_size, scan_local_size, src);
// Fill
kernel = backend_ctx->kernel_moe_fill;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &post_router_buf));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &total_tiles_buf));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &n_tile_size));
size_t fill_global_size[] = {(size_t)(((max_post_router_tile + 63) / 64) * 64), n_tile_size, 1};
size_t fill_local_size[] = {64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, fill_global_size, fill_local_size, src);
// Scatter
kernel = backend_ctx->kernel_moe_scatter;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &post_router_buf));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &emap_buf));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &tile_offset_buf));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &slot_counter_buf));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne21));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne20));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02));
backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src);
CL_CHECK(clReleaseMemObject(original_router_buf));
CL_CHECK(clReleaseMemObject(hist_buf));
CL_CHECK(clReleaseMemObject(tile_offset_buf));
CL_CHECK(clReleaseMemObject(total_tiles_buf));
CL_CHECK(clReleaseMemObject(slot_counter_buf));
CL_CHECK(clReleaseMemObject(post_router_buf));
CL_CHECK(clReleaseMemObject(emap_buf));
}
static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -12824,6 +13035,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int r2 = ne12/ne02;
const int r3 = ne13/ne03;
@@ -12836,6 +13048,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
int nrows = 1; // number of row in src1
int ndst = 4; // number of values produced by each subgroup
const int n_tile_size = 32;
const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02;
cl_kernel kernel;
// subgroup mat vec
@@ -12967,11 +13182,10 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
size_t local_size[3] = {64, 2, 1};
size_t global_size[3] = {64, 2, 1};
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
int tile_size = 320;
if (ne12 == 1) { // for gemv
kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32;
kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32_ns;
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
// create a sub_buffer for src2
cl_buffer_region region;
@@ -12985,78 +13199,154 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32;
// preprocess router table
int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size;
void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short));
void * host_src2 = malloc(ne21 * nb21);
CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL));
int total_experts = nb21 / nb20;
int out_idx = 0;
for (int i_expert = 0; i_expert < ne02; i_expert++) {
for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) {
for (int j = 0; j < ne21; j++) {
for (int i = 0; i < ne20; i++) {
int expert = ((int *)host_src2)[j * total_experts + i];
if (i_expert == expert) {
((short *)host_src2_reorder)[out_idx] = static_cast<short>(expert);
((short *)host_src2_reorder)[out_idx + 1] = static_cast<short>(j * ne11 + (i % ne11));
((short *)host_src2_reorder)[out_idx + 2] = static_cast<short>(j * ne20 + i);
((short *)host_src2_reorder)[out_idx + 3] = static_cast<short>(i_tile);
out_idx += 4;
}
}
}
}
}
buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status);
// create a sub_buffer for src1
region.origin = offset1;
region.size = ne10 * ne11 * ne12 * sizeof(float);
src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(tile_size);
global_size[2] = static_cast<size_t>(ne20 * ne21 * num_tiles_per_expert);
}
// create image for src1
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
// create a sub_buffer for src1
cl_buffer_region region;
region.origin = offset1;
region.size = ne10 * ne11 * ne12 * sizeof(float);
src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// create image for src1
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
// Set kernel args
int arg_idx = 0;
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
if (ne12 == 1) {
// Set kernel args
int arg_idx = 0;
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11));
} else {
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size));
// launch kernel
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
// deallocate sub buffers and images
CL_CHECK(clReleaseMemObject(src1_sub_buffer));
CL_CHECK(clReleaseMemObject(buf_src1_image));
CL_CHECK(clReleaseMemObject(buf_src2));
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns;
// Reorder router if called from test-backend-ops or when new router is generated.
// Otherwise reuse the reordered result from previous mul_mat_id call.
if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) {
moe_router_reoerder(backend, src2, ne20);
backend_ctx->toggle_reorder = false;
}
cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image;
cl_mem buf_src2, buf_src2_emap;
cl_buffer_region region;
region.origin = 0;
region.size = sizeof(int) * max_post_router_tile * n_tile_size;
GGML_ASSERT(backend_ctx->prealloc_post_router.buffer);
buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
region.origin = 0;
region.size = sizeof(short) * max_post_router_tile;
buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// Reorder activations
// create a sub_buffer for src1
region.origin = offset1;
region.size = ne10 * ne11 * ne12 * sizeof(float);
sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// Create image for reordered src1
// Use pre-allocated placeholder
region.origin = 0;
region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float);
backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size);
buf_src1_reordered = clCreateSubBuffer(
backend_ctx->prealloc_act_trans.buffer,
0,
CL_BUFFER_CREATE_TYPE_REGION,
&region,
&status);
CL_CHECK(status);
cl_image_format image_format_buf_src1;
cl_image_desc image_desc_buf_src1;
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
unsigned short map_ratio = ne20 / ne11;
GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n");
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer)));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio));
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size));
size_t reorder_b_local_size[3] = {256, 1, 1};
size_t reorder_b_global_size[3] = {static_cast<size_t>(((ne00 / 4) + 255) / 256 * 256), static_cast<size_t>(max_post_router_tile * n_tile_size), 1};
// Dispatch reorder kernel
backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst);
// MoE kernel prepare
// Create sub buffer for dst
region.origin = offsetd;
region.size = ne0 * ne1 * ne2 * sizeof(float);
sub_buf_dst = clCreateSubBuffer(
extrad->data_device,
0,
CL_BUFFER_CREATE_TYPE_REGION,
&region,
&status);
CL_CHECK(status);
// Create image for dst
cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT};
cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}};
buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status);
CL_CHECK(status);
// Set kernel args
int arg_idx = 0;
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q_img));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer)));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
// set thread grid
global_size[1] = static_cast<size_t>((ne01 + 63) / 64);
global_size[2] = static_cast<size_t>(max_post_router_tile);
local_size[1] = 1;
local_size[2] = 1;
// Dispatch kernel
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
clReleaseMemObject(sub_buf_src1_pre);
clReleaseMemObject(buf_src1_reordered);
clReleaseMemObject(image_src1_reordered);
clReleaseMemObject(buf_src2);
clReleaseMemObject(buf_src2_emap);
clReleaseMemObject(sub_buf_dst);
clReleaseMemObject(buf_dst_image);
}
// launch kernel
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
// deallocate sub buffers and images
CL_CHECK(clReleaseMemObject(src1_sub_buffer));
CL_CHECK(clReleaseMemObject(buf_src1_image));
CL_CHECK(clReleaseMemObject(buf_src2));
return;
} // else fallback to generic kernel
} // fallback to generic MoE mxfp4 kernel
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
#ifdef GGML_OPENCL_SOA_Q
@@ -14002,6 +14292,13 @@ static void ggml_cl_argsort(ggml_backend_t backend, const ggml_tensor * src0, co
size_t local_work_size[] = {(size_t)ne00_padded, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
const int ne21 = dst->ne[1];
if ((strstr(src0->name, "_moe") != NULL) && (ne21 != 1)) {
backend_ctx->toggle_reorder = true;
}
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
}
static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

View File

@@ -371,6 +371,93 @@ kernel void kernel_restore_block_mxfp4_trans(
b->e = src_e[src_blk_offset];
}
kernel void kernel_convert_block_mxfp4_trans4_ns(
global struct block_mxfp4 * src0,
__global uint * dst_q,
__global uchar * dst_e,
uint ne00,
uint ne01
) {
uint i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK_MXFP4;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
global struct block_mxfp4 * b = src0 + src_blk_offset;
dst_e[dst_blk_offset] = b->e;
// extract quantization and unshuffle
ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0];
ushort8 post_block = (ushort8)(0);
uchar * pre_block_ptr = (uchar *)(&pre_block);
uchar * post_block_ptr = (uchar *)(&post_block);
for (int i = 0; i < QK_MXFP4 / 4; ++i) {
uchar x0 = pre_block_ptr[2*i + 0];
uchar x1 = pre_block_ptr[2*i + 1];
post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
post_block_ptr[i + QK_MXFP4 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
}
uint4 q_block = as_uint4(post_block);
uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
dst_q[offset] = q_block.x;
dst_q[offset + ne01] = q_block.y;
dst_q[offset + ne01 * 2] = q_block.z;
dst_q[offset + ne01 * 3] = q_block.w;
}
kernel void kernel_restore_block_mxfp4_trans4_ns(
__global uint * src_q,
__global uchar * src_e,
__global struct block_mxfp4 * dst0,
uint ne00,
uint ne01
) {
uint i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK_MXFP4;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
__global struct block_mxfp4 * b = dst0 + dst_blk_offset;
b->e = src_e[src_d_offset];
// collect transposed quantization parts for a block
uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
uint4 q_block;
q_block.x = src_q[src_q_offset];
q_block.y = src_q[src_q_offset + ne01];
q_block.z = src_q[src_q_offset + ne01 * 2];
q_block.w = src_q[src_q_offset + ne01 * 3];
ushort8 post_block = as_ushort8(q_block);
ushort8 pre_block = (ushort8)(0);
uchar * pre_block_ptr = (uchar *)(&pre_block);
uchar * post_block_ptr = (uchar *)(&post_block);
for (int i = 0; i < QK_MXFP4 / 4; ++i) {
uchar x0 = post_block_ptr[i + 0];
uchar x1 = post_block_ptr[i + QK_MXFP4 / 4];
pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
}
((__global ushort8 *)(&(b->qs[0])))[0] = pre_block;
}
//------------------------------------------------------------------------------
// block_q8_0
//------------------------------------------------------------------------------

View File

@@ -0,0 +1,302 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
#define TILESIZE_K 16
#define TILESIZE_M 64
#define TILESIZE_N 32
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
sign_b.hi = fp4x8.s0 & 0x8000;
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
ushort2 fp16_packed_a_1, fp16_packed_b_1;
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
sign_b.hi = fp4x8.s1 & 0x8000;
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
}
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
static inline half e8m0_to_fp16(uchar x) {
ushort bits;
bits = (ushort)(x) - (ushort)(112);
bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10);
return as_half(bits);
}
static inline float e8m0_to_fp32(uchar x) {
int bits;
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
return as_float(bits);
}
__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair
kernel void kernel_gemm_moe_mxfp4_f32_ns(
__read_only image1d_buffer_t src0_q,
__global uchar * src0_d,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global ushort * src2_emap,
__write_only image1d_buffer_t dst,
__global int * total_tiles,
uint ne00,
uint ne01
) {
uint block_id_m = get_global_id(1); // m_tile
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
return;
}
__private half16 reg_a;
__private float32 reg_c = (float32)(0);
__local half4 shared_b[128];
const ushort expert_id = src2_emap[block_id_n];
const uint row = block_id_m * TILESIZE_M;
const uint col = block_id_n * TILESIZE_N;
uint sub_block_id_m = get_local_id(0);
uint2 b_global_offset;
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
b_global_offset.y = b_global_offset.x + (16 * ne00);
uint2 b_local_offset;
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
b_local_offset.y = b_local_offset.x + 16;
// Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
// First sub-block
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5);
uint b_sub_offset = col * ne00 + step;
// Load scale for current mxfp4 block
uint s_offset = s_sub_offset + get_global_id(0);
float s = e8m0_to_fp32(src0_d[s_offset]);
// Load 16 fp4 (64-bits) in transposed layout
uint2 mxfp4x16;
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
float8 bx8_f32;
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
// Convert to half and store to LM to share within the subgroup
half8 bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
// Dequantization
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
// 32 16x16 fp16 dot product with 8 elements reduction for better precision
half16 acc;
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
// Repeat for second sub-block
uint half_step = step + TILESIZE_K;
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
b_sub_offset = col * ne00 + half_step;
// Load next 16 fp4 (64-bits) in transposed layout
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
// Convert to half and store to LM to share within the subgroup
bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
// Dequantization
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
// 32 16x16 fp16 dot product with 3-levels reduction for better precision
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
// Load poster router and share in LM
__local uint out_idx[TILESIZE_N];
if (get_local_id(0) < TILESIZE_N) {
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
if (idx == 0xFFFFFFFF) {
idx = src2[block_id_n * TILESIZE_N + 0];
}
out_idx[get_local_id(0)] = idx * ne01;
}
barrier(CLK_LOCAL_MEM_FENCE);
// Scatter results back to original position in output grid
uint m_offset = row + get_local_id(0);
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
// Store zero padding parts to the index of first output in tile, override correct result in the end
barrier(CLK_GLOBAL_MEM_FENCE);
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
}

View File

@@ -0,0 +1,161 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define QK_MXFP4 32
#define N_SIMDGROUP 4
#define SIMDGROUP_WIDTH 64
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
sign_b.hi = fp4x8.s0 & 0x8000;
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
ushort2 fp16_packed_a_1, fp16_packed_b_1;
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
sign_b.hi = fp4x8.s1 & 0x8000;
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
}
static inline float e8m0_to_fp32(uchar x) {
int bits;
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
return as_float(bits);
}
__attribute__((qcom_reqd_sub_group_size("half")))
__kernel void kernel_gemv_moe_mxfp4_f32_ns(
__global uint * src0_q,
__global uchar * src0_e,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne11
) {
uint i01 = get_global_id(0);
uint i20 = get_global_id(2);
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
uint i11 = i20 % ne11;
uint expert_id = src2[i20];
uint expert_offset = expert_id * ne00 * ne01 / 32;
__private float sum = 0.0f; // each thread calculate partial sum of one output
// loop along ne00 in block granularity, skip 4 blocks every iter
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
// load one block of q
uint4 regQ;
uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01;
regQ.s0 = src0_q[block_offset];
regQ.s1 = src0_q[block_offset + ne01];
regQ.s2 = src0_q[block_offset + ne01 * 2];
regQ.s3 = src0_q[block_offset + ne01 * 3];
uint offset = i11 * ne00 / 4 + ib00 * 8;
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
float4 shared_y4;
shared_y4 = read_imagef(src1, (offset + 0));
float4 acc = shared_y4 * convert_float4(fp16x8.lo);
shared_y4 = read_imagef(src1, (offset + 1));
acc += shared_y4 * convert_float4(fp16x8.hi);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
shared_y4 = read_imagef(src1, (offset + 2));
acc += shared_y4 * convert_float4(fp16x8.lo);
shared_y4 = read_imagef(src1, (offset + 3));
acc += shared_y4 * convert_float4(fp16x8.hi);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
shared_y4 = read_imagef(src1, (offset + 4));
acc += shared_y4 * convert_float4(fp16x8.lo);
shared_y4 = read_imagef(src1, (offset + 5));
acc += shared_y4 * convert_float4(fp16x8.hi);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
shared_y4 = read_imagef(src1, (offset + 6));
acc += shared_y4 * convert_float4(fp16x8.lo);
shared_y4 = read_imagef(src1, (offset + 7));
acc += shared_y4 * convert_float4(fp16x8.hi);
uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
}
// reduction in local memory, assumes #subgroups=4
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 1 outputs per thread in subgroup 0
if (sgid == 0) {
dst = dst + (offsetd >> 2);
dst[i01 + i20 * ne01] = sum;
}
}

View File

@@ -0,0 +1,30 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define QK4_0 32
kernel void kernel_moe_reorder_b(
global float4 * src,
global uint * router,
global float4 * dst,
global int * total_tiles,
uint K,
ushort map_ratio,
uint tile_size
) {
uint k_4 = get_global_id(0);
uint post_router_idx = get_global_id(1);
if ((k_4 >= (K / 4)) || (post_router_idx >= total_tiles[0] * tile_size)) {
return;
}
uint router_idx = router[post_router_idx];
float4 out = (float4)(0);
if (router_idx != 0xFFFFFFFF) {
ushort activation_idx = router_idx / map_ratio;
out = src[activation_idx * K / 4 + k_4];
}
dst[post_router_idx * K / 4 + k_4] = out;
}

View File

@@ -0,0 +1,82 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void kernel_moe_histogram(
__global const int * input,
__global int * hist,
uint N,
uint topK,
uint n_experts
) {
uint n = get_global_id(0);
uint k = get_global_id(1);
if (n >= N || k >= topK) {
return;
}
int expert_id = input[n * n_experts + k];
atomic_inc(&hist[expert_id]);
}
__kernel void kernel_moe_scan(
__global int * hist,
__global int * tile_offset,
__global int * total_tiles,
__global int * slot_counter,
int tile_size,
uint n_experts
) {
int offset = 0;
for (int v = 0; v < n_experts; v++) {
int count = hist[v];
int tiles = (count + tile_size - 1) / tile_size;
tile_offset[v] = offset;
offset += tiles;
hist[v] = 0;
slot_counter[v] = 0;
}
*total_tiles = offset;
}
__kernel void kernel_moe_scatter(
__global const int * input,
__global int * post_router,
__global ushort * emap,
__global const int * tile_offset,
__global int * slot_counter,
int N,
int topK,
uint n_experts
) {
uint n = get_global_id(0);
uint k = get_global_id(1);
if (n >= N || k >= topK) {
return;
}
int val = input[n * n_experts + k];
int local_slot = atomic_inc(&slot_counter[val]);
int tile_idx = tile_offset[val] + (local_slot / 32);
int lane = local_slot % 32;
int out_pos = tile_idx * 32 + lane;
post_router[out_pos] = n * topK + k;
emap[tile_idx] = val;
}
__kernel void kernel_moe_fill(
__global int * post_router,
__global int * total_tiles,
int tile_size
) {
int tile_id = get_global_id(0);
int vec_id_in_tile = get_global_id(1);
if (tile_id < total_tiles[0]) {
post_router[tile_id * tile_size + vec_id_in_tile] = 0xFFFFFFFF;
}
}

View File

@@ -1,6 +1,7 @@
#include "virtgpu-shm.h"
#include "virtgpu.h"
#include "ggml-remoting.h"
#include <assert.h>

View File

@@ -1,4 +1,5 @@
#include "virtgpu.h"
#include "ggml-remoting.h"
#include <stdio.h>
#include <unistd.h>

View File

@@ -18,8 +18,6 @@
#include <cstring>
#include "ggml-remoting.h"
#define VIRGL_RENDERER_UNSTABLE_APIS 1
#include "apir_hw.h"
#include <drm/virtgpu_drm.h>

View File

@@ -1779,12 +1779,12 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_mul_mat_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.use_subgroup_matrix = context.supports_subgroup_matrix;
auto it = mul_mat_fast_pipelines.find(key);
@@ -2143,6 +2143,9 @@ class ggml_webgpu_shader_lib {
// variant suffix for src1 type
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
if (key.vectorized) {
variant += "_vectorized";
}
auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines);

View File

@@ -55,8 +55,13 @@
uint64_t ggml_graph_next_uid(void) {
#ifdef _MSC_VER
#if defined(_WIN32)
static volatile LONG counter = 1;
return (uint64_t) InterlockedIncrement(&counter) - 1;
#else
static volatile long long counter = 1;
return (uint64_t) _InterlockedIncrement64(&counter) - 1;
#endif
#else
static uint64_t counter = 1;
return __atomic_fetch_add(&counter, 1, __ATOMIC_RELAXED);

View File

@@ -1 +1 @@
387fa29fbbf3149f06a631c7850b6c35c24b0232
19eac6f0edaf285506eb6228d31bb9caeda9aba1

View File

@@ -2253,6 +2253,28 @@ public:
llama_io_write_buffer(
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
~llama_io_write_buffer() {
#if 1
// TODO: add backend support to batch tensor_get? or some other way to speed this up
for (const auto & info : winfos) {
ggml_backend_tensor_get(info.tensor, info.ptr, info.offset, info.size);
}
#else
// flush the writes asynchronously
// this helps on Macs, but on other devices - it does not. just an example
std::vector<std::future<void>> futures;
futures.reserve(winfos.size());
for (const auto & info : winfos) {
futures.push_back(std::async(std::launch::async, [info]() {
ggml_backend_tensor_get(info.tensor, info.ptr, info.offset, info.size);
}));
}
for (auto & f : futures) {
f.wait();
}
#endif
}
void write(const void * src, size_t size) override {
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
@@ -2267,7 +2289,10 @@ public:
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
}
ggml_backend_tensor_get(tensor, ptr, offset, size);
// save the write for later during destruction
winfos.push_back({tensor, ptr, size, offset});
ptr += size;
size_written += size;
buf_size -= size;
@@ -2281,25 +2306,48 @@ private:
uint8_t * ptr;
size_t buf_size = 0;
size_t size_written = 0;
struct write_info {
const ggml_tensor * tensor;
uint8_t * ptr;
size_t size;
size_t offset;
};
std::vector<write_info> winfos;
};
class llama_io_read_buffer : public llama_io_read_i {
public:
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
const uint8_t * read(size_t size) override {
const uint8_t * base_ptr = ptr;
~llama_io_read_buffer() {
// flush the reads
for (const auto & info : rinfos) {
ggml_backend_tensor_set(info.tensor, info.ptr, info.offset, info.size);
}
}
void read(void * dst, size_t size) override {
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
}
memcpy(dst, ptr, size);
ptr += size;
size_read += size;
buf_size -= size;
return base_ptr;
}
void read_to(void * dst, size_t size) override {
memcpy(dst, read(size), size);
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
}
// save for later during destruction
rinfos.push_back({tensor, ptr, size, offset});
ptr += size;
size_read += size;
buf_size -= size;
}
size_t n_bytes() override {
@@ -2310,6 +2358,14 @@ private:
const uint8_t * ptr;
size_t buf_size = 0;
size_t size_read = 0;
struct read_info {
ggml_tensor * tensor;
const uint8_t * ptr;
size_t size;
size_t offset;
};
std::vector<read_info> rinfos;
};
class llama_io_write_file : public llama_io_write_i {
@@ -2341,15 +2397,15 @@ class llama_io_read_file : public llama_io_read_i {
public:
llama_io_read_file(llama_file * f) : file(f) {}
void read_to(void * dst, size_t size) override {
void read(void * dst, size_t size) override {
file->read_raw(dst, size);
size_read += size;
}
const uint8_t * read(size_t size) override {
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
temp_buffer.resize(size);
read_to(temp_buffer.data(), size);
return temp_buffer.data();
read(temp_buffer.data(), size);
ggml_backend_tensor_set(tensor, temp_buffer.data(), offset, size);
}
size_t n_bytes() override {

View File

@@ -1,5 +1,7 @@
#include "llama-io.h"
#include <vector>
void llama_io_write_i::write_string(const std::string & str) {
uint32_t str_size = str.size();
@@ -9,7 +11,10 @@ void llama_io_write_i::write_string(const std::string & str) {
void llama_io_read_i::read_string(std::string & str) {
uint32_t str_size;
read_to(&str_size, sizeof(str_size));
read(&str_size, sizeof(str_size));
str.assign((const char *) read(str_size), str_size);
std::vector<char> buf(str_size);
read(buf.data(), str_size);
str.assign(buf.data(), str_size);
}

View File

@@ -25,8 +25,8 @@ public:
llama_io_read_i() = default;
virtual ~llama_io_read_i() = default;
virtual const uint8_t * read(size_t size) = 0;
virtual void read_to(void * dst, size_t size) = 0;
virtual void read(void * dst, size_t size) = 0;
virtual void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0;
// bytes read so far
virtual size_t n_bytes() = 0;

View File

@@ -1900,14 +1900,14 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
uint32_t n_stream_cur;
io.read_to(&n_stream_cur, sizeof(n_stream_cur));
io.read(&n_stream_cur, sizeof(n_stream_cur));
if (n_stream_cur != n_stream) {
throw std::runtime_error("n_stream mismatch");
}
for (uint32_t s = 0; s < n_stream; ++s) {
uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count));
io.read(&cell_count, sizeof(cell_count));
if (cell_count == 0) {
continue;
@@ -2082,8 +2082,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
llama_pos pos;
uint32_t n_seq_id;
io.read_to(&pos, sizeof(pos));
io.read_to(&n_seq_id, sizeof(n_seq_id));
io.read(&pos, sizeof(pos));
io.read(&n_seq_id, sizeof(n_seq_id));
if (n_seq_id != 1) {
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
@@ -2092,7 +2092,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
if (hparams.n_pos_per_embd() > 1) {
llama_kv_cell_ext ext;
io.read_to(&ext, sizeof(ext));
io.read(&ext, sizeof(ext));
ubatch.pos[i + ubatch.n_tokens] = ext.y;
ubatch.pos[i + ubatch.n_tokens*2] = ext.x;
@@ -2101,7 +2101,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
// read the sequence id, but directly discard it - we will use dest_seq_id instead
{
llama_seq_id seq_id;
io.read_to(&seq_id, sizeof(seq_id));
io.read(&seq_id, sizeof(seq_id));
}
ubatch.pos[i] = pos;
@@ -2143,20 +2143,20 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
llama_pos pos;
uint32_t n_seq_id;
io.read_to(&pos, sizeof(pos));
io.read_to(&n_seq_id, sizeof(n_seq_id));
io.read(&pos, sizeof(pos));
io.read(&n_seq_id, sizeof(n_seq_id));
cells.pos_set(i, pos);
if (hparams.n_pos_per_embd() > 1) {
llama_kv_cell_ext ext;
io.read_to(&ext, sizeof(ext));
io.read(&ext, sizeof(ext));
cells.ext_set(i, ext);
}
for (uint32_t j = 0; j < n_seq_id; ++j) {
llama_seq_id seq_id;
io.read_to(&seq_id, sizeof(seq_id));
io.read(&seq_id, sizeof(seq_id));
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
@@ -2189,8 +2189,8 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
uint32_t v_trans;
uint32_t n_layer;
io.read_to(&v_trans, sizeof(v_trans));
io.read_to(&n_layer, sizeof(n_layer));
io.read(&v_trans, sizeof(v_trans));
io.read(&n_layer, sizeof(n_layer));
if (n_layer != layers.size()) {
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
@@ -2217,7 +2217,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read type of key
int32_t k_type_i_ref;
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
io.read(&k_type_i_ref, sizeof(k_type_i_ref));
const int32_t k_type_i = (int32_t) k->type;
if (k_type_i != k_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
@@ -2226,7 +2226,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read row size of key
uint64_t k_size_row_ref;
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
io.read(&k_size_row_ref, sizeof(k_size_row_ref));
const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
if (k_size_row != k_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
@@ -2236,13 +2236,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
if (cell_count) {
if (sinfo.is_contiguous()) {
// Fast path: contiguous cells, single memcpy
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
io.read_tensor(k, sinfo.head() * k_size_row, cell_count * k_size_row);
} else {
// Slow path: scatter to non-contiguous positions
const void * src = io.read(cell_count * k_size_row);
for (uint32_t i = 0; i < cell_count; ++i) {
const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
io.read_tensor(k, dst_offset, k_size_row);
}
}
}
@@ -2261,7 +2260,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read type of value
int32_t v_type_i_ref;
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t) v->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
@@ -2270,7 +2269,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read row size of value
uint64_t v_size_row_ref;
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
io.read(&v_size_row_ref, sizeof(v_size_row_ref));
const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
if (v_size_row != v_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
@@ -2280,13 +2279,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
if (cell_count) {
if (sinfo.is_contiguous()) {
// Fast path: contiguous cells, single memcpy
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
io.read_tensor(v, sinfo.head() * v_size_row, cell_count * v_size_row);
} else {
// Slow path: scatter to non-contiguous positions
const void * src = io.read(cell_count * v_size_row);
for (uint32_t i = 0; i < cell_count; ++i) {
const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
io.read_tensor(v, dst_offset, v_size_row);
}
}
}
@@ -2305,7 +2303,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read type of value
int32_t v_type_i_ref;
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t) v->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
@@ -2314,7 +2312,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read element size of value
uint32_t v_size_el_ref;
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
io.read(&v_size_el_ref, sizeof(v_size_el_ref));
const size_t v_size_el = ggml_type_size(v->type);
if (v_size_el != v_size_el_ref) {
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
@@ -2323,7 +2321,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read GQA embedding size
uint32_t n_embd_v_gqa_ref;
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
io.read(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
return false;
@@ -2335,15 +2333,14 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
const uint32_t h = sinfo.head();
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (h + j * cells.size()) * v_size_el;
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
io.read_tensor(v, dst_offset, cell_count * v_size_el);
}
} else {
// Slow path: scatter to non-contiguous positions
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const void * src = io.read(cell_count * v_size_el);
for (uint32_t i = 0; i < cell_count; ++i) {
const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
io.read_tensor(v, dst_offset, v_size_el);
}
}
}

View File

@@ -743,7 +743,7 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i
GGML_UNUSED(flags);
uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count));
io.read(&cell_count, sizeof(cell_count));
bool res = true;
@@ -879,8 +879,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
llama_pos pos;
uint32_t n_seq_id;
io.read_to(&pos, sizeof(pos));
io.read_to(&n_seq_id, sizeof(n_seq_id));
io.read(&pos, sizeof(pos));
io.read(&n_seq_id, sizeof(n_seq_id));
if (n_seq_id != 0) {
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
@@ -920,14 +920,14 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
llama_pos pos;
uint32_t n_seq_id;
io.read_to(&pos, sizeof(pos));
io.read_to(&n_seq_id, sizeof(n_seq_id));
io.read(&pos, sizeof(pos));
io.read(&n_seq_id, sizeof(n_seq_id));
cell.pos = pos;
for (uint32_t j = 0; j < n_seq_id; ++j) {
llama_seq_id seq_id;
io.read_to(&seq_id, sizeof(seq_id));
io.read(&seq_id, sizeof(seq_id));
if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) {
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max);
@@ -961,8 +961,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
uint32_t s_trans;
uint32_t n_layer;
io.read_to(&s_trans, sizeof(s_trans));
io.read_to(&n_layer, sizeof(n_layer));
io.read(&s_trans, sizeof(s_trans));
io.read(&n_layer, sizeof(n_layer));
if (n_layer != hparams.n_layer) {
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
@@ -984,7 +984,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read type of key
int32_t r_type_i_ref;
io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
io.read(&r_type_i_ref, sizeof(r_type_i_ref));
const int32_t r_type_i = (int32_t) r_l[il]->type;
if (r_type_i != r_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
@@ -993,7 +993,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read row size of key
uint64_t r_size_row_ref;
io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
io.read(&r_size_row_ref, sizeof(r_size_row_ref));
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
if (r_size_row != r_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
@@ -1002,7 +1002,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
if (cell_count) {
// Read and set the keys for the whole cell range
ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
io.read_tensor(r_l[il], head * r_size_row, cell_count * r_size_row);
}
}
@@ -1013,7 +1013,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read type of value
int32_t s_type_i_ref;
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
const int32_t s_type_i = (int32_t)s_l[il]->type;
if (s_type_i != s_type_i_ref) {
@@ -1023,7 +1023,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read row size of value
uint64_t s_size_row_ref;
io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
io.read(&s_size_row_ref, sizeof(s_size_row_ref));
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
if (s_size_row != s_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
@@ -1032,7 +1032,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
if (cell_count) {
// Read and set the values for the whole cell range
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
io.read_tensor(s_l[il], head * s_size_row, cell_count * s_size_row);
}
}
} else {
@@ -1045,7 +1045,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read type of value
int32_t s_type_i_ref;
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
const int32_t s_type_i = (int32_t)s_l[il]->type;
if (s_type_i != s_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
@@ -1054,7 +1054,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read element size of value
uint32_t s_size_el_ref;
io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
io.read(&s_size_el_ref, sizeof(s_size_el_ref));
const size_t s_size_el = ggml_type_size(s_l[il]->type);
if (s_size_el != s_size_el_ref) {
LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
@@ -1063,7 +1063,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// Read state embedding size
uint32_t n_embd_s_ref;
io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
io.read(&n_embd_s_ref, sizeof(n_embd_s_ref));
if (n_embd_s != n_embd_s_ref) {
LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
return false;
@@ -1073,7 +1073,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_s; ++j) {
const size_t dst_offset = (head + j * size) * s_size_el;
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
io.read_tensor(s_l[il], dst_offset, cell_count * s_size_el);
}
}
}

View File

@@ -1994,7 +1994,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
}
}
if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) {
if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false)) {
// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
// cancel the factor from the convert script
hparams.rope_yarn_log_mul /= 0.1f;
@@ -2868,7 +2868,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
hparams.f_attn_temp_offset = 0.0f;

View File

@@ -683,9 +683,9 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_mod
LLAMA_LOG_WARN("%s: %-36s - applying manual override: %s -> %s\n",
__func__, tensor_name.c_str(), ggml_type_name(new_type), ggml_type_name(qtype));
new_type = qtype;
manual = true;
break;
}
manual = true;
break;
}
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -36,7 +36,7 @@ using json = nlohmann::ordered_json;
constexpr int HTTP_POLLING_SECONDS = 1;
static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) {
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) {
if (pos_min == -1) {
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
}
@@ -46,19 +46,15 @@ static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int i
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto cur = server_prompt_checkpoint {
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.n_tokens = */ n_tokens,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
};
ckpt.pos_min = pos_min;
ckpt.pos_max = pos_max;
ckpt.n_tokens = n_tokens;
ckpt.data.resize(checkpoint_size);
const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
}
return cur;
}
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
@@ -364,7 +360,12 @@ struct server_slot {
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
const auto n_tokens = prompt.tokens.size();
spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
//const int64_t t_start = ggml_time_us();
server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens);
//const int64_t t_total = ggml_time_us() - t_start;
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
@@ -1836,7 +1837,8 @@ private:
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
const auto & cur = slot.prompt.checkpoints.emplace_back(server_get_checkpoint(ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max));
auto & cur = slot.prompt.checkpoints.emplace_back();
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
SLT_WRN(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",

View File

@@ -1,5 +1,5 @@
<script lang="ts">
import * as Tooltip from '../src/lib/components/ui/tooltip';
import * as Tooltip from '../../src/lib/components/ui/tooltip';
interface Props {
children: any;

View File

@@ -1,7 +1,7 @@
import type { Preview } from '@storybook/sveltekit';
import '../src/app.css';
import ModeWatcherDecorator from './ModeWatcherDecorator.svelte';
import TooltipProviderDecorator from './TooltipProviderDecorator.svelte';
import ModeWatcherDecorator from './decorators/ModeWatcherDecorator.svelte';
import TooltipProviderDecorator from './decorators/TooltipProviderDecorator.svelte';
const preview: Preview = {
parameters: {

View File

@@ -3640,9 +3640,9 @@
}
},
"node_modules/bits-ui": {
"version": "2.17.3",
"resolved": "https://registry.npmjs.org/bits-ui/-/bits-ui-2.17.3.tgz",
"integrity": "sha512-Bef41uY9U2jaBJHPhcPvmBNkGec5Wx2z6eioDsTmsaR2vH4QoaOcPi75gzCG3+/2TNr6v/qBwzgWNPYCxNtrEA==",
"version": "2.18.0",
"resolved": "https://registry.npmjs.org/bits-ui/-/bits-ui-2.18.0.tgz",
"integrity": "sha512-GLOBZRVy3hxNHIQ2MpD/+5aK9KcBFZRhUJtZ1UDABXdlVR4K6zFpgt4T+Rwuhf2sQzlc6yK1q/DprHPjwT4Pjw==",
"dev": true,
"license": "MIT",
"dependencies": {

View File

@@ -28,7 +28,6 @@ import type {
ApiRouterModelsUnloadResponse,
// Chat types
ChatAttachmentDisplayItem,
ChatAttachmentPreviewItem,
ChatMessageType,
ChatRole,
ChatUploadedFile,
@@ -92,7 +91,6 @@ declare global {
ApiRouterModelsUnloadResponse,
// Chat types
ChatAttachmentDisplayItem,
ChatAttachmentPreviewItem,
ChatMessagePromptProgress,
ChatMessageSiblingInfo,
ChatMessageTimings,

View File

@@ -1,3 +1,5 @@
import { isElementInViewport } from '$lib/utils/viewport';
/**
* Svelte action that fades in an element when it enters the viewport.
* Uses IntersectionObserver for efficient viewport detection.
@@ -12,17 +14,8 @@ export function fadeInView(
) {
const { duration = 300, y = 0, skipIfVisible = false } = options;
if (skipIfVisible) {
const rect = node.getBoundingClientRect();
const isAlreadyVisible =
rect.top < window.innerHeight &&
rect.bottom > 0 &&
rect.left < window.innerWidth &&
rect.right > 0;
if (isAlreadyVisible) {
return;
}
if (skipIfVisible && isElementInViewport(node)) {
return;
}
node.style.opacity = '0';

View File

@@ -0,0 +1,11 @@
---
name: app
description: Opinionated app components building on top of ./ui primitives
---
- Can include business logic and state management
- Can include data fetching and caching logic
- Should use original spelling for HTML-native events and `camelCase` for custom events
- Props and markup attributes should be listed alphabetically
- Use JS Objects and Arrays for CSS classes and styles when they are dynamic
- Whenever there can be repetition in the component's markup, if it's too small to be decoupled as a separate component — use Svelte 5's `{#snippet}` + `{@render}`

View File

@@ -5,15 +5,16 @@
import { TooltipSide } from '$lib/enums';
interface Props {
icon: Component;
tooltip: string;
variant?: ButtonVariant;
size?: ButtonSize;
iconSize?: string;
ariaLabel?: string;
class?: string;
disabled?: boolean;
icon: Component;
iconSize?: string;
onclick: (e?: MouseEvent) => void;
'aria-label'?: string;
size?: ButtonSize;
stopPropagationOnClick?: boolean;
tooltip: string;
variant?: ButtonVariant;
tooltipSide?: TooltipSide;
}
@@ -26,8 +27,9 @@
disabled = false,
iconSize = 'h-3 w-3',
tooltipSide = TooltipSide.TOP,
stopPropagationOnClick = false,
onclick,
'aria-label': ariaLabel
ariaLabel
}: Props = $props();
</script>
@@ -37,13 +39,18 @@
{variant}
{size}
{disabled}
{onclick}
onclick={(e: MouseEvent) => {
if (stopPropagationOnClick) e.stopPropagation();
onclick?.(e);
}}
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
aria-label={ariaLabel || tooltip}
>
{@const IconComponent = icon}
<IconComponent class={iconSize} />
{#if icon}
{@const IconComponent = icon}
<IconComponent class={iconSize} />
{/if}
</Button>
</Tooltip.Trigger>

View File

@@ -1,18 +1,17 @@
<script lang="ts">
import { Copy } from '@lucide/svelte';
import { copyToClipboard } from '$lib/utils';
import ActionIcon from './ActionIcon.svelte';
interface Props {
ariaLabel?: string;
canCopy?: boolean;
text: string;
}
let { ariaLabel = 'Copy to clipboard', canCopy = true, text }: Props = $props();
export let ariaLabel: string = 'Copy to clipboard';
export let canCopy: boolean = true;
export let text: string;
</script>
<Copy
class="h-3 w-3 flex-shrink-0 cursor-{canCopy ? 'pointer' : 'not-allowed'}"
aria-label={ariaLabel}
<ActionIcon
icon={Copy}
tooltip={ariaLabel}
iconSize="h-4 w-4"
disabled={!canCopy}
onclick={() => canCopy && copyToClipboard(text)}
/>

View File

@@ -1,27 +0,0 @@
<script lang="ts">
import { X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
interface Props {
id: string;
onRemove?: (id: string) => void;
class?: string;
iconSize?: number;
}
let { id, onRemove, class: className = '', iconSize = 3 }: Props = $props();
</script>
<Button
type="button"
variant="ghost"
size="icon-sm"
class="bg-white/20 p-0 hover:bg-white/30 {className}"
onclick={(e: MouseEvent) => {
e.stopPropagation();
onRemove?.(id);
}}
aria-label="Remove file"
>
<X class="h-{iconSize} w-{iconSize}" />
</Button>

View File

@@ -1,46 +0,0 @@
<script lang="ts">
import { Eye } from '@lucide/svelte';
import { ActionIconCopyToClipboard } from '$lib/components/app';
import { FileTypeText } from '$lib/enums';
interface Props {
code: string;
language: string;
disabled?: boolean;
onPreview?: (code: string, language: string) => void;
}
let { code, language, disabled = false, onPreview }: Props = $props();
const showPreview = $derived(language?.toLowerCase() === FileTypeText.HTML);
function handlePreview() {
if (disabled) return;
onPreview?.(code, language);
}
</script>
<div class="code-block-actions">
<div class="copy-code-btn" class:opacity-50={disabled} class:!cursor-not-allowed={disabled}>
<ActionIconCopyToClipboard
text={code}
canCopy={!disabled}
ariaLabel={disabled ? 'Code incomplete' : 'Copy code'}
/>
</div>
{#if showPreview}
<button
class="preview-code-btn"
class:opacity-50={disabled}
class:!cursor-not-allowed={disabled}
title={disabled ? 'Code incomplete' : 'Preview code'}
aria-label="Preview code"
aria-disabled={disabled}
type="button"
onclick={handlePreview}
>
<Eye size={16} />
</button>
{/if}
</div>

View File

@@ -9,11 +9,5 @@
/** Styled icon button for action triggers with tooltip. */
export { default as ActionIcon } from './ActionIcon.svelte';
/** Code block actions component (copy, preview). */
export { default as ActionIconsCodeBlock } from './ActionIconsCodeBlock.svelte';
/** Copy-to-clipboard icon button with click handler. */
/** Copy-to-clipboard icon button with clipboard logic. */
export { default as ActionIconCopyToClipboard } from './ActionIconCopyToClipboard.svelte';
/** Remove/delete icon button with X icon. */
export { default as ActionIconRemove } from './ActionIconRemove.svelte';

View File

@@ -1,5 +1,4 @@
<script lang="ts">
import { cn } from '$lib/components/ui/utils';
import type { Snippet } from 'svelte';
interface Props {
@@ -13,10 +12,10 @@
</script>
<button
class={cn(
class={[
'inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75',
className
)}
]}
{onclick}
>
{#if icon}

View File

@@ -1,39 +0,0 @@
<script lang="ts">
import { ModelModality } from '$lib/enums';
import { MODALITY_ICONS, MODALITY_LABELS } from '$lib/constants';
import { cn } from '$lib/components/ui/utils';
type DisplayableModality = ModelModality.VISION | ModelModality.AUDIO;
interface Props {
modalities: ModelModality[];
class?: string;
}
let { modalities, class: className = '' }: Props = $props();
// Filter to only modalities that have icons (VISION, AUDIO)
const displayableModalities = $derived(
modalities.filter(
(m): m is DisplayableModality => m === ModelModality.VISION || m === ModelModality.AUDIO
)
);
</script>
{#each displayableModalities as modality, index (index)}
{@const IconComponent = MODALITY_ICONS[modality]}
{@const label = MODALITY_LABELS[modality]}
<span
class={cn(
'inline-flex items-center gap-1 rounded-md bg-muted px-2 py-1 text-xs font-medium',
className
)}
>
{#if IconComponent}
<IconComponent class="h-3 w-3" />
{/if}
{label}
</span>
{/each}

View File

@@ -0,0 +1,32 @@
<script lang="ts">
import { Eye, Mic } from '@lucide/svelte';
import { ModelModality } from '$lib/enums';
interface Props {
modalities: ModelModality[];
class?: string;
}
let { modalities, class: className = '' }: Props = $props();
</script>
{#each modalities as modality (modality)}
{#if modality === ModelModality.VISION || modality === ModelModality.AUDIO}
<span
class={[
'inline-flex items-center gap-1 rounded-md bg-muted px-2 py-1 text-xs font-medium',
className
]}
>
{#if modality === ModelModality.VISION}
<Eye class="h-3 w-3" />
Vision
{:else}
<Mic class="h-3 w-3" />
Audio
{/if}
</span>
{/if}
{/each}

View File

@@ -6,11 +6,8 @@
*
*/
/** Badge displaying chat statistics (tokens, timing). */
export { default as BadgeChatStatistic } from './BadgeChatStatistic.svelte';
/** Generic info badge with optional tooltip and click handler. */
export { default as BadgeInfo } from './BadgeInfo.svelte';
/** Badge indicating model modality (vision, audio, tools). */
export { default as BadgeModality } from './BadgeModality.svelte';
export { default as BadgesModality } from './BadgesModality.svelte';

View File

@@ -1,284 +0,0 @@
<script lang="ts">
import { Button } from '$lib/components/ui/button';
import * as Alert from '$lib/components/ui/alert';
import { SyntaxHighlightedCode } from '$lib/components/app';
import { FileText, Image, Music, FileIcon, Eye, Info } from '@lucide/svelte';
import {
isTextFile,
isImageFile,
isPdfFile,
isAudioFile,
getLanguageFromFilename,
createBase64DataUrl
} from '$lib/utils';
import { convertPDFToImage } from '$lib/utils/browser-only';
import { modelsStore } from '$lib/stores/models.svelte';
interface Props {
// Either an uploaded file or a stored attachment
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
// For uploaded files
preview?: string;
name?: string;
textContent?: string;
// For checking vision modality
activeModelId?: string;
}
let { uploadedFile, attachment, preview, name, textContent, activeModelId }: Props = $props();
let hasVisionModality = $derived(
activeModelId ? modelsStore.modelSupportsVision(activeModelId) : false
);
let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File');
// Determine file type from uploaded file or attachment
let isAudio = $derived(isAudioFile(attachment, uploadedFile));
let isImage = $derived(isImageFile(attachment, uploadedFile));
let isPdf = $derived(isPdfFile(attachment, uploadedFile));
let isText = $derived(isTextFile(attachment, uploadedFile));
let displayPreview = $derived(
uploadedFile?.preview ||
(isImage && attachment && 'base64Url' in attachment ? attachment.base64Url : preview)
);
let displayTextContent = $derived(
uploadedFile?.textContent ||
(attachment && 'content' in attachment ? attachment.content : textContent)
);
let language = $derived(getLanguageFromFilename(displayName));
let IconComponent = $derived(() => {
if (isImage) return Image;
if (isText || isPdf) return FileText;
if (isAudio) return Music;
return FileIcon;
});
let pdfViewMode = $state<'text' | 'pages'>('pages');
let pdfImages = $state<string[]>([]);
let pdfImagesLoading = $state(false);
let pdfImagesError = $state<string | null>(null);
async function loadPdfImages() {
if (!isPdf || pdfImages.length > 0 || pdfImagesLoading) return;
pdfImagesLoading = true;
pdfImagesError = null;
try {
let file: File | null = null;
if (uploadedFile?.file) {
file = uploadedFile.file;
} else if (isPdf && attachment) {
// Check if we have pre-processed images
if (
'images' in attachment &&
attachment.images &&
Array.isArray(attachment.images) &&
attachment.images.length > 0
) {
pdfImages = attachment.images;
return;
}
// Convert base64 back to File for processing
if ('base64Data' in attachment && attachment.base64Data) {
const base64Data = attachment.base64Data;
const byteCharacters = atob(base64Data);
const byteNumbers = new Array(byteCharacters.length);
for (let i = 0; i < byteCharacters.length; i++) {
byteNumbers[i] = byteCharacters.charCodeAt(i);
}
const byteArray = new Uint8Array(byteNumbers);
file = new File([byteArray], displayName, { type: 'application/pdf' });
}
}
if (file) {
pdfImages = await convertPDFToImage(file);
} else {
throw new Error('No PDF file available for conversion');
}
} catch (error) {
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
} finally {
pdfImagesLoading = false;
}
}
export function reset() {
pdfImages = [];
pdfImagesLoading = false;
pdfImagesError = null;
pdfViewMode = 'pages';
}
$effect(() => {
if (isPdf && pdfViewMode === 'pages') {
loadPdfImages();
}
});
</script>
<div class="space-y-4">
<div class="flex items-center justify-end gap-6">
{#if isPdf}
<div class="flex items-center gap-2">
<Button
variant={pdfViewMode === 'text' ? 'default' : 'outline'}
size="sm"
onclick={() => (pdfViewMode = 'text')}
disabled={pdfImagesLoading}
>
<FileText class="mr-1 h-4 w-4" />
Text
</Button>
<Button
variant={pdfViewMode === 'pages' ? 'default' : 'outline'}
size="sm"
onclick={() => {
pdfViewMode = 'pages';
loadPdfImages();
}}
disabled={pdfImagesLoading}
>
{#if pdfImagesLoading}
<div
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
></div>
{:else}
<Eye class="mr-1 h-4 w-4" />
{/if}
Pages
</Button>
</div>
{/if}
</div>
<div class="flex-1 overflow-auto">
{#if isImage && displayPreview}
<div class="flex items-center justify-center">
<img
src={displayPreview}
alt={displayName}
class="max-h-full rounded-lg object-contain shadow-lg"
/>
</div>
{:else if isPdf && pdfViewMode === 'pages'}
{#if !hasVisionModality && activeModelId}
<Alert.Root class="mb-4">
<Info class="h-4 w-4" />
<Alert.Title>Preview only</Alert.Title>
<Alert.Description>
<span class="inline-flex">
The selected model does not support vision. Only the extracted
<!-- svelte-ignore a11y_click_events_have_key_events -->
<!-- svelte-ignore a11y_no_static_element_interactions -->
<span class="mx-1 cursor-pointer underline" onclick={() => (pdfViewMode = 'text')}>
text
</span>
will be sent to the model.
</span>
</Alert.Description>
</Alert.Root>
{/if}
{#if pdfImagesLoading}
<div class="flex items-center justify-center p-8">
<div class="text-center">
<div
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"
></div>
<p class="text-muted-foreground">Converting PDF to images...</p>
</div>
</div>
{:else if pdfImagesError}
<div class="flex items-center justify-center p-8">
<div class="text-center">
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
<p class="mb-4 text-muted-foreground">Failed to load PDF images</p>
<p class="text-sm text-muted-foreground">{pdfImagesError}</p>
<Button class="mt-4" onclick={() => (pdfViewMode = 'text')}>View as Text</Button>
</div>
</div>
{:else if pdfImages.length > 0}
<div class="max-h-[70vh] space-y-4 overflow-auto">
{#each pdfImages as image, index (image)}
<div class="text-center">
<p class="mb-2 text-sm text-muted-foreground">Page {index + 1}</p>
<img
src={image}
alt="PDF Page {index + 1}"
class="mx-auto max-w-full rounded-lg shadow-lg"
/>
</div>
{/each}
</div>
{:else}
<div class="flex items-center justify-center p-8">
<div class="text-center">
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
<p class="mb-4 text-muted-foreground">No PDF pages available</p>
</div>
</div>
{/if}
{:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent}
<SyntaxHighlightedCode code={displayTextContent} {language} maxWidth="calc(69rem - 2rem)" />
{:else if isAudio}
<div class="flex items-center justify-center p-8">
<div class="w-full max-w-md text-center">
<Music class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
{#if uploadedFile?.preview}
<audio controls class="mb-4 w-full" src={uploadedFile.preview}>
Your browser does not support the audio element.
</audio>
{:else if isAudio && attachment && 'mimeType' in attachment && 'base64Data' in attachment}
<audio
controls
class="mb-4 w-full"
src={createBase64DataUrl(attachment.mimeType, attachment.base64Data)}
>
Your browser does not support the audio element.
</audio>
{:else}
<p class="mb-4 text-muted-foreground">Audio preview not available</p>
{/if}
<p class="text-sm text-muted-foreground">
{displayName}
</p>
</div>
</div>
{:else}
<div class="flex items-center justify-center p-8">
<div class="text-center">
{#if IconComponent}
<IconComponent class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
{/if}
<p class="mb-4 text-muted-foreground">Preview not available for this file type</p>
</div>
</div>
{/if}
</div>
</div>

View File

@@ -1,165 +0,0 @@
<script lang="ts">
import { ActionIconRemove } from '$lib/components/app';
import { formatFileSize, getFileTypeLabel, getPreviewText, isTextFile } from '$lib/utils';
import { AttachmentType } from '$lib/enums';
interface Props {
class?: string;
id: string;
onClick?: (event?: MouseEvent) => void;
onRemove?: (id: string) => void;
name: string;
readonly?: boolean;
size?: number;
textContent?: string;
// Either uploaded file or stored attachment
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
}
let {
class: className = '',
id,
onClick,
onRemove,
name,
readonly = false,
size,
textContent,
uploadedFile,
attachment
}: Props = $props();
let isText = $derived(isTextFile(attachment, uploadedFile));
let fileTypeLabel = $derived.by(() => {
if (uploadedFile?.type) {
return getFileTypeLabel(uploadedFile.type);
}
if (attachment) {
if ('mimeType' in attachment && attachment.mimeType) {
return getFileTypeLabel(attachment.mimeType);
}
if (attachment.type) {
return getFileTypeLabel(attachment.type);
}
}
return getFileTypeLabel(name);
});
let pdfProcessingMode = $derived.by(() => {
if (attachment?.type === AttachmentType.PDF) {
const pdfAttachment = attachment as DatabaseMessageExtraPdfFile;
return pdfAttachment.processedAsImages ? 'Sent as Image' : 'Sent as Text';
}
return null;
});
</script>
{#if isText}
{#if readonly}
<!-- Readonly mode (ChatMessage) -->
<button
class="cursor-pointer rounded-lg border border-border bg-muted p-3 transition-shadow hover:shadow-md {className} w-full max-w-2xl"
onclick={onClick}
aria-label={`Preview ${name}`}
type="button"
>
<div class="flex items-start gap-3">
<div class="flex min-w-0 flex-1 flex-col items-start text-left">
<span class="w-full truncate text-sm font-medium text-foreground">{name}</span>
{#if size}
<span class="text-xs text-muted-foreground">{formatFileSize(size)}</span>
{/if}
{#if textContent}
<div class="relative mt-2 w-full">
<div
class="overflow-hidden font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground"
>
{getPreviewText(textContent)}
</div>
{#if textContent.length > 150}
<div
class="pointer-events-none absolute right-0 bottom-0 left-0 h-6 bg-gradient-to-t from-muted to-transparent"
></div>
{/if}
</div>
{/if}
</div>
</div>
</button>
{:else}
<!-- Non-readonly mode (ChatForm) -->
<button
class="group relative rounded-lg border border-border bg-muted p-3 {className} {textContent
? 'max-h-24 max-w-72'
: 'max-w-36'} cursor-pointer text-left"
onclick={onClick}
>
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
<ActionIconRemove {id} {onRemove} />
</div>
<div class="pr-8">
<span class="mb-3 block truncate text-sm font-medium text-foreground">{name}</span>
{#if textContent}
<div class="relative">
<div
class="overflow-hidden font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground"
style="max-height: 3rem; line-height: 1.2em;"
>
{getPreviewText(textContent)}
</div>
{#if textContent.length > 150}
<div
class="pointer-events-none absolute right-0 bottom-0 left-0 h-4 bg-gradient-to-t from-muted to-transparent"
></div>
{/if}
</div>
{/if}
</div>
</button>
{/if}
{:else}
<button
class="group flex items-center gap-3 rounded-lg border border-border bg-muted p-3 {className} relative"
onclick={onClick}
>
<div
class="flex h-8 w-8 items-center justify-center rounded bg-primary/10 text-xs font-medium text-primary"
>
{fileTypeLabel}
</div>
<div class="flex flex-col gap-0.5">
<span
class="max-w-24 truncate text-sm font-medium text-foreground {readonly
? ''
: 'group-hover:pr-6'} md:max-w-32"
>
{name}
</span>
{#if pdfProcessingMode}
<span class="text-left text-xs text-muted-foreground">{pdfProcessingMode}</span>
{:else if size}
<span class="text-left text-xs text-muted-foreground">{formatFileSize(size)}</span>
{/if}
</div>
{#if !readonly}
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
<ActionIconRemove {id} {onRemove} />
</div>
{/if}
</button>
{/if}

View File

@@ -1,287 +0,0 @@
<script lang="ts">
import {
ChatAttachmentMcpPrompt,
ChatAttachmentMcpResource,
ChatAttachmentThumbnailImage,
ChatAttachmentThumbnailFile,
HorizontalScrollCarousel,
DialogChatAttachmentPreview,
DialogChatAttachmentsViewAll,
DialogMcpResourcePreview
} from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { AttachmentType } from '$lib/enums';
import type {
DatabaseMessageExtraMcpPrompt,
DatabaseMessageExtraMcpResource,
MCPResourceAttachment
} from '$lib/types';
import { getAttachmentDisplayItems } from '$lib/utils';
interface Props {
class?: string;
style?: string;
// For ChatMessage - stored attachments
attachments?: DatabaseMessageExtra[];
readonly?: boolean;
// For ChatForm - pending uploads
onFileRemove?: (fileId: string) => void;
uploadedFiles?: ChatUploadedFile[];
// Image size customization
imageClass?: string;
imageHeight?: string;
imageWidth?: string;
// Limit display to single row with "+ X more" button
limitToSingleRow?: boolean;
// For vision modality check
activeModelId?: string;
}
let {
class: className = '',
style = '',
attachments = [],
readonly = false,
onFileRemove,
uploadedFiles = $bindable([]),
// Default to small size for form previews
imageClass = '',
imageHeight = 'h-24',
imageWidth = 'w-auto',
limitToSingleRow = false,
activeModelId
}: Props = $props();
let displayItems = $derived(getAttachmentDisplayItems({ uploadedFiles, attachments }));
let carouselRef: HorizontalScrollCarousel | undefined = $state();
let isScrollable = $state(false);
let previewDialogOpen = $state(false);
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
let mcpResourcePreviewOpen = $state(false);
let mcpResourcePreviewExtra = $state<DatabaseMessageExtraMcpResource | null>(null);
let showViewAll = $derived(limitToSingleRow && displayItems.length > 0 && isScrollable);
let viewAllDialogOpen = $state(false);
function openPreview(item: ChatAttachmentDisplayItem, event?: MouseEvent) {
event?.stopPropagation();
event?.preventDefault();
previewItem = {
uploadedFile: item.uploadedFile,
attachment: item.attachment,
preview: item.preview,
name: item.name,
size: item.size,
textContent: item.textContent
};
previewDialogOpen = true;
}
function openMcpResourcePreview(extra: DatabaseMessageExtraMcpResource) {
mcpResourcePreviewExtra = extra;
mcpResourcePreviewOpen = true;
}
function toMcpResourceAttachment(
extra: DatabaseMessageExtraMcpResource,
id: string
): MCPResourceAttachment {
return {
id,
resource: {
uri: extra.uri,
name: extra.name,
title: extra.name,
serverName: extra.serverName
}
};
}
$effect(() => {
if (carouselRef && displayItems.length) {
carouselRef.resetScroll();
}
});
</script>
{#if displayItems.length > 0}
<div class={className} {style}>
{#if limitToSingleRow}
<HorizontalScrollCarousel
bind:this={carouselRef}
onScrollableChange={(scrollable) => (isScrollable = scrollable)}
>
{#each displayItems as item (item.id)}
{#if item.isMcpPrompt}
{@const mcpPrompt =
item.attachment?.type === AttachmentType.MCP_PROMPT
? (item.attachment as DatabaseMessageExtraMcpPrompt)
: item.uploadedFile?.mcpPrompt
? {
type: AttachmentType.MCP_PROMPT as const,
name: item.name,
serverName: item.uploadedFile.mcpPrompt.serverName,
promptName: item.uploadedFile.mcpPrompt.promptName,
content: item.textContent ?? '',
arguments: item.uploadedFile.mcpPrompt.arguments
}
: null}
{#if mcpPrompt}
<ChatAttachmentMcpPrompt
class="max-w-[300px] min-w-[200px] flex-shrink-0 {limitToSingleRow
? 'first:ml-4 last:mr-4'
: ''}"
prompt={mcpPrompt}
{readonly}
isLoading={item.isLoading}
loadError={item.loadError}
onRemove={onFileRemove ? () => onFileRemove(item.id) : undefined}
/>
{/if}
{:else if item.isMcpResource && item.attachment?.type === AttachmentType.MCP_RESOURCE}
{@const mcpResource = item.attachment as DatabaseMessageExtraMcpResource}
<ChatAttachmentMcpResource
class="flex-shrink-0 {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
attachment={toMcpResourceAttachment(mcpResource, item.id)}
onClick={() => openMcpResourcePreview(mcpResource)}
/>
{:else if item.isImage && item.preview}
<ChatAttachmentThumbnailImage
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onClick={(event) => openPreview(item, event)}
/>
{:else}
<ChatAttachmentThumbnailFile
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
id={item.id}
name={item.name}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
attachment={item.attachment}
uploadedFile={item.uploadedFile}
onClick={(event) => openPreview(item, event)}
/>
{/if}
{/each}
</HorizontalScrollCarousel>
{#if showViewAll}
<div class="mt-2 -mr-2 flex justify-end px-4">
<Button
type="button"
variant="ghost"
size="sm"
class="h-6 text-xs text-muted-foreground hover:text-foreground"
onclick={() => (viewAllDialogOpen = true)}
>
View all ({displayItems.length})
</Button>
</div>
{/if}
{:else}
<div class="flex flex-wrap items-start justify-end gap-3">
{#each displayItems as item (item.id)}
{#if item.isMcpPrompt}
{@const mcpPrompt =
item.attachment?.type === AttachmentType.MCP_PROMPT
? (item.attachment as DatabaseMessageExtraMcpPrompt)
: item.uploadedFile?.mcpPrompt
? {
type: AttachmentType.MCP_PROMPT as const,
name: item.name,
serverName: item.uploadedFile.mcpPrompt.serverName,
promptName: item.uploadedFile.mcpPrompt.promptName,
content: item.textContent ?? '',
arguments: item.uploadedFile.mcpPrompt.arguments
}
: null}
{#if mcpPrompt}
<ChatAttachmentMcpPrompt
class="max-w-[300px] min-w-[200px]"
prompt={mcpPrompt}
{readonly}
isLoading={item.isLoading}
loadError={item.loadError}
onRemove={onFileRemove ? () => onFileRemove(item.id) : undefined}
/>
{/if}
{:else if item.isMcpResource && item.attachment?.type === AttachmentType.MCP_RESOURCE}
{@const mcpResource = item.attachment as DatabaseMessageExtraMcpResource}
<ChatAttachmentMcpResource
attachment={toMcpResourceAttachment(mcpResource, item.id)}
onClick={() => openMcpResourcePreview(mcpResource)}
/>
{:else if item.isImage && item.preview}
<ChatAttachmentThumbnailImage
class="cursor-pointer"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onClick={(event) => openPreview(item, event)}
/>
{:else}
<ChatAttachmentThumbnailFile
class="cursor-pointer"
id={item.id}
name={item.name}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
attachment={item.attachment}
uploadedFile={item.uploadedFile}
onClick={(event?: MouseEvent) => openPreview(item, event)}
/>
{/if}
{/each}
</div>
{/if}
</div>
{/if}
{#if previewItem}
<DialogChatAttachmentPreview
bind:open={previewDialogOpen}
uploadedFile={previewItem.uploadedFile}
attachment={previewItem.attachment}
preview={previewItem.preview}
name={previewItem.name}
size={previewItem.size}
textContent={previewItem.textContent}
{activeModelId}
/>
{/if}
<DialogChatAttachmentsViewAll
bind:open={viewAllDialogOpen}
{uploadedFiles}
{attachments}
{readonly}
{onFileRemove}
imageHeight="h-64"
{imageClass}
{activeModelId}
/>
{#if mcpResourcePreviewExtra}
<DialogMcpResourcePreview bind:open={mcpResourcePreviewOpen} extra={mcpResourcePreviewExtra} />
{/if}

View File

@@ -0,0 +1,119 @@
<script lang="ts">
import {
ChatAttachmentsListItem,
DialogChatAttachmentsPreview,
DialogMcpResourcePreview,
HorizontalScrollCarousel
} from '$lib/components/app';
import type { DatabaseMessageExtraMcpResource } from '$lib/types';
import { getAttachmentDisplayItems, isMcpPrompt, isMcpResource } from '$lib/utils';
interface Props {
class?: string;
style?: string;
// For ChatMessage - stored attachments
attachments?: DatabaseMessageExtra[];
readonly?: boolean;
// For ChatForm - pending uploads
onFileRemove?: (fileId: string) => void;
uploadedFiles?: ChatUploadedFile[];
// Image size customization
imageClass?: string;
imageHeight?: string;
imageWidth?: string;
// Limit display to single row with "+ X more" button
limitToSingleRow?: boolean;
// For vision modality check
activeModelId?: string;
}
let {
class: className = '',
style = '',
attachments = [],
readonly = false,
onFileRemove,
uploadedFiles = $bindable([]),
// Default to small size for form previews
imageClass = '',
imageHeight = 'h-24',
imageWidth = 'w-auto',
limitToSingleRow = false,
activeModelId
}: Props = $props();
let carouselRef: HorizontalScrollCarousel | undefined = $state();
let mcpResourcePreviewOpen = $state(false);
let mcpResourcePreviewExtra = $state<DatabaseMessageExtraMcpResource | null>(null);
let previewFocusIndex = $state(0);
let viewAllDialogOpen = $state(false);
let displayItems = $derived(getAttachmentDisplayItems({ uploadedFiles, attachments }));
function openPreview(item: ChatAttachmentDisplayItem, event?: MouseEvent) {
event?.stopPropagation();
event?.preventDefault();
// Find the index of the clicked item among non-MCP attachments
const nonMcpItems = displayItems.filter((i) => !isMcpPrompt(i) && !isMcpResource(i));
const index = nonMcpItems.findIndex((i) => i.id === item.id);
previewFocusIndex = index >= 0 ? index : 0;
viewAllDialogOpen = true;
}
function openMcpResourcePreview(extra: DatabaseMessageExtraMcpResource) {
mcpResourcePreviewExtra = extra;
mcpResourcePreviewOpen = true;
}
$effect(() => {
if (carouselRef && displayItems.length) {
carouselRef.resetScroll();
}
});
</script>
{#snippet attachmentitem(item: ChatAttachmentDisplayItem)}
<ChatAttachmentsListItem
{imageClass}
{imageHeight}
{imageWidth}
{item}
{limitToSingleRow}
{onFileRemove}
onMcpResourcePreview={openMcpResourcePreview}
onPreview={(i: ChatAttachmentDisplayItem, event?: MouseEvent) => openPreview(i, event)}
{readonly}
/>
{/snippet}
{#if displayItems.length > 0}
<div class={className} {style}>
{#if limitToSingleRow}
<HorizontalScrollCarousel bind:this={carouselRef}>
{#each displayItems as item (item.id)}
{@render attachmentitem(item)}
{/each}
</HorizontalScrollCarousel>
{:else}
<div class="flex flex-wrap items-start justify-end gap-3">
{#each displayItems as item (item.id)}
{@render attachmentitem(item)}
{/each}
</div>
{/if}
</div>
{/if}
<DialogChatAttachmentsPreview
{activeModelId}
{attachments}
bind:open={viewAllDialogOpen}
{previewFocusIndex}
{uploadedFiles}
/>
{#if mcpResourcePreviewExtra}
<DialogMcpResourcePreview extra={mcpResourcePreviewExtra} bind:open={mcpResourcePreviewOpen} />
{/if}

View File

@@ -0,0 +1,132 @@
<script lang="ts">
import {
ChatAttachmentsListItemMcpPrompt,
ChatAttachmentsListItemMcpResource,
ChatAttachmentsListItemThumbnailImage,
ChatAttachmentsListItemThumbnailFile
} from '$lib/components/app';
import { AttachmentType } from '$lib/enums';
import type {
ChatAttachmentDisplayItem,
DatabaseMessageExtraMcpPrompt,
DatabaseMessageExtraMcpResource,
MCPResourceAttachment
} from '$lib/types';
import { isMcpPrompt, isMcpResource, isPdfFile } from '$lib/utils';
interface Props {
class?: string;
imageClass?: string;
imageHeight?: string;
imageWidth?: string;
item: ChatAttachmentDisplayItem;
limitToSingleRow?: boolean;
onFileRemove?: (fileId: string) => void;
onMcpResourcePreview?: (extra: DatabaseMessageExtraMcpResource) => void;
onPreview?: (item: ChatAttachmentDisplayItem) => void;
readonly?: boolean;
}
let {
class: className = '',
imageClass = '',
imageHeight = 'h-24',
imageWidth = 'w-auto',
item,
limitToSingleRow = false,
onFileRemove,
onMcpResourcePreview,
onPreview,
readonly = false
}: Props = $props();
const scrollClasses = $derived(limitToSingleRow ? 'first:ml-4 last:mr-4' : '');
function toMcpResourceAttachment(
extra: DatabaseMessageExtraMcpResource,
id: string
): MCPResourceAttachment {
return {
id,
resource: {
uri: extra.uri,
name: extra.name,
title: extra.name,
serverName: extra.serverName
}
};
}
</script>
{#if isMcpPrompt(item)}
{@const mcpPrompt =
item.attachment?.type === AttachmentType.MCP_PROMPT
? (item.attachment as DatabaseMessageExtraMcpPrompt)
: item.uploadedFile?.mcpPrompt
? {
type: AttachmentType.MCP_PROMPT as const,
name: item.name,
serverName: item.uploadedFile.mcpPrompt.serverName,
promptName: item.uploadedFile.mcpPrompt.promptName,
content: item.textContent ?? '',
arguments: item.uploadedFile.mcpPrompt.arguments
}
: null}
{#if mcpPrompt}
<ChatAttachmentsListItemMcpPrompt
class="max-w-[300px] min-w-[200px] flex-shrink-0 {className} {scrollClasses}"
prompt={mcpPrompt}
{readonly}
isLoading={item.isLoading}
loadError={item.loadError}
onRemove={onFileRemove ? () => onFileRemove(item.id) : undefined}
/>
{/if}
{:else if isMcpResource(item)}
{@const mcpResource = item.attachment as DatabaseMessageExtraMcpResource}
<ChatAttachmentsListItemMcpResource
class="flex-shrink-0 {className} {scrollClasses}"
attachment={toMcpResourceAttachment(mcpResource, item.id)}
onclick={() => onMcpResourcePreview?.(mcpResource)}
/>
{:else if item.isImage && item.preview}
<ChatAttachmentsListItemThumbnailImage
class="flex-shrink-0 cursor-pointer {className} {scrollClasses}"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onclick={() => onPreview?.(item)}
/>
{:else if isPdfFile(item.attachment, item.uploadedFile)}
<ChatAttachmentsListItemThumbnailFile
class="flex-shrink-0 cursor-pointer {className} {scrollClasses}"
id={item.id}
name={item.name}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
attachment={item.attachment}
uploadedFile={item.uploadedFile}
onclick={() => onPreview?.(item)}
/>
{:else}
<ChatAttachmentsListItemThumbnailFile
class="flex-shrink-0 cursor-pointer {className} {scrollClasses}"
id={item.id}
name={item.name}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
attachment={item.attachment}
uploadedFile={item.uploadedFile}
onclick={() => onPreview?.(item)}
/>
{/if}

View File

@@ -1,40 +1,41 @@
<script lang="ts">
import { ChatMessageMcpPromptContent, ActionIconRemove } from '$lib/components/app';
import { ChatMessageMcpPromptContent, ActionIcon } from '$lib/components/app';
import { X } from '@lucide/svelte';
import type { DatabaseMessageExtraMcpPrompt } from '$lib/types';
import { McpPromptVariant } from '$lib/enums';
interface Props {
class?: string;
prompt: DatabaseMessageExtraMcpPrompt;
readonly?: boolean;
isLoading?: boolean;
loadError?: string;
onRemove?: () => void;
prompt: DatabaseMessageExtraMcpPrompt;
readonly?: boolean;
}
let {
class: className = '',
prompt,
readonly = false,
isLoading = false,
loadError,
onRemove
onRemove,
prompt,
readonly = false
}: Props = $props();
</script>
<div class="group relative {className}">
<ChatMessageMcpPromptContent
{prompt}
variant={McpPromptVariant.ATTACHMENT}
{isLoading}
{loadError}
{prompt}
variant={McpPromptVariant.ATTACHMENT}
/>
{#if !readonly && onRemove}
<div
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
>
<ActionIconRemove id={prompt.name} onRemove={() => onRemove?.()} />
<ActionIcon icon={X} tooltip="Remove" stopPropagationOnClick onclick={() => onRemove?.()} />
</div>
{/if}
</div>

View File

@@ -1,46 +1,47 @@
<script lang="ts">
import { Loader2, AlertCircle } from '@lucide/svelte';
import { cn } from '$lib/components/ui/utils';
import { mcpStore } from '$lib/stores/mcp.svelte';
import type { MCPResourceAttachment } from '$lib/types';
import * as Tooltip from '$lib/components/ui/tooltip';
import { ActionIconRemove } from '$lib/components/app';
import { ActionIcon } from '$lib/components/app';
import { X } from '@lucide/svelte';
import { getResourceIcon, getResourceDisplayName } from '$lib/utils';
interface Props {
attachment: MCPResourceAttachment;
onRemove?: (attachmentId: string) => void;
onClick?: () => void;
class?: string;
onclick?: () => void;
onRemove?: (attachmentId: string) => void;
}
let { attachment, onRemove, onClick, class: className }: Props = $props();
function getStatusClass(attachment: MCPResourceAttachment): string {
if (attachment.error) return 'border-red-500/50 bg-red-500/10';
if (attachment.loading) return 'border-border/50 bg-muted/30';
return 'border-border/50 bg-muted/30';
}
let { attachment, class: className, onclick, onRemove }: Props = $props();
const ResourceIcon = $derived(
getResourceIcon(attachment.resource.mimeType, attachment.resource.uri)
);
const serverName = $derived(mcpStore.getServerDisplayName(attachment.resource.serverName));
const favicon = $derived(mcpStore.getServerFavicon(attachment.resource.serverName));
function getStatusClass(attachment: MCPResourceAttachment): string {
if (attachment.error) return 'border-red-500/50 bg-red-500/10';
if (attachment.loading) return 'border-border/50 bg-muted/30';
return 'border-border/50 bg-muted/30';
}
</script>
<Tooltip.Root>
<Tooltip.Trigger>
<button
type="button"
class={cn(
class={[
'flex flex-shrink-0 items-center gap-1.5 rounded-md border px-2 py-0.75 text-sm transition-colors',
getStatusClass(attachment),
onClick && 'cursor-pointer hover:bg-muted/50',
onclick && 'cursor-pointer hover:bg-muted/50',
className
)}
onclick={onClick}
disabled={!onClick}
]}
disabled={!onclick}
{onclick}
type="button"
>
{#if attachment.loading}
<Loader2 class="h-3 w-3 animate-spin text-muted-foreground" />
@@ -55,11 +56,13 @@
</span>
{#if onRemove}
<ActionIconRemove
<ActionIcon
class="-my-2 -mr-1.5 bg-transparent"
iconSize={2}
id={attachment.id}
{onRemove}
icon={X}
iconSize="h-2 w-2"
onclick={() => onRemove?.(attachment.id)}
stopPropagationOnClick
tooltip="Remove"
/>
{/if}
</button>
@@ -69,12 +72,12 @@
<div class="flex items-center gap-1 text-xs">
{#if favicon}
<img
src={favicon}
alt=""
alt={attachment.resource.serverName}
class="h-3 w-3 shrink-0 rounded-sm"
onerror={(e) => {
(e.currentTarget as HTMLImageElement).style.display = 'none';
}}
src={favicon}
/>
{/if}

View File

@@ -0,0 +1,174 @@
<script lang="ts">
import { X } from '@lucide/svelte';
import {
formatFileSize,
getFileTypeLabel,
getPreviewText,
isPdfFile,
isTextFile
} from '$lib/utils';
import { ActionIcon } from '$lib/components/app';
import { AttachmentType } from '$lib/enums';
interface Props {
attachment?: DatabaseMessageExtra;
class?: string;
id: string;
onclick?: (event: MouseEvent) => void;
onRemove?: (id: string) => void;
name: string;
readonly?: boolean;
size?: number;
textContent?: string;
// Either uploaded file or stored attachment
uploadedFile?: ChatUploadedFile;
}
let {
attachment,
class: className = '',
id,
onclick,
onRemove,
name,
readonly = false,
size,
textContent,
uploadedFile
}: Props = $props();
let isPdf = $derived(isPdfFile(attachment, uploadedFile));
let isPdfWithContent = $derived(isPdf && !!textContent);
let isText = $derived(isTextFile(attachment, uploadedFile));
let isTextWithContent = $derived(isText && !!textContent);
let fileTypeLabel = $derived.by(() => {
if (uploadedFile?.type) {
return getFileTypeLabel(uploadedFile.type);
}
if (attachment) {
if ('mimeType' in attachment && attachment.mimeType) {
return getFileTypeLabel(attachment.mimeType);
}
if (attachment.type) {
return getFileTypeLabel(attachment.type);
}
}
return getFileTypeLabel(name);
});
let pdfProcessingMode = $derived.by(() => {
if (attachment?.type === AttachmentType.PDF) {
const pdfAttachment = attachment as DatabaseMessageExtraPdfFile;
return pdfAttachment.processedAsImages ? 'Sent as Image' : 'Sent as Text';
}
return null;
});
</script>
{#snippet textPreview(content: string)}
<div class="relative">
<div
class="font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground {!readonly
? 'max-h-3rem line-height-1.2'
: ''}"
>
{getPreviewText(content)}
</div>
{#if content.length > 150}
<div
class="pointer-events-none absolute right-0 bottom-0 left-0 h-4 bg-gradient-to-t from-muted to-transparent {readonly
? 'h-6'
: ''}"
></div>
{/if}
</div>
{/snippet}
{#snippet removeButton()}
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
<ActionIcon icon={X} tooltip="Remove" stopPropagationOnClick onclick={() => onRemove?.(id)} />
</div>
{/snippet}
{#snippet fileIcon()}
<div
class="flex h-8 w-8 items-center justify-center rounded bg-primary/10 text-xs font-medium text-primary"
>
{fileTypeLabel}
</div>
{/snippet}
{#snippet info(text: string | undefined)}
{#if text}
<span class="text-xs text-muted-foreground">{text}</span>
{/if}
{/snippet}
{#if isTextWithContent || isPdfWithContent}
<button
aria-label={readonly ? `Preview ${name}` : undefined}
class="rounded-lg border border-border bg-muted p-3 {className} cursor-pointer {readonly
? 'w-full max-w-2xl transition-shadow hover:shadow-md'
: `group relative text-left ${textContent ? 'max-h-24 max-w-72' : 'max-w-36'}`} overflow-hidden"
{onclick}
type="button"
>
{#if !readonly}
{@render removeButton()}
{/if}
<div class={[!readonly && 'pr-8', 'overflow-hidden']}>
{#if readonly}
<div class="flex items-start gap-3">
<div class="flex min-w-0 flex-1 flex-col items-start text-left">
<span class="w-full truncate text-sm font-medium text-foreground">{name}</span>
{@render info(pdfProcessingMode || (size ? formatFileSize(size) : undefined))}
{#if textContent}
{@render textPreview(textContent)}
{/if}
</div>
</div>
{:else}
<span class="mb-3 block truncate text-sm font-medium text-foreground">{name}</span>
{#if textContent}
{@render textPreview(textContent)}
{/if}
{/if}
</div>
</button>
{:else}
<button
class="group flex items-center gap-3 rounded-lg border border-border bg-muted p-3 {className} relative"
{onclick}
type="button"
>
{@render fileIcon()}
<div class="flex flex-col items-start gap-0.5">
<span
class="max-w-24 truncate text-sm font-medium text-foreground {readonly
? ''
: 'group-hover:pr-6'} md:max-w-32"
>
{name}
</span>
{@render info(pdfProcessingMode || (size ? formatFileSize(size) : undefined))}
</div>
{#if !readonly}
{@render removeButton()}
{/if}
</button>
{/if}

View File

@@ -1,64 +1,65 @@
<script lang="ts">
import { ActionIconRemove } from '$lib/components/app';
import { ActionIcon } from '$lib/components/app';
import { X } from '@lucide/svelte';
interface Props {
class?: string;
height?: string;
id: string;
imageClass?: string;
onclick?: (event?: MouseEvent) => void;
onRemove?: (id: string) => void;
name: string;
preview: string;
readonly?: boolean;
onRemove?: (id: string) => void;
onClick?: (event?: MouseEvent) => void;
class?: string;
// Customizable size props
width?: string;
height?: string;
imageClass?: string;
}
let {
class: className = '',
height = 'h-16',
id,
imageClass = '',
onclick,
onRemove,
name,
preview,
readonly = false,
onRemove,
onClick,
class: className = '',
// Default to small size for form previews
width = 'w-auto',
height = 'h-16',
imageClass = ''
width = 'w-auto'
}: Props = $props();
</script>
{#snippet image()}
<img src={preview} alt={name} class="{height} {width} cursor-pointer object-cover {imageClass}" />
{/snippet}
<div
class="group relative overflow-hidden rounded-lg bg-muted shadow-lg dark:border dark:border-muted {className}"
>
{#if onClick}
{#if onclick}
<button
type="button"
class="block h-full w-full rounded-lg focus:ring-2 focus:ring-primary focus:ring-offset-2 focus:outline-none"
onclick={onClick}
aria-label="Preview {name}"
class="block h-full w-full rounded-lg focus:ring-2 focus:ring-primary focus:ring-offset-2 focus:outline-none"
{onclick}
type="button"
>
<img
src={preview}
alt={name}
class="{height} {width} cursor-pointer object-cover {imageClass}"
/>
{@render image()}
</button>
{:else}
<img
src={preview}
alt={name}
class="{height} {width} cursor-pointer object-cover {imageClass}"
/>
{@render image()}
{/if}
{#if !readonly}
<div
class="absolute top-1 right-1 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
>
<ActionIconRemove {id} {onRemove} class="text-white" />
<ActionIcon
class="text-white"
icon={X}
onclick={() => onRemove?.(id)}
stopPropagationOnClick
tooltip="Remove"
/>
</div>
{/if}
</div>

View File

@@ -0,0 +1,190 @@
<script lang="ts">
import {
ChatAttachmentsPreviewCurrentItem,
ChatAttachmentsPreviewFileInfo,
ChatAttachmentsPreviewNavButtons,
ChatAttachmentsPreviewThumbnailStrip
} from '$lib/components/app';
import { modelsStore } from '$lib/stores/models.svelte';
import {
createBase64DataUrl,
formatFileSize,
getAttachmentDisplayItems,
getLanguageFromFilename,
isAudioFile,
isImageFile,
isMcpPrompt,
isMcpResource,
isPdfFile,
isTextFile
} from '$lib/utils';
interface PreviewItem {
id: string;
name: string;
size?: number;
preview?: string;
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
textContent?: string;
isImage: boolean;
isAudio: boolean;
}
interface Props {
uploadedFiles?: ChatUploadedFile[];
attachments?: DatabaseMessageExtra[];
activeModelId?: string;
class?: string;
previewFocusIndex?: number;
}
let {
uploadedFiles = [],
attachments = [],
activeModelId,
class: className = '',
previewFocusIndex = 0
}: Props = $props();
let allItems = $derived(
getAttachmentDisplayItems({ uploadedFiles, attachments })
.filter((item) => !isMcpPrompt(item) && !isMcpResource(item))
.map(
(item): PreviewItem => ({
...item,
isImage: isImageFile(item.attachment, item.uploadedFile),
isAudio: isAudioFile(item.attachment, item.uploadedFile)
})
)
);
let currentIndex = $state(0);
$effect(() => {
if (previewFocusIndex >= 0 && previewFocusIndex < allItems.length) {
currentIndex = previewFocusIndex;
}
});
$effect(() => {
const handler = (e: Event) => {
const delta = (e as CustomEvent).detail;
if (delta < 0) {
currentIndex = currentIndex > 0 ? currentIndex - 1 : allItems.length - 1;
} else {
currentIndex = currentIndex < allItems.length - 1 ? currentIndex + 1 : 0;
}
};
document.addEventListener('chat-attachments-nav', handler);
return () => document.removeEventListener('chat-attachments-nav', handler);
});
$effect(() => {
const index = currentIndex;
setTimeout(() => {
const thumbnail = document.querySelector(`[data-thumbnail-index="${index}"]`);
thumbnail?.scrollIntoView({ behavior: 'smooth', inline: 'center', block: 'nearest' });
}, 0);
});
let currentItem = $derived(allItems[currentIndex] ?? null);
let displayName = $derived(
currentItem?.name ||
currentItem?.uploadedFile?.name ||
currentItem?.attachment?.name ||
'Unknown File'
);
let isAudio = $derived(
currentItem ? isAudioFile(currentItem.attachment, currentItem.uploadedFile) : false
);
let isImage = $derived(
currentItem ? isImageFile(currentItem.attachment, currentItem.uploadedFile) : false
);
let isPdf = $derived(
currentItem ? isPdfFile(currentItem.attachment, currentItem.uploadedFile) : false
);
let isText = $derived(
currentItem ? isTextFile(currentItem.attachment, currentItem.uploadedFile) : false
);
let displayPreview = $derived(
currentItem?.uploadedFile?.preview ||
(isImage && currentItem?.attachment && 'base64Url' in currentItem.attachment
? currentItem.attachment.base64Url
: currentItem?.preview)
);
let displayTextContent = $derived(
currentItem?.uploadedFile?.textContent ||
(currentItem?.attachment && 'content' in currentItem.attachment
? currentItem.attachment.content
: currentItem?.textContent)
);
let language = $derived(getLanguageFromFilename(displayName));
let fileSize = $derived(currentItem?.size ? formatFileSize(currentItem.size) : '');
let hasVisionModality = $derived(
currentItem && activeModelId ? modelsStore.modelSupportsVision(activeModelId) : false
);
let audioSrc = $derived(
isAudio && currentItem
? (currentItem.uploadedFile?.preview ??
(currentItem.attachment &&
'mimeType' in currentItem.attachment &&
'base64Data' in currentItem.attachment
? createBase64DataUrl(
currentItem.attachment.mimeType,
currentItem.attachment.base64Data
)
: null))
: null
);
export function prev() {
currentIndex = currentIndex > 0 ? currentIndex - 1 : allItems.length - 1;
}
export function next() {
currentIndex = currentIndex < allItems.length - 1 ? currentIndex + 1 : 0;
}
function onNavigate(index: number) {
currentIndex = index;
}
</script>
<div class="{className} flex flex-col text-white">
<div class="relative flex min-h-0 flex-1 items-center justify-center overflow-hidden">
<ChatAttachmentsPreviewNavButtons onPrev={prev} onNext={next} show={allItems.length > 1} />
<div class="flex h-full w-full flex-col items-center justify-start overflow-auto py-4">
{#if currentItem}
<ChatAttachmentsPreviewFileInfo {displayName} {fileSize} />
<ChatAttachmentsPreviewCurrentItem
{currentItem}
{isImage}
{isAudio}
{isPdf}
{isText}
{displayPreview}
{displayTextContent}
{audioSrc}
{language}
{hasVisionModality}
{activeModelId}
/>
{/if}
<ChatAttachmentsPreviewThumbnailStrip items={allItems} {currentIndex} {onNavigate} />
</div>
</div>
</div>

View File

@@ -0,0 +1,65 @@
<script lang="ts">
import type { ChatAttachmentDisplayItem } from '$lib/types';
import { Image, Music, FileText, FileIcon } from '@lucide/svelte';
import ChatAttachmentsPreviewCurrentItemPdf from './ChatAttachmentsPreviewCurrentItemPdf.svelte';
import ChatAttachmentsPreviewCurrentItemImage from './ChatAttachmentsPreviewCurrentItemImage.svelte';
import ChatAttachmentsPreviewCurrentItemAudio from './ChatAttachmentsPreviewCurrentItemAudio.svelte';
import ChatAttachmentsPreviewCurrentItemText from './ChatAttachmentsPreviewCurrentItemText.svelte';
import ChatAttachmentsPreviewCurrentItemUnavailable from './ChatAttachmentsPreviewCurrentItemUnavailable.svelte';
interface Props {
currentItem: ChatAttachmentDisplayItem | null;
isImage: boolean;
isAudio: boolean;
isPdf: boolean;
isText: boolean;
displayPreview: string | undefined;
displayTextContent: string | undefined;
audioSrc: string | null;
language: string;
hasVisionModality: boolean;
activeModelId?: string;
}
let {
currentItem,
isImage,
isAudio,
isPdf,
isText,
displayPreview,
displayTextContent,
audioSrc,
language,
hasVisionModality,
activeModelId
}: Props = $props();
let IconComponent = $derived(
isImage ? Image : isText || isPdf ? FileText : isAudio ? Music : FileIcon
);
let isUnavailable = $derived(!isPdf && !isImage && !(isText && displayTextContent) && !isAudio);
</script>
{#if currentItem}
{#key currentItem.id}
{#if isPdf}
<ChatAttachmentsPreviewCurrentItemPdf
{currentItem}
displayName={currentItem.name}
{displayTextContent}
{hasVisionModality}
{activeModelId}
/>
{:else if isImage}
<ChatAttachmentsPreviewCurrentItemImage {currentItem} {displayPreview} />
{:else if isText && displayTextContent}
<ChatAttachmentsPreviewCurrentItemText {displayTextContent} {language} />
{:else if isAudio}
<ChatAttachmentsPreviewCurrentItemAudio {currentItem} {audioSrc} />
{:else if isUnavailable}
<ChatAttachmentsPreviewCurrentItemUnavailable {IconComponent} />
{/if}
{/key}
{/if}

View File

@@ -0,0 +1,26 @@
<script lang="ts">
import { Music } from '@lucide/svelte';
interface Props {
currentItem: { name?: string } | null;
audioSrc: string | null;
}
let { currentItem, audioSrc }: Props = $props();
</script>
<div class="flex flex-1 items-center justify-center p-8">
<div class="w-full max-w-md text-center">
<Music class="mx-auto mb-4 h-16 w-16 text-white/50" />
{#if audioSrc}
<audio controls class="mb-4 w-full" src={audioSrc}>
Your browser does not support the audio element.
</audio>
{:else}
<p class="mb-4 text-white/70">Audio preview not available</p>
{/if}
<p class="text-sm text-white/50">{currentItem?.name || 'Audio'}</p>
</div>
</div>

View File

@@ -0,0 +1,18 @@
<script lang="ts">
interface Props {
currentItem: { name?: string } | null;
displayPreview: string | undefined;
}
let { currentItem, displayPreview }: Props = $props();
</script>
{#if displayPreview}
<div class="flex flex-1 items-center justify-center">
<img
src={displayPreview}
alt={currentItem?.name || 'preview'}
class="max-h-[80vh] max-w-[80vw] rounded-lg object-contain shadow-lg"
/>
</div>
{/if}

View File

@@ -0,0 +1,174 @@
<script lang="ts">
import type { ChatAttachmentDisplayItem } from '$lib/types';
import { FileText, Eye, Info } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as Alert from '$lib/components/ui/alert';
import { SyntaxHighlightedCode } from '$lib/components/app';
import { getLanguageFromFilename } from '$lib/utils';
import { convertPDFToImage } from '$lib/utils/browser-only';
import { PdfViewMode } from '$lib/enums';
interface Props {
currentItem: ChatAttachmentDisplayItem | null;
displayName: string;
displayTextContent: string | undefined;
hasVisionModality: boolean;
activeModelId?: string;
}
let { currentItem, displayName, displayTextContent, hasVisionModality, activeModelId }: Props =
$props();
let pdfViewMode = $state<PdfViewMode>(PdfViewMode.PAGES);
let pdfImages = $state<string[]>([]);
let pdfImagesLoading = $state(false);
let pdfImagesError = $state<string | null>(null);
let language = $derived(getLanguageFromFilename(displayName));
async function loadPdfImages() {
if (pdfImages.length > 0 || pdfImagesLoading || !currentItem) return;
pdfImagesLoading = true;
pdfImagesError = null;
try {
let file: File | null = null;
if (currentItem.uploadedFile?.file) {
file = currentItem.uploadedFile.file;
} else if (currentItem.attachment) {
// Check if we have pre-processed images
if (
'images' in currentItem.attachment &&
currentItem.attachment.images &&
Array.isArray(currentItem.attachment.images) &&
currentItem.attachment.images.length > 0
) {
pdfImages = currentItem.attachment.images;
return;
}
// Convert base64 back to File for processing
if ('base64Data' in currentItem.attachment && currentItem.attachment.base64Data) {
const base64Data = currentItem.attachment.base64Data;
const byteCharacters = atob(base64Data);
const byteNumbers = new Array(byteCharacters.length);
for (let i = 0; i < byteCharacters.length; i++) {
byteNumbers[i] = byteCharacters.charCodeAt(i);
}
const byteArray = new Uint8Array(byteNumbers);
file = new File([byteArray], displayName, { type: 'application/pdf' });
}
}
if (file) {
pdfImages = await convertPDFToImage(file);
} else {
throw new Error('No PDF file available for conversion');
}
} catch (error) {
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
} finally {
pdfImagesLoading = false;
}
}
$effect(() => {
if (pdfViewMode === PdfViewMode.PAGES) {
loadPdfImages();
}
});
</script>
<div class="mb-4 flex items-center justify-end gap-2">
<Button
variant={pdfViewMode === PdfViewMode.TEXT ? 'default' : 'outline'}
size="sm"
onclick={() => (pdfViewMode = PdfViewMode.TEXT)}
disabled={pdfImagesLoading}
>
<FileText class="mr-1 h-4 w-4" />
Text
</Button>
<Button
variant={pdfViewMode === PdfViewMode.PAGES ? 'default' : 'outline'}
size="sm"
onclick={() => (pdfViewMode = PdfViewMode.PAGES)}
disabled={pdfImagesLoading}
>
{#if pdfImagesLoading}
<div
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
></div>
{:else}
<Eye class="mr-1 h-4 w-4" />
{/if}
Pages
</Button>
</div>
{#if !hasVisionModality && activeModelId && currentItem}
<Alert.Root class="mb-4 max-w-4xl">
<Info class="h-4 w-4" />
<Alert.Title>Preview only</Alert.Title>
<Alert.Description>
<span class="inline-flex">
The selected model does not support vision. Only the extracted
<!-- svelte-ignore a11y_click_events_have_key_events -->
<!-- svelte-ignore a11y_no_static_element_interactions -->
<span
class="mx-1 cursor-pointer underline"
onclick={() => (pdfViewMode = PdfViewMode.TEXT)}
>
text
</span>
will be sent to the model.
</span>
</Alert.Description>
</Alert.Root>
{/if}
{#if pdfImagesLoading}
<div class="flex flex-1 items-center justify-center p-8">
<div class="text-center">
<div
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-white border-t-transparent"
></div>
<p class="text-white/70">Converting PDF to images...</p>
</div>
</div>
{:else if pdfImagesError}
<div class="flex flex-1 items-center justify-center p-8">
<div class="text-center">
<FileText class="mx-auto mb-4 h-16 w-16 text-white/50" />
<p class="mb-4 text-white/70">Failed to load PDF images</p>
<p class="text-sm text-white/50">{pdfImagesError}</p>
</div>
</div>
{:else if pdfImages.length > 0}
{#each pdfImages as image, index (image)}
<p class="mb-2 text-sm text-white/50">Page {index + 1}</p>
<img src={image} alt="PDF Page {index + 1}" class="mx-auto max-w-[85vw] rounded-lg shadow-lg" />
<div class="h-4"></div>
{/each}
{:else}
<div class="flex flex-1 items-center justify-center p-8">
<div class="text-center">
<FileText class="mx-auto mb-4 h-16 w-16 text-white/50" />
<p class="text-white/70">No PDF pages available</p>
</div>
</div>
{/if}
{#if pdfViewMode === PdfViewMode.TEXT && displayTextContent}
<div class="px-4 pb-4">
<SyntaxHighlightedCode
class="max-w-4xl"
code={displayTextContent}
{language}
maxHeight="none"
/>
</div>
{/if}

View File

@@ -0,0 +1,21 @@
<script lang="ts">
import { SyntaxHighlightedCode } from '$lib/components/app';
interface Props {
displayTextContent: string | undefined;
language: string;
}
let { displayTextContent, language }: Props = $props();
</script>
{#if displayTextContent}
<div class="px-4 pb-4">
<SyntaxHighlightedCode
class="max-w-4xl"
code={displayTextContent}
{language}
maxHeight="none"
/>
</div>
{/if}

View File

@@ -0,0 +1,17 @@
<script lang="ts">
import type { Component } from 'svelte';
interface Props {
IconComponent: Component;
}
let { IconComponent }: Props = $props();
</script>
<div class="flex flex-1 items-center justify-center p-8">
<div class="text-center">
<IconComponent class="mx-auto mb-4 h-16 w-16 text-white/50" />
<p class="text-white/70">Preview not available for this file type</p>
</div>
</div>

View File

@@ -0,0 +1,16 @@
<script lang="ts">
interface Props {
displayName: string;
fileSize: string;
}
let { displayName, fileSize }: Props = $props();
</script>
<div class="sticky top-0 z-[20] mb-4 rounded-lg bg-black/5 px-4 py-2 text-center backdrop-blur-md">
<p class="font-medium text-white">{displayName}</p>
{#if fileSize}
<p class="text-xs text-white/60">{fileSize}</p>
{/if}
</div>

View File

@@ -0,0 +1,34 @@
<script lang="ts">
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
interface Props {
onPrev: () => void;
onNext: () => void;
show: boolean;
}
let { onPrev, onNext, show }: Props = $props();
</script>
{#if show}
<Button
variant="secondary"
size="icon"
class="absolute top-1/2 left-4 z-10 h-8 w-8 -translate-y-1/2 rounded-full bg-background/5 p-0 text-white!"
onclick={onPrev}
aria-label="Previous"
>
<ChevronLeft class="size-4" />
</Button>
<Button
variant="secondary"
size="icon"
class="absolute top-1/2 right-4 z-10 h-8 w-8 -translate-y-1/2 rounded-full bg-background/5 p-0 text-white!"
onclick={onNext}
aria-label="Next"
>
<ChevronRight class="size-4" />
</Button>
{/if}

View File

@@ -0,0 +1,63 @@
<script lang="ts">
import { Music, FileText } from '@lucide/svelte';
import { HorizontalScrollCarousel } from '$lib/components/app/misc';
interface PreviewItem {
id: string;
name: string;
isImage: boolean;
isAudio: boolean;
preview?: string;
}
interface Props {
items: PreviewItem[];
currentIndex: number;
onNavigate: (index: number) => void;
}
let { items, currentIndex, onNavigate }: Props = $props();
function getFileExtension(name: string): string {
const parts = name.split('.');
if (parts.length > 1) {
return parts.pop()?.toUpperCase() ?? '';
}
return '';
}
</script>
{#if items.length > 1}
<div class="sticky bottom-0 z-10 mt-4 flex-shrink-0">
<HorizontalScrollCarousel class="max-w-full">
{#each items as item, index (item.id)}
<button
data-thumbnail-index={index}
class={[
'relative flex-shrink-0 cursor-pointer overflow-hidden rounded border-2 bg-black/80 backdrop-blur-sm transition-all hover:opacity-90',
index === currentIndex ? 'border-white' : 'border-transparent opacity-60',
'[&:not(:first-child)]:last:mr-4 [&:not(:last-child)]:first:ml-4'
]}
onclick={() => onNavigate(index)}
aria-label={`Go to ${item.name}`}
>
{#if item.isImage && item.preview}
<img src={item.preview} alt={item.name} class="h-12 w-12 object-cover" />
{:else}
<div
class="bg-foreground-muted/50 flex h-12 w-12 flex-col items-center justify-center gap-0.5 py-1"
>
{#if item.isAudio}
<Music class="h-4 w-4 text-white/70" />
{:else}
<FileText class="h-4 w-4 text-white/70" />
{/if}
<span class="font-mono text-[9px] text-white/60">{getFileExtension(item.name)}</span>
</div>
{/if}
</button>
{/each}
</HorizontalScrollCarousel>
</div>
{/if}

View File

@@ -1,117 +0,0 @@
<script lang="ts">
import {
ChatAttachmentThumbnailImage,
ChatAttachmentThumbnailFile,
DialogChatAttachmentPreview
} from '$lib/components/app';
import { getAttachmentDisplayItems } from '$lib/utils';
interface Props {
uploadedFiles?: ChatUploadedFile[];
attachments?: DatabaseMessageExtra[];
readonly?: boolean;
onFileRemove?: (fileId: string) => void;
imageHeight?: string;
imageWidth?: string;
imageClass?: string;
activeModelId?: string;
}
let {
uploadedFiles = [],
attachments = [],
readonly = false,
onFileRemove,
imageHeight = 'h-24',
imageWidth = 'w-auto',
imageClass = '',
activeModelId
}: Props = $props();
let previewDialogOpen = $state(false);
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
let displayItems = $derived(getAttachmentDisplayItems({ uploadedFiles, attachments }));
let imageItems = $derived(displayItems.filter((item) => item.isImage));
let fileItems = $derived(displayItems.filter((item) => !item.isImage));
function openPreview(item: (typeof displayItems)[0], event?: Event) {
if (event) {
event.preventDefault();
event.stopPropagation();
}
previewItem = {
uploadedFile: item.uploadedFile,
attachment: item.attachment,
preview: item.preview,
name: item.name,
size: item.size,
textContent: item.textContent
};
previewDialogOpen = true;
}
</script>
<div class="space-y-4">
<div class="min-h-0 flex-1 space-y-6 overflow-y-auto px-1">
{#if fileItems.length > 0}
<div>
<h3 class="mb-3 text-sm font-medium text-foreground">Files ({fileItems.length})</h3>
<div class="flex flex-wrap items-start gap-3">
{#each fileItems as item (item.id)}
<ChatAttachmentThumbnailFile
class="cursor-pointer"
id={item.id}
name={item.name}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
attachment={item.attachment}
uploadedFile={item.uploadedFile}
onClick={(event?: MouseEvent) => openPreview(item, event)}
/>
{/each}
</div>
</div>
{/if}
{#if imageItems.length > 0}
<div>
<h3 class="mb-3 text-sm font-medium text-foreground">Images ({imageItems.length})</h3>
<div class="flex flex-wrap items-start gap-3">
{#each imageItems as item (item.id)}
{#if item.preview}
<ChatAttachmentThumbnailImage
class="cursor-pointer"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onClick={(event) => openPreview(item, event)}
/>
{/if}
{/each}
</div>
</div>
{/if}
</div>
</div>
{#if previewItem}
<DialogChatAttachmentPreview
bind:open={previewDialogOpen}
uploadedFile={previewItem.uploadedFile}
attachment={previewItem.attachment}
preview={previewItem.preview}
name={previewItem.name}
size={previewItem.size}
textContent={previewItem.textContent}
{activeModelId}
/>
{/if}

View File

@@ -1,14 +1,13 @@
<script lang="ts">
import {
ChatAttachmentsList,
ChatAttachmentMcpResources,
ChatFormActions,
ChatFormFileInputInvisible,
ChatFormPromptPicker,
ChatFormResourcePicker,
ChatFormTextarea
ChatFormMcpResourcesList,
ChatFormPickers,
ChatFormTextarea,
DialogMcpResourcesBrowser
} from '$lib/components/app';
import { DialogMcpResources } from '$lib/components/app/dialogs';
import {
CLIPBOARD_CONTENT_QUOTE_PREFIX,
INPUT_CLASSES,
@@ -54,6 +53,8 @@
isLoading?: boolean;
placeholder?: string;
showMcpPromptButton?: boolean;
showAddButton?: boolean;
showModelSelector?: boolean;
// Event Handlers
onAttachmentRemove?: (index: number) => void;
@@ -73,6 +74,8 @@
isLoading = false,
placeholder = 'Type a message...',
showMcpPromptButton = false,
showAddButton = true,
showModelSelector = true,
uploadedFiles = $bindable([]),
value = $bindable(''),
onAttachmentRemove,
@@ -85,31 +88,21 @@
onValueChange
}: Props = $props();
/**
*
*
* STATE
*
*
*/
// Component References
let audioRecorder: AudioRecorder | undefined;
let chatFormActionsRef: ChatFormActions | undefined = $state(undefined);
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
let promptPickerRef: ChatFormPromptPicker | undefined = $state(undefined);
let resourcePickerRef: ChatFormResourcePicker | undefined = $state(undefined);
let pickersRef: { handleKeydown: (event: KeyboardEvent) => boolean } | undefined =
$state(undefined);
let textareaRef: ChatFormTextarea | undefined = $state(undefined);
// Audio Recording State
let isRecording = $state(false);
let recordingSupported = $state(false);
// Prompt Picker State
// Picker State
let isPromptPickerOpen = $state(false);
let promptSearchQuery = $state('');
// Inline Resource Picker State (triggered by @)
let isInlineResourcePickerOpen = $state(false);
let resourceSearchQuery = $state('');
@@ -117,22 +110,12 @@
let isResourceDialogOpen = $state(false);
let preSelectedResourceUri = $state<string | undefined>(undefined);
/**
*
*
* DERIVED STATE
*
*
*/
// Configuration
let currentConfig = $derived(config());
let pasteLongTextToFileLength = $derived.by(() => {
const n = Number(currentConfig.pasteLongTextToFileLen);
return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
});
// Model Selection Logic
let isRouter = $derived(isRouterMode());
let conversationModel = $derived(
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
@@ -158,7 +141,6 @@
return null;
});
// Form Validation State
let hasModelSelected = $derived(!isRouter || !!conversationModel || !!selectedModelId());
let hasLoadingAttachments = $derived(uploadedFiles.some((f) => f.isLoading));
let hasAttachments = $derived(
@@ -166,27 +148,11 @@
);
let canSubmit = $derived(value.trim().length > 0 || hasAttachments);
/**
*
*
* LIFECYCLE
*
*
*/
onMount(() => {
recordingSupported = isAudioRecordingSupported();
audioRecorder = new AudioRecorder();
});
/**
*
*
* PUBLIC API
*
*
*/
export function focus() {
textareaRef?.focus();
}
@@ -199,10 +165,6 @@
chatFormActionsRef?.openModelSelector();
}
/**
* Check if a model is selected, open selector if not
* @returns true if model is selected, false otherwise
*/
export function checkModelSelected(): boolean {
if (!hasModelSelected) {
chatFormActionsRef?.openModelSelector();
@@ -211,14 +173,6 @@
return true;
}
/**
*
*
* EVENT HANDLERS - File Management
*
*
*/
function handleFileSelect(files: File[]) {
onFilesAdd?.(files);
}
@@ -238,14 +192,6 @@
}
}
/**
*
*
* EVENT HANDLERS - Input & Keyboard
*
*
*/
function handleInput() {
const perChatOverrides = conversationsStore.getAllMcpServerOverrides();
const hasServers = mcpStore.hasEnabledServers(perChatOverrides);
@@ -273,11 +219,7 @@
}
function handleKeydown(event: KeyboardEvent) {
if (isPromptPickerOpen && promptPickerRef?.handleKeydown(event)) {
return;
}
if (isInlineResourcePickerOpen && resourcePickerRef?.handleKeydown(event)) {
if (pickersRef?.handleKeydown(event)) {
return;
}
@@ -388,14 +330,6 @@
}
}
/**
*
*
* EVENT HANDLERS - Prompt Picker
*
*
*/
function handlePromptLoadStart(
placeholderId: string,
promptInfo: MCPPromptInfo,
@@ -474,14 +408,6 @@
textareaRef?.focus();
}
/**
*
*
* EVENT HANDLERS - Inline Resource Picker
*
*
*/
function handleInlineResourcePickerClose() {
isInlineResourcePickerOpen = false;
resourceSearchQuery = '';
@@ -489,7 +415,6 @@
}
function handleInlineResourceSelect() {
// Clear the @query from input after resource is attached
if (value.startsWith(RESOURCE_TRIGGER_PREFIX)) {
value = '';
onValueChange?.('');
@@ -512,14 +437,6 @@
isResourceDialogOpen = true;
}
/**
*
*
* EVENT HANDLERS - Audio Recording
*
*
*/
async function handleMicClick() {
if (!audioRecorder || !recordingSupported) {
console.warn('Audio recording not supported');
@@ -552,29 +469,27 @@
<form
class="relative {className}"
onsubmit={(e) => {
e.preventDefault();
onsubmit={(event) => {
event.preventDefault();
if (!canSubmit || disabled || hasLoadingAttachments) return;
onSubmit?.();
}}
>
<ChatFormPromptPicker
bind:this={promptPickerRef}
isOpen={isPromptPickerOpen}
searchQuery={promptSearchQuery}
onClose={handlePromptPickerClose}
<ChatFormPickers
bind:this={pickersRef}
{isPromptPickerOpen}
{promptSearchQuery}
{isInlineResourcePickerOpen}
{resourceSearchQuery}
onPromptPickerClose={handlePromptPickerClose}
onInlineResourcePickerClose={handleInlineResourcePickerClose}
onInlineResourceSelect={handleInlineResourceSelect}
onPromptLoadStart={handlePromptLoadStart}
onPromptLoadComplete={handlePromptLoadComplete}
onPromptLoadError={handlePromptLoadError}
/>
<ChatFormResourcePicker
bind:this={resourcePickerRef}
isOpen={isInlineResourcePickerOpen}
searchQuery={resourceSearchQuery}
onClose={handleInlineResourcePickerClose}
onResourceSelect={handleInlineResourceSelect}
onBrowse={handleBrowseResources}
onInlineResourceBrowse={handleBrowseResources}
/>
<div
@@ -611,7 +526,7 @@
/>
{#if mcpHasResourceAttachments()}
<ChatAttachmentMcpResources
<ChatFormMcpResourcesList
class="mb-3"
onResourceClick={(uri) => {
preSelectedResourceUri = uri;
@@ -624,10 +539,11 @@
class="px-3"
bind:this={chatFormActionsRef}
canSend={canSubmit}
hasText={value.trim().length > 0}
{disabled}
{isLoading}
{isRecording}
{showAddButton}
{showModelSelector}
{uploadedFiles}
onFileUpload={handleFileUpload}
onMicClick={handleMicClick}
@@ -640,7 +556,7 @@
</div>
</form>
<DialogMcpResources
<DialogMcpResourcesBrowser
bind:open={isResourceDialogOpen}
preSelectedUri={preSelectedResourceUri}
onAttach={(resource: MCPResourceInfo) => {

View File

@@ -0,0 +1,33 @@
<script lang="ts">
import { Plus } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as Tooltip from '$lib/components/ui/tooltip';
import { ATTACHMENT_TOOLTIP_TEXT } from '$lib/constants';
interface Props {
disabled?: boolean;
onclick?: (e: MouseEvent) => void;
}
let { disabled = false, onclick }: Props = $props();
</script>
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
{disabled}
{onclick}
variant="secondary"
type="button"
>
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
<Plus class="h-4 w-4" />
</Button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>{ATTACHMENT_TOOLTIP_TEXT}</p>
</Tooltip.Content>
</Tooltip.Root>

View File

@@ -1,17 +1,18 @@
<script lang="ts">
import { Plus } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import type { Snippet } from 'svelte';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
import {
ATTACHMENT_FILE_ITEMS,
ATTACHMENT_EXTRA_ITEMS,
ATTACHMENT_MCP_ITEMS,
ATTACHMENT_TOOLTIP_TEXT,
TOOLTIP_DELAY_DURATION
} from '$lib/constants';
import { AttachmentMenuItemId } from '$lib/enums';
import { ChatFormActionToolsSubmenu, ChatFormActionMcpServersSubmenu } from '$lib/components/app';
import {
ChatFormActionAddToolsSubmenu,
ChatFormActionAddMcpServersSubmenu
} from '$lib/components/app';
import { useAttachmentMenu } from '$lib/hooks/use-attachment-menu.svelte';
@@ -27,6 +28,7 @@
onMcpPromptClick?: () => void;
onMcpSettingsClick?: () => void;
onMcpResourcesClick?: () => void;
trigger: Snippet<[{ disabled: boolean }]>;
}
let {
@@ -40,7 +42,8 @@
onSystemPromptClick,
onMcpPromptClick,
onMcpSettingsClick,
onMcpResourcesClick
onMcpResourcesClick,
trigger
}: Props = $props();
let dropdownOpen = $state(false);
@@ -62,24 +65,7 @@
<div class="flex items-center gap-1 {className}">
<DropdownMenu.Root bind:open={dropdownOpen}>
<DropdownMenu.Trigger name="Attach files" {disabled}>
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
{disabled}
variant="secondary"
type="button"
>
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
<Plus class="h-4 w-4" />
</Button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>{ATTACHMENT_TOOLTIP_TEXT}</p>
</Tooltip.Content>
</Tooltip.Root>
{@render trigger({ disabled })}
</DropdownMenu.Trigger>
<DropdownMenu.Content align="start" class="w-48">
@@ -161,9 +147,9 @@
{/if}
{/each}
<ChatFormActionToolsSubmenu />
<ChatFormActionAddToolsSubmenu />
<ChatFormActionMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
<ChatFormActionAddMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
{#each ATTACHMENT_MCP_ITEMS as item (item.id)}
{#if attachmentMenu.isItemVisible(item.visibleWhen)}

View File

@@ -0,0 +1,149 @@
<script lang="ts">
import { Settings, Plus } from '@lucide/svelte';
import { Switch } from '$lib/components/ui/switch';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import { McpLogo, DropdownMenuSearchable } from '$lib/components/app';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { HealthCheckStatus } from '$lib/enums';
import type { MCPServerSettingsEntry } from '$lib/types';
import { goto } from '$app/navigation';
interface Props {
onMcpSettingsClick?: () => void;
}
let { onMcpSettingsClick }: Props = $props();
let mcpSearchQuery = $state('');
let allMcpServers = $derived(mcpStore.getServersSorted());
let mcpServers = $derived(allMcpServers.filter((s) => s.enabled));
let hasMcpServers = $derived(mcpServers.length > 0);
// let hasAnyMcpServers = $derived(allMcpServers.length > 0);
let filteredMcpServers = $derived.by(() => {
const query = mcpSearchQuery.toLowerCase().trim();
if (!query) return mcpServers;
return mcpServers.filter((s) => {
const name = getServerLabel(s).toLowerCase();
const url = s.url.toLowerCase();
return name.includes(query) || url.includes(query);
});
});
function getServerLabel(server: MCPServerSettingsEntry): string {
return mcpStore.getServerLabel(server);
}
function isServerEnabledForChat(serverId: string): boolean {
return conversationsStore.isMcpServerEnabledForChat(serverId);
}
async function toggleServerForChat(serverId: string) {
await conversationsStore.toggleMcpServerForChat(serverId);
}
function handleMcpSubMenuOpen(open: boolean) {
if (open) {
mcpSearchQuery = '';
mcpStore.runHealthChecksForServers(allMcpServers);
}
}
function handleMcpSettingsClick() {
onMcpSettingsClick?.();
goto(`${hasMcpServers ? '' : '?add'}#/settings/mcp`);
}
</script>
<DropdownMenu.Root>
<DropdownMenu.Sub onOpenChange={handleMcpSubMenuOpen}>
<DropdownMenu.SubTrigger class="flex cursor-pointer items-center gap-2">
<McpLogo class="h-4 w-4" />
<span>MCP Servers</span>
</DropdownMenu.SubTrigger>
<DropdownMenu.SubContent class="w-72 pt-0">
{#if hasMcpServers}
<DropdownMenuSearchable
placeholder="Search servers..."
bind:searchValue={mcpSearchQuery}
emptyMessage="No servers found"
isEmpty={filteredMcpServers.length === 0}
>
<div class="max-h-64 overflow-y-auto">
{#each filteredMcpServers as server (server.id)}
{@const healthState = mcpStore.getHealthCheckState(server.id)}
{@const hasError = healthState.status === HealthCheckStatus.ERROR}
{@const isEnabledForChat = isServerEnabledForChat(server.id)}
<button
type="button"
class="flex w-full items-center justify-between gap-2 rounded-sm px-2 py-2 text-left transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50"
onclick={() => !hasError && toggleServerForChat(server.id)}
disabled={hasError}
>
<div class="flex min-w-0 flex-1 items-center gap-2">
{#if mcpStore.getServerFavicon(server.id)}
<img
src={mcpStore.getServerFavicon(server.id)}
alt=""
class="h-4 w-4 shrink-0 rounded-sm"
onerror={(e) => {
(e.currentTarget as HTMLImageElement).style.display = 'none';
}}
/>
{/if}
<span class="truncate text-sm">{getServerLabel(server)}</span>
{#if hasError}
<span
class="shrink-0 rounded bg-destructive/15 px-1.5 py-0.5 text-xs text-destructive"
>
Error
</span>
{/if}
</div>
<Switch
checked={isEnabledForChat}
disabled={hasError}
onclick={(e) => e.stopPropagation()}
onCheckedChange={() => toggleServerForChat(server.id)}
/>
</button>
{/each}
</div>
{#snippet footer()}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={handleMcpSettingsClick}
>
<Settings class="h-4 w-4" />
<span>Manage MCP Servers</span>
</DropdownMenu.Item>
{/snippet}
</DropdownMenuSearchable>
{:else}
<div class="px-2 py-3 text-center text-sm text-muted-foreground">
No MCP servers configured
</div>
<DropdownMenu.Separator />
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={handleMcpSettingsClick}
>
<Plus class="h-4 w-4" />
<span>Add MCP Servers</span>
</DropdownMenu.Item>
{/if}
</DropdownMenu.SubContent>
</DropdownMenu.Sub>
</DropdownMenu.Root>

View File

@@ -1,18 +1,17 @@
<script lang="ts">
import { Plus } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import type { Snippet } from 'svelte';
import * as Tooltip from '$lib/components/ui/tooltip';
import * as Sheet from '$lib/components/ui/sheet';
import { TOOLTIP_DELAY_DURATION } from '$lib/constants';
import {
ATTACHMENT_FILE_ITEMS,
ATTACHMENT_EXTRA_ITEMS,
ATTACHMENT_MCP_ITEMS,
ATTACHMENT_TOOLTIP_TEXT
ATTACHMENT_MCP_ITEMS
} from '$lib/constants/attachment-menu';
import { ChatFormActionToolsSubmenu, ChatFormActionMcpServersSubmenu } from '$lib/components/app';
import { McpLogo } from '$lib/components/app';
import { useAttachmentMenu } from '$lib/hooks/use-attachment-menu.svelte';
import { AttachmentMenuItemId } from '$lib/enums';
import { PencilRuler } from '@lucide/svelte';
interface Props {
class?: string;
@@ -24,8 +23,8 @@
onFileUpload?: () => void;
onSystemPromptClick?: () => void;
onMcpPromptClick?: () => void;
onMcpSettingsClick?: () => void;
onMcpResourcesClick?: () => void;
trigger: Snippet<[{ disabled: boolean; onclick?: () => void }]>;
}
let {
@@ -38,8 +37,8 @@
onFileUpload,
onSystemPromptClick,
onMcpPromptClick,
onMcpSettingsClick,
onMcpResourcesClick
onMcpResourcesClick,
trigger
}: Props = $props();
let sheetOpen = $state(false);
@@ -52,28 +51,14 @@
}
);
function handleMcpSettingsClick() {
sheetOpen = false;
onMcpSettingsClick?.();
}
const sheetItemClass =
'flex w-full items-center gap-3 rounded-md px-3 py-2.5 text-left text-sm transition-colors hover:bg-accent active:bg-accent disabled:cursor-not-allowed disabled:opacity-50';
</script>
<div class="flex items-center gap-1 {className}">
<Sheet.Root bind:open={sheetOpen}>
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
{disabled}
variant="secondary"
type="button"
onclick={() => (sheetOpen = true)}
>
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
<Plus class="h-4 w-4" />
</Button>
{@render trigger({ disabled, onclick: () => (sheetOpen = true) })}
<!-- <ChatFormActionAddButton {disabled} onclick={() => (sheetOpen = true)} /> -->
<Sheet.Content side="bottom" class="max-h-[85vh] gap-0 overflow-y-auto">
<Sheet.Header>
@@ -161,9 +146,17 @@
<div class="my-2 border-t"></div>
<ChatFormActionToolsSubmenu />
<a href="#/settings/mcp" class="flex items-center gap-3 px-3 py-2">
<McpLogo class="inline h-4 w-4" />
<ChatFormActionMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
<span class="text-sm">MCP Servers</span>
</a>
<a href="#/settings/chat/tools" class="flex items-center gap-3 px-3 py-2">
<PencilRuler class="inline h-4 w-4" />
<span class="text-sm">Tools</span>
</a>
{#each ATTACHMENT_MCP_ITEMS as item (item.id)}
{#if attachmentMenu.isItemVisible(item.visibleWhen)}

View File

@@ -24,6 +24,7 @@
{#if toolsStore.loading}
<div class="px-3 py-4 text-center text-sm text-muted-foreground">
<Loader2 class="mx-auto mb-1 h-4 w-4 animate-spin" />
Loading tools...
</div>
{:else if toolsStore.isToolsEndpointUnreachable}
@@ -31,19 +32,21 @@
<span class="flex gap-2">
<Info class="mt-0.5 h-4 w-4 shrink-0" />
<span
>Run llama-server with <code>--tools</code> flag to enable
<strong>Built-in Tools</strong>.</span
>
<span>
Run llama-server with <code>--tools</code> flag to enable
<strong>Built-in Tools</strong>.
</span>
</span>
<span class="flex gap-2">
<Info class="mt-0.5 h-4 w-4 shrink-0" />
<span
>{hasMcpServersAvailable ? 'Enable' : 'Add'} MCP Server(s) to access
<strong>MCP Tools</strong>.</span
>
<span>
{hasMcpServersAvailable ? 'Enable' : 'Add'} MCP Server(s) to access
<strong>MCP Tools</strong>.
</span>
</span>
</div>
{:else if toolsStore.error}

View File

@@ -0,0 +1,68 @@
<script lang="ts">
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
import ChatFormActionAddDropdown from './ChatFormActionAddDropdown.svelte';
import ChatFormActionAddSheet from './ChatFormActionAddSheet.svelte';
import ChatFormActionAddButton from './ChatFormActionAddButton.svelte';
interface Props {
disabled?: boolean;
hasAudioModality?: boolean;
hasMcpPromptsSupport?: boolean;
hasMcpResourcesSupport?: boolean;
hasVisionModality?: boolean;
onFileUpload?: () => void;
onMcpPromptClick?: () => void;
onMcpResourcesClick?: () => void;
onMcpSettingsClick?: () => void;
onSystemPromptClick?: () => void;
}
let {
disabled = false,
hasAudioModality = false,
hasMcpPromptsSupport = false,
hasMcpResourcesSupport = false,
hasVisionModality = false,
onFileUpload,
onMcpPromptClick,
onMcpResourcesClick,
onMcpSettingsClick,
onSystemPromptClick
}: Props = $props();
const isMobile = new IsMobile();
</script>
{#if isMobile.current}
<ChatFormActionAddSheet
{disabled}
{hasAudioModality}
{hasVisionModality}
{hasMcpPromptsSupport}
{hasMcpResourcesSupport}
{onFileUpload}
{onMcpPromptClick}
{onMcpResourcesClick}
>
{#snippet trigger({ disabled, onclick })}
<ChatFormActionAddButton {disabled} {onclick} />
{/snippet}
</ChatFormActionAddSheet>
{:else}
<ChatFormActionAddDropdown
{disabled}
{hasAudioModality}
{hasVisionModality}
{hasMcpPromptsSupport}
{hasMcpResourcesSupport}
{onFileUpload}
{onMcpPromptClick}
{onMcpResourcesClick}
{onMcpSettingsClick}
{onSystemPromptClick}
>
{#snippet trigger()}
<ChatFormActionAddButton {disabled} />
{/snippet}
</ChatFormActionAddDropdown>
{/if}

View File

@@ -1,147 +0,0 @@
<script lang="ts">
import { Settings, Plus } from '@lucide/svelte';
import { Switch } from '$lib/components/ui/switch';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import { McpLogo, DropdownMenuSearchable } from '$lib/components/app';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { HealthCheckStatus } from '$lib/enums';
import type { MCPServerSettingsEntry } from '$lib/types';
import { goto } from '$app/navigation';
interface Props {
onMcpSettingsClick?: () => void;
}
let { onMcpSettingsClick }: Props = $props();
let mcpSearchQuery = $state('');
let allMcpServers = $derived(mcpStore.getServersSorted());
let mcpServers = $derived(allMcpServers.filter((s) => s.enabled));
let hasMcpServers = $derived(mcpServers.length > 0);
// let hasAnyMcpServers = $derived(allMcpServers.length > 0);
let filteredMcpServers = $derived.by(() => {
const query = mcpSearchQuery.toLowerCase().trim();
if (!query) return mcpServers;
return mcpServers.filter((s) => {
const name = getServerLabel(s).toLowerCase();
const url = s.url.toLowerCase();
return name.includes(query) || url.includes(query);
});
});
function getServerLabel(server: MCPServerSettingsEntry): string {
return mcpStore.getServerLabel(server);
}
function isServerEnabledForChat(serverId: string): boolean {
return conversationsStore.isMcpServerEnabledForChat(serverId);
}
async function toggleServerForChat(serverId: string) {
await conversationsStore.toggleMcpServerForChat(serverId);
}
function handleMcpSubMenuOpen(open: boolean) {
if (open) {
mcpSearchQuery = '';
mcpStore.runHealthChecksForServers(allMcpServers);
}
}
function handleMcpSettingsClick() {
onMcpSettingsClick?.();
goto(`${hasMcpServers ? '' : '?add'}#/settings/mcp`);
}
</script>
<DropdownMenu.Sub onOpenChange={handleMcpSubMenuOpen}>
<DropdownMenu.SubTrigger class="flex cursor-pointer items-center gap-2">
<McpLogo class="h-4 w-4" />
<span>MCP Servers</span>
</DropdownMenu.SubTrigger>
<DropdownMenu.SubContent class="w-72 pt-0">
{#if hasMcpServers}
<DropdownMenuSearchable
placeholder="Search servers..."
bind:searchValue={mcpSearchQuery}
emptyMessage="No servers found"
isEmpty={filteredMcpServers.length === 0}
>
<div class="max-h-64 overflow-y-auto">
{#each filteredMcpServers as server (server.id)}
{@const healthState = mcpStore.getHealthCheckState(server.id)}
{@const hasError = healthState.status === HealthCheckStatus.ERROR}
{@const isEnabledForChat = isServerEnabledForChat(server.id)}
<button
type="button"
class="flex w-full items-center justify-between gap-2 rounded-sm px-2 py-2 text-left transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50"
onclick={() => !hasError && toggleServerForChat(server.id)}
disabled={hasError}
>
<div class="flex min-w-0 flex-1 items-center gap-2">
{#if mcpStore.getServerFavicon(server.id)}
<img
src={mcpStore.getServerFavicon(server.id)}
alt=""
class="h-4 w-4 shrink-0 rounded-sm"
onerror={(e) => {
(e.currentTarget as HTMLImageElement).style.display = 'none';
}}
/>
{/if}
<span class="truncate text-sm">{getServerLabel(server)}</span>
{#if hasError}
<span
class="shrink-0 rounded bg-destructive/15 px-1.5 py-0.5 text-xs text-destructive"
>
Error
</span>
{/if}
</div>
<Switch
checked={isEnabledForChat}
disabled={hasError}
onclick={(e: MouseEvent) => e.stopPropagation()}
onCheckedChange={() => toggleServerForChat(server.id)}
/>
</button>
{/each}
</div>
{#snippet footer()}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={handleMcpSettingsClick}
>
<Settings class="h-4 w-4" />
<span>Manage MCP Servers</span>
</DropdownMenu.Item>
{/snippet}
</DropdownMenuSearchable>
{:else}
<div class="px-2 py-3 text-center text-sm text-muted-foreground">
No MCP servers configured
</div>
<DropdownMenu.Separator />
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={handleMcpSettingsClick}
>
<Plus class="h-4 w-4" />
<span>Add MCP Servers</span>
</DropdownMenu.Item>
{/if}
</DropdownMenu.SubContent>
</DropdownMenu.Sub>

View File

@@ -0,0 +1,160 @@
<script lang="ts">
import { chatStore } from '$lib/stores/chat.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isRouterMode, serverError } from '$lib/stores/server.svelte';
import { ModelsSelectorDropdown, ModelsSelectorSheet } from '$lib/components/app';
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
import { activeMessages } from '$lib/stores/conversations.svelte';
interface Props {
currentModel?: string;
disabled?: boolean;
forceForegroundText?: boolean;
hasAudioModality?: boolean;
hasVisionModality?: boolean;
hasModelSelected?: boolean;
isSelectedModelInCache?: boolean;
submitTooltip?: string;
useGlobalSelection?: boolean;
}
let {
currentModel,
disabled = false,
forceForegroundText = false,
hasAudioModality = $bindable(false),
hasVisionModality = $bindable(false),
hasModelSelected = $bindable(false),
isSelectedModelInCache = $bindable(true),
submitTooltip = $bindable(''),
useGlobalSelection = false
}: Props = $props();
let isRouter = $derived(isRouterMode());
let isOffline = $derived(!!serverError());
let conversationModel = $derived(
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
);
let lastSyncedConversationModel: string | null = null;
$effect(() => {
if (conversationModel && conversationModel !== lastSyncedConversationModel) {
lastSyncedConversationModel = conversationModel;
modelsStore.selectModelByName(conversationModel);
} else if (isRouter && !modelsStore.selectedModelId && modelsStore.loadedModelIds.length > 0) {
lastSyncedConversationModel = null;
// auto-select the first loaded model only when nothing is selected yet
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
if (first) modelsStore.selectModelById(first.id);
}
});
let activeModelId = $derived.by(() => {
const options = modelOptions();
if (!isRouter) {
return options.length > 0 ? options[0].model : null;
}
const selectedId = selectedModelId();
if (selectedId) {
const model = options.find((m) => m.id === selectedId);
if (model) return model.model;
}
if (conversationModel) {
const model = options.find((m) => m.model === conversationModel);
if (model) return model.model;
}
return null;
});
let modelPropsVersion = $state(0); // Used to trigger reactivity after fetch
$effect(() => {
if (activeModelId) {
const cached = modelsStore.getModelProps(activeModelId);
if (!cached) {
modelsStore.fetchModelProps(activeModelId).then(() => {
modelPropsVersion++;
});
}
}
});
$effect(() => {
hasAudioModality = activeModelId ? modelsStore.modelSupportsAudio(activeModelId) : false;
});
$effect(() => {
void modelPropsVersion;
hasVisionModality = activeModelId ? modelsStore.modelSupportsVision(activeModelId) : false;
});
$effect(() => {
hasModelSelected = !isRouter || !!conversationModel || !!selectedModelId();
});
$effect(() => {
if (!isRouter) {
isSelectedModelInCache = true;
} else if (conversationModel) {
isSelectedModelInCache = modelOptions().some((option) => option.model === conversationModel);
} else {
const currentModelId = selectedModelId();
if (!currentModelId) {
isSelectedModelInCache = false;
} else {
isSelectedModelInCache = modelOptions().some((option) => option.id === currentModelId);
}
}
});
$effect(() => {
if (!hasModelSelected) {
submitTooltip = 'Please select a model first';
} else if (!isSelectedModelInCache) {
submitTooltip = 'Selected model is not available, please select another';
} else {
submitTooltip = '';
}
});
let selectorModelRef: ModelsSelectorDropdown | ModelsSelectorSheet | undefined =
$state(undefined);
let isMobile = new IsMobile();
export function open() {
selectorModelRef?.open();
}
</script>
{#if isMobile.current}
<ModelsSelectorSheet
disabled={disabled || isOffline}
bind:this={selectorModelRef}
{currentModel}
{forceForegroundText}
{useGlobalSelection}
/>
{:else}
<ModelsSelectorDropdown
disabled={disabled || isOffline}
bind:this={selectorModelRef}
{currentModel}
{forceForegroundText}
{useGlobalSelection}
/>
{/if}

View File

@@ -2,7 +2,6 @@
import { ArrowUp } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as Tooltip from '$lib/components/ui/tooltip';
import { cn } from '$lib/components/ui/utils';
interface Props {
canSend?: boolean;
@@ -20,12 +19,11 @@
<Button
type="submit"
disabled={isDisabled}
class={cn(
class={[
'h-8 w-8 rounded-full p-0',
showErrorState
? 'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
: ''
)}
showErrorState &&
'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
]}
{...props}
>
<span class="sr-only">Send</span>

View File

@@ -2,31 +2,27 @@
import { Square } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import {
ChatFormActionAttachmentsDropdown,
ChatFormActionAttachmentsSheet,
ChatFormActionsAdd,
ChatFormActionModels,
ChatFormActionRecord,
ChatFormActionSubmit,
ModelsSelectorDropdown,
ModelsSelectorSheet
ChatFormActionSubmit
} from '$lib/components/app';
import { FileTypeCategory } from '$lib/enums';
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isRouterMode, serverError } from '$lib/stores/server.svelte';
import { config } from '$lib/stores/settings.svelte';
import { activeMessages, conversationsStore } from '$lib/stores/conversations.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { getFileTypeCategory } from '$lib/utils';
import { goto } from '$app/navigation';
interface Props {
canSend?: boolean;
canSubmit?: boolean;
class?: string;
disabled?: boolean;
isLoading?: boolean;
isRecording?: boolean;
hasText?: boolean;
showAddButton?: boolean;
showModelSelector?: boolean;
uploadedFiles?: ChatUploadedFile[];
onFileUpload?: () => void;
onMicClick?: () => void;
@@ -38,11 +34,13 @@
let {
canSend = false,
canSubmit = false,
class: className = '',
disabled = false,
isLoading = false,
isRecording = false,
hasText = false,
showAddButton = true,
showModelSelector = true,
uploadedFiles = [],
onFileUpload,
onMicClick,
@@ -53,124 +51,6 @@
}: Props = $props();
let currentConfig = $derived(config());
let isRouter = $derived(isRouterMode());
let isOffline = $derived(!!serverError());
let conversationModel = $derived(
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
);
let lastSyncedConversationModel: string | null = null;
$effect(() => {
if (conversationModel && conversationModel !== lastSyncedConversationModel) {
lastSyncedConversationModel = conversationModel;
modelsStore.selectModelByName(conversationModel);
} else if (isRouter && !modelsStore.selectedModelId && modelsStore.loadedModelIds.length > 0) {
lastSyncedConversationModel = null;
// auto-select the first loaded model only when nothing is selected yet
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
if (first) modelsStore.selectModelById(first.id);
}
});
let activeModelId = $derived.by(() => {
const options = modelOptions();
if (!isRouter) {
return options.length > 0 ? options[0].model : null;
}
const selectedId = selectedModelId();
if (selectedId) {
const model = options.find((m) => m.id === selectedId);
if (model) return model.model;
}
if (conversationModel) {
const model = options.find((m) => m.model === conversationModel);
if (model) return model.model;
}
return null;
});
let modelPropsVersion = $state(0); // Used to trigger reactivity after fetch
$effect(() => {
if (activeModelId) {
const cached = modelsStore.getModelProps(activeModelId);
if (!cached) {
modelsStore.fetchModelProps(activeModelId).then(() => {
modelPropsVersion++;
});
}
}
});
let hasAudioModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
return modelsStore.modelSupportsAudio(activeModelId);
}
return false;
});
let hasVisionModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
return modelsStore.modelSupportsVision(activeModelId);
}
return false;
});
let hasAudioAttachments = $derived(
uploadedFiles.some((file) => getFileTypeCategory(file.type) === FileTypeCategory.AUDIO)
);
let shouldShowRecordButton = $derived(
hasAudioModality && !hasText && !hasAudioAttachments && currentConfig.autoMicOnEmpty
);
let hasModelSelected = $derived(!isRouter || !!conversationModel || !!selectedModelId());
let isSelectedModelInCache = $derived.by(() => {
if (!isRouter) return true;
if (conversationModel) {
return modelOptions().some((option) => option.model === conversationModel);
}
const currentModelId = selectedModelId();
if (!currentModelId) return false;
return modelOptions().some((option) => option.id === currentModelId);
});
let submitTooltip = $derived.by(() => {
if (!hasModelSelected) {
return 'Please select a model first';
}
if (!isSelectedModelInCache) {
return 'Selected model is not available, please select another';
}
return '';
});
let selectorModelRef: ModelsSelectorDropdown | ModelsSelectorSheet | undefined =
$state(undefined);
let isMobile = new IsMobile();
export function openModelSelector() {
selectorModelRef?.open();
}
let hasMcpPromptsSupport = $derived.by(() => {
const perChatOverrides = conversationsStore.getAllMcpServerOverrides();
@@ -183,25 +63,34 @@
return mcpStore.hasResourcesCapability(perChatOverrides);
});
let hasAudioModality = $state(false);
let hasVisionModality = $state(false);
let hasModelSelected = $state(false);
let isSelectedModelInCache = $state(true);
let submitTooltip = $state('');
let hasAudioAttachments = $derived(
uploadedFiles.some((file) => getFileTypeCategory(file.type) === FileTypeCategory.AUDIO)
);
let shouldShowRecordButton = $derived(
hasAudioModality && !canSubmit && !hasAudioAttachments && currentConfig.autoMicOnEmpty
);
let selectorModelRef: ChatFormActionModels | undefined = $state(undefined);
export function openModelSelector() {
selectorModelRef?.open();
}
</script>
<div class="flex w-full items-center gap-3 {className}" style="container-type: inline-size">
<div class="mr-auto flex items-center gap-2">
{#if isMobile.current}
<ChatFormActionAttachmentsSheet
{disabled}
{hasAudioModality}
{hasVisionModality}
{hasMcpPromptsSupport}
{hasMcpResourcesSupport}
{onFileUpload}
{onSystemPromptClick}
{onMcpPromptClick}
onMcpSettingsClick={() => goto('#/settings/mcp')}
{onMcpResourcesClick}
/>
{:else}
<ChatFormActionAttachmentsDropdown
<div
class="flex w-full items-center gap-3 {className} {showAddButton ? '' : 'justify-end'}"
style="container-type: inline-size"
>
{#if showAddButton}
<div class="mr-auto flex items-center gap-2">
<ChatFormActionsAdd
{disabled}
{hasAudioModality}
{hasVisionModality}
@@ -213,30 +102,24 @@
{onMcpResourcesClick}
onMcpSettingsClick={() => goto('#/settings/mcp')}
/>
{/if}
</div>
</div>
{/if}
<div class="ml-auto flex items-center gap-2">
{#if isMobile.current}
<ModelsSelectorSheet
disabled={disabled || isOffline}
bind:this={selectorModelRef}
currentModel={conversationModel}
forceForegroundText
useGlobalSelection
/>
{:else}
<ModelsSelectorDropdown
disabled={disabled || isOffline}
bind:this={selectorModelRef}
currentModel={conversationModel}
forceForegroundText
useGlobalSelection
/>
{/if}
</div>
{#if showModelSelector}
<ChatFormActionModels
{disabled}
bind:this={selectorModelRef}
bind:hasAudioModality
bind:hasVisionModality
bind:hasModelSelected
bind:isSelectedModelInCache
bind:submitTooltip
forceForegroundText
useGlobalSelection
/>
{/if}
{#if isLoading && !hasText}
{#if isLoading && !canSubmit}
<Button
type="button"
variant="secondary"
@@ -253,10 +136,10 @@
<ChatFormActionRecord {disabled} {hasAudioModality} {isLoading} {isRecording} {onMicClick} />
{:else}
<ChatFormActionSubmit
canSend={canSend && hasModelSelected && isSelectedModelInCache}
canSend={canSend && (showModelSelector ? hasModelSelected && isSelectedModelInCache : true)}
{disabled}
tooltipLabel={submitTooltip}
showErrorState={hasModelSelected && !isSelectedModelInCache}
showErrorState={showModelSelector && hasModelSelected && !isSelectedModelInCache}
/>
{/if}
</div>

View File

@@ -15,6 +15,7 @@
function handleFileSelect(event: Event) {
const input = event.target as HTMLInputElement;
if (input.files) {
onFileSelect?.(Array.from(input.files));
}

View File

@@ -1,31 +0,0 @@
<script lang="ts">
import { browser } from '$app/environment';
import { config } from '$lib/stores/settings.svelte';
interface Props {
class?: string;
show?: boolean;
}
let { class: className = '', show = true }: Props = $props();
let sendOnEnter = $derived(config().sendOnEnter !== false);
let modKey = browser && /Mac|iPhone|iPad|iPod/.test(navigator.platform) ? 'Cmd' : 'Ctrl';
</script>
{#if show}
<div class="mt-6 items-center justify-center {className} hidden md:flex">
{#if sendOnEnter}
<p class="text-xs text-muted-foreground">
Press <kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">Enter</kbd> to send,
<kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">Shift + Enter</kbd> for new line
</p>
{:else}
<p class="text-xs text-muted-foreground">
Press <kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">{modKey} + Enter</kbd> to
send,
<kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">Enter</kbd> for new line
</p>
{/if}
</div>
{/if}

View File

@@ -4,7 +4,10 @@
mcpResourceAttachments,
mcpHasResourceAttachments
} from '$lib/stores/mcp-resources.svelte';
import { ChatAttachmentMcpResource, HorizontalScrollCarousel } from '$lib/components/app';
import {
ChatAttachmentsListItemMcpResource,
HorizontalScrollCarousel
} from '$lib/components/app';
interface Props {
class?: string;
@@ -29,11 +32,11 @@
<div class={className}>
<HorizontalScrollCarousel gapSize="2">
{#each attachments as attachment, i (attachment.id)}
<ChatAttachmentMcpResource
<ChatAttachmentsListItemMcpResource
class={i === 0 ? 'ml-3' : ''}
{attachment}
onRemove={handleRemove}
onClick={() => handleResourceClick(attachment.resource.uri)}
onclick={() => handleResourceClick(attachment.resource.uri)}
/>
{/each}
</HorizontalScrollCarousel>

Some files were not shown because too many files have changed in this diff Show More