mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-05-07 16:57:34 +03:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
048a490f76 | ||
|
|
db44417b02 | ||
|
|
d05fe1d7da | ||
|
|
0754b7b6fe | ||
|
|
09294365a9 | ||
|
|
63d93d1733 | ||
|
|
c5a3bc39b1 | ||
|
|
9dbb372610 | ||
|
|
228e836344 | ||
|
|
ed23489f42 | ||
|
|
457e2288c9 | ||
|
|
e8ec7ab058 | ||
|
|
1a03cf47f6 | ||
|
|
b97ebdc98f | ||
|
|
2098fd6169 | ||
|
|
ab6120cde5 | ||
|
|
c3c1505392 |
@@ -12,6 +12,8 @@ body:
|
||||
after recreating the CMake build directory and with `-DGGML_CCACHE=OFF`.
|
||||
If the compilation succeeds with ccache disabled you should be able to permanently fix the issue
|
||||
by clearing `~/.cache/ccache` (on Linux).
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
- type: textarea
|
||||
id: commit
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/011-bug-results.yml
vendored
4
.github/ISSUE_TEMPLATE/011-bug-results.yml
vendored
@@ -1,5 +1,5 @@
|
||||
name: Bug (model use)
|
||||
description: Something goes wrong when using a model (in general, not specific to a single llama.cpp module).
|
||||
description: Something goes wrong when running a model (crashes, garbled outputs, etc.).
|
||||
title: "Eval bug: "
|
||||
labels: ["bug-unconfirmed", "model evaluation"]
|
||||
body:
|
||||
@@ -12,6 +12,8 @@ body:
|
||||
If you encountered the issue while using an external UI (e.g. ollama),
|
||||
please reproduce your issue using one of the examples/binaries in this repository.
|
||||
The `llama-completion` binary can be used for simple and reproducible model inference.
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
- type: textarea
|
||||
id: version
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/019-bug-misc.yml
vendored
2
.github/ISSUE_TEMPLATE/019-bug-misc.yml
vendored
@@ -10,6 +10,8 @@ body:
|
||||
This issue template is intended for miscellaneous bugs that don't fit into any other category.
|
||||
If you encountered the issue while using an external UI (e.g. ollama),
|
||||
please reproduce your issue using one of the examples/binaries in this repository.
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
- type: textarea
|
||||
id: version
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/020-enhancement.yml
vendored
2
.github/ISSUE_TEMPLATE/020-enhancement.yml
vendored
@@ -8,6 +8,8 @@ body:
|
||||
value: |
|
||||
[Please post your idea first in Discussion if there is not yet a consensus for this enhancement request. This will help to keep this issue tracker focused on enhancements that the community has agreed needs to be implemented.](https://github.com/ggml-org/llama.cpp/discussions/categories/ideas)
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
|
||||
- type: checkboxes
|
||||
id: prerequisites
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/030-research.yml
vendored
2
.github/ISSUE_TEMPLATE/030-research.yml
vendored
@@ -8,6 +8,8 @@ body:
|
||||
value: |
|
||||
Don't forget to check for any [duplicate research issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3A%22research+%F0%9F%94%AC%22)
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
|
||||
- type: checkboxes
|
||||
id: research-stage
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/040-refactor.yml
vendored
2
.github/ISSUE_TEMPLATE/040-refactor.yml
vendored
@@ -9,6 +9,8 @@ body:
|
||||
Don't forget to [check for existing refactor issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3Arefactoring) in case it's already covered.
|
||||
Also you may want to check [Pull request refactor label as well](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Aopen+is%3Apr+label%3Arefactoring) for duplicates too.
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
|
||||
- type: textarea
|
||||
id: background-description
|
||||
attributes:
|
||||
|
||||
@@ -2889,6 +2889,20 @@ class LlamaModel(TextModel):
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape))
|
||||
|
||||
def _repack_nvfp4(self, name: str, weight: Tensor, scale: Tensor, scale2: Tensor, input_scale: Tensor):
|
||||
# Mirror the BF16 Q/K RoPE permutation site in modify_tensors; the NVFP4 path bypasses it.
|
||||
if self.undo_permute:
|
||||
n_head = self.find_hparam(["n_heads", "num_attention_heads"], optional=True)
|
||||
n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"], optional=True)
|
||||
if n_head is not None:
|
||||
if name.endswith("q_proj.weight"):
|
||||
weight = LlamaModel.permute(weight, n_head, n_head)
|
||||
scale = LlamaModel.permute(scale, n_head, n_head)
|
||||
elif name.endswith("k_proj.weight"):
|
||||
weight = LlamaModel.permute(weight, n_head, n_kv_head)
|
||||
scale = LlamaModel.permute(scale, n_head, n_kv_head)
|
||||
super()._repack_nvfp4(name, weight, scale, scale2, input_scale)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
@@ -12702,11 +12716,12 @@ class MistralModel(LlamaModel):
|
||||
def set_mistral_config(gguf_writer: gguf.GGUFWriter, hparams: dict):
|
||||
if "yarn" in hparams:
|
||||
yarn_params = hparams["yarn"]
|
||||
mscale_all_dim = 1.0 if not yarn_params["apply_scale"] else 0.0
|
||||
gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
|
||||
gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
|
||||
gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
|
||||
gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
|
||||
gguf_writer.add_rope_scaling_yarn_log_mul(mscale_all_dim)
|
||||
gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
|
||||
|
||||
if "llama_4_scaling" in hparams:
|
||||
@@ -13232,17 +13247,18 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
}
|
||||
|
||||
# only used when byteswapping data. Only correct size is needed
|
||||
# TODO: uncomment uint64, uint32, and uint16, ref: https://github.com/pytorch/pytorch/issues/58734
|
||||
_dtype_byteswap_map: dict[torch.dtype, type] = {
|
||||
torch.float64: np.float64,
|
||||
torch.float32: np.float32,
|
||||
torch.bfloat16: np.float16,
|
||||
torch.float16: np.float16,
|
||||
torch.int64: np.int64,
|
||||
torch.uint64: np.uint64,
|
||||
# torch.uint64: np.uint64,
|
||||
torch.int32: np.int32,
|
||||
torch.uint32: np.uint32,
|
||||
# torch.uint32: np.uint32,
|
||||
torch.int16: np.int16,
|
||||
torch.uint16: np.uint16,
|
||||
# torch.uint16: np.uint16,
|
||||
torch.int8: np.int8,
|
||||
torch.uint8: np.uint8,
|
||||
torch.bool: np.uint8,
|
||||
|
||||
@@ -5,7 +5,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 10)
|
||||
set(GGML_VERSION_PATCH 1)
|
||||
set(GGML_VERSION_PATCH 2)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
@@ -5431,8 +5431,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
char pci_bus_id[16] = {};
|
||||
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
|
||||
char pci_bus_id[32] = {};
|
||||
CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i));
|
||||
dev_ctx->pci_bus_id = pci_bus_id;
|
||||
dev_ctx->op_offload_min_batch_size = min_batch_size;
|
||||
|
||||
|
||||
1
ggml/src/ggml-cuda/vendors/hip.h
vendored
1
ggml/src/ggml-cuda/vendors/hip.h
vendored
@@ -55,6 +55,7 @@
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceGetAttribute hipDeviceGetAttribute
|
||||
#define cudaDeviceGetPCIBusId hipDeviceGetPCIBusId
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
|
||||
1
ggml/src/ggml-cuda/vendors/musa.h
vendored
1
ggml/src/ggml-cuda/vendors/musa.h
vendored
@@ -39,6 +39,7 @@
|
||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
||||
#define cudaDeviceGetPCIBusId musaDeviceGetPCIBusId
|
||||
#define cudaDeviceProp musaDeviceProp
|
||||
#define cudaDeviceSynchronize musaDeviceSynchronize
|
||||
#define cudaError_t musaError_t
|
||||
|
||||
@@ -22,7 +22,8 @@ message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for
|
||||
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include(ExternalProject)
|
||||
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
|
||||
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
|
||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
|
||||
|
||||
|
||||
@@ -2254,8 +2254,7 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->ne[2] != 1 || dst->ne[3] != 1) {
|
||||
// FA during prompt still needs work
|
||||
if (dst->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2421,8 +2420,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contigiuos tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
|
||||
// TODO: add support for non-contiguous elements within a row
|
||||
if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,11 @@ target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
|
||||
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
|
||||
|
||||
if (GGML_HEXAGON_FA_EXP2_HF)
|
||||
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
|
||||
endif()
|
||||
|
||||
# HMX acceleration: available on v73+ architectures
|
||||
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
|
||||
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
@@ -52,11 +57,13 @@ if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-queue.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
)
|
||||
|
||||
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
|
||||
set_source_files_properties(
|
||||
hmx-matmul-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
PROPERTIES COMPILE_OPTIONS "-mhmx"
|
||||
)
|
||||
|
||||
|
||||
@@ -138,15 +138,15 @@ set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
|
||||
|
||||
#Compiler Options
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O2")
|
||||
|
||||
set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O2")
|
||||
|
||||
set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}")
|
||||
set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}")
|
||||
|
||||
@@ -17,13 +17,14 @@
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hmx-ops.h"
|
||||
|
||||
// Must be multiple of 32
|
||||
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
|
||||
|
||||
// This is a bit of a hack because the compiler is strugling to properly inline
|
||||
// the default hvx_vec_f32_to_f16 with output into the local array.
|
||||
static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
||||
static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
||||
{
|
||||
*(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
|
||||
}
|
||||
@@ -621,6 +622,17 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
// HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV
|
||||
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) {
|
||||
int ret = hmx_flash_attn_ext(octx);
|
||||
if (ret == HTP_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
// VTCM too small or other failure -> fall through to HVX path
|
||||
}
|
||||
#endif
|
||||
|
||||
struct htp_fa_context factx;
|
||||
factx.octx = octx;
|
||||
|
||||
|
||||
@@ -74,6 +74,12 @@ static inline size_t hex_smax(size_t a, size_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
static inline void hex_swap_ptr(void ** p1, void ** p2) {
|
||||
void * t = *p1;
|
||||
*p1 = *p2;
|
||||
*p2 = t;
|
||||
}
|
||||
|
||||
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
|
||||
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||
Q6_l2fetch_AP((void *) p, control);
|
||||
|
||||
1840
ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c
Normal file
1840
ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -61,6 +61,9 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
|
||||
int m, int k, int n,
|
||||
int weight_type);
|
||||
|
||||
// HMX flash attention
|
||||
int hmx_flash_attn_ext(struct htp_ops_context * octx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
#ifndef HMX_UTILS_H
|
||||
#define HMX_UTILS_H
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <stddef.h>
|
||||
|
||||
@@ -12,21 +15,188 @@
|
||||
#define HMX_FP16_TILE_N_ELMS 1024
|
||||
#define HMX_FP16_TILE_SIZE 2048
|
||||
|
||||
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
|
||||
|
||||
// Initialise aligned 256-byte area with scale vector + zero padding.
|
||||
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
|
||||
HVX_Vector *pv = (HVX_Vector *)out_scales;
|
||||
*pv++ = v_scale;
|
||||
*pv = Q6_V_vzero();
|
||||
static inline void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
|
||||
volatile HVX_Vector *pv = (HVX_Vector *) out_scales;
|
||||
pv[0] = v_scale;
|
||||
pv[1] = Q6_V_vzero();
|
||||
}
|
||||
|
||||
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
|
||||
// --- Shared scatter offsets and interleave helper ---
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
uint8_t *p = *vtcm_ptr;
|
||||
*vtcm_ptr += size;
|
||||
return p;
|
||||
// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile.
|
||||
// word[i] = i*128 maps K-row-pair i to byte offset i*128.
|
||||
// Column offset (n*4) is added at runtime. Entries 0..15 cover one tile (region 2047);
|
||||
// entries 16..31 cover the next adjacent tile (region 4095) — pick region size at the
|
||||
// call site to scatter into one tile (masked) or two contiguous tiles (unmasked).
|
||||
static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = {
|
||||
0 * 128, 1 * 128, 2 * 128, 3 * 128, 4 * 128, 5 * 128, 6 * 128, 7 * 128, 8 * 128, 9 * 128, 10 * 128,
|
||||
11 * 128, 12 * 128, 13 * 128, 14 * 128, 15 * 128, 16 * 128, 17 * 128, 18 * 128, 19 * 128, 20 * 128, 21 * 128,
|
||||
22 * 128, 23 * 128, 24 * 128, 25 * 128, 26 * 128, 27 * 128, 28 * 128, 29 * 128, 30 * 128, 31 * 128,
|
||||
};
|
||||
|
||||
// Scatter row-major FP16 data (in VTCM scratch) into transposed [K][N] tiles.
|
||||
// vtcm_src: [n_cols][src_stride] row-major fp16 (only first k elements per row are used)
|
||||
// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16
|
||||
// Processes rows [start_row, end_row) for multi-thread slicing.
|
||||
// Full range: start_row=0, end_row=n_cols.
|
||||
static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
const __fp16 * restrict vtcm_src,
|
||||
int n_cols,
|
||||
int k,
|
||||
int src_stride,
|
||||
int start_row,
|
||||
int end_row) {
|
||||
assert(k % HMX_FP16_TILE_N_COLS == 0);
|
||||
|
||||
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
|
||||
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
|
||||
// Each hvx_vmemu load brings 64 fp16 = 128 bytes covering 2 adjacent K-tiles.
|
||||
// When n_k_tiles is even, scatter into 2 K-tiles per call (region 4095, no mask)
|
||||
// using the upper half of hmx_transpose_scatter_offsets. Tail one K-tile (when
|
||||
// n_k_tiles is odd) falls back to single-tile masked scatter.
|
||||
const bool pair_scatter = (n_k_tiles & 1) == 0;
|
||||
const size_t pair_region = (size_t) (2 * HMX_FP16_TILE_SIZE - 1);
|
||||
const size_t single_region = (size_t) (HMX_FP16_TILE_SIZE - 1);
|
||||
__builtin_assume(k > 0);
|
||||
__builtin_assume(end_row > start_row);
|
||||
|
||||
if (pair_scatter) {
|
||||
// Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter.
|
||||
const int c_step = 2 * HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const int n_c_iters = k / c_step;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const int ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
|
||||
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1);
|
||||
p1 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: scatter one K-tile per call (region 2047, masked).
|
||||
const int c_step = HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const int n_c_iters = k / c_step;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const int ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
|
||||
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1);
|
||||
p1 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Interleave row-major FP16 data into column-major tile format.
|
||||
// Input: [n_rows, head_dim] row-major. Output: tile[dim_tile][row_tile].
|
||||
// Processes rows [start_row, end_row) for multi-thread slicing.
|
||||
// Full range: start_row=0, end_row=n_rows.
|
||||
static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
|
||||
const __fp16 * restrict src,
|
||||
int n_rows,
|
||||
int head_dim,
|
||||
int src_stride,
|
||||
int n_row_tiles,
|
||||
int start_row,
|
||||
int end_row) {
|
||||
__builtin_assume(head_dim > 0);
|
||||
const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows;
|
||||
|
||||
const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride);
|
||||
const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
// Row-pair invariants hoisted out of the c loop.
|
||||
const int r0 = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
|
||||
|
||||
// tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0).
|
||||
// Each c step (+= 64) advances both by 2 dim-tiles worth of fp16.
|
||||
__fp16 * tb0 = tiles_out + (size_t) r0 * HMX_FP16_TILE_N_ELMS;
|
||||
__fp16 * tb1 = tb0 + tile_stride_elms;
|
||||
const size_t tb_step = 2 * tile_stride_elms;
|
||||
|
||||
if (pv_in1) {
|
||||
for (int c = 0; c < head_dim; c += 64) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_Vector v1 = *pv_in1++;
|
||||
HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2);
|
||||
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
|
||||
((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp);
|
||||
tb0 += tb_step;
|
||||
tb1 += tb_step;
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int c = 0; c < head_dim; c += 64) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2);
|
||||
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
|
||||
((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp);
|
||||
tb0 += tb_step;
|
||||
tb1 += tb_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // HMX_UTILS_H
|
||||
|
||||
@@ -77,6 +77,12 @@ static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline _Float16 hvx_vec_get_f16(HVX_Vector v) {
|
||||
_Float16 __attribute__((aligned(128))) x;
|
||||
hvx_vec_store_a(&x, 2, v);
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
|
||||
// abs by clearing the fp16 sign bit
|
||||
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
|
||||
|
||||
@@ -7,7 +7,8 @@
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#define hvx_splat_loop_body(dst_type, vec_store) \
|
||||
#define hvx_splat_pragma(x) _Pragma(#x)
|
||||
#define hvx_splat_loop_body(dst_type, vec_store, unroll_cnt) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
\
|
||||
@@ -16,7 +17,7 @@
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
hvx_splat_pragma(unroll(unroll_cnt)) \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = src; \
|
||||
} \
|
||||
@@ -25,31 +26,47 @@
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
static inline void hvx_splat_a(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a);
|
||||
hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a, 4);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u);
|
||||
static inline void hvx_splat_u(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u, 4);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) {
|
||||
static inline void hvx_splat_f32_a(void * restrict dst, float v, uint32_t n) {
|
||||
hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) {
|
||||
static inline void hvx_splat_f32_u(void * restrict dst, float v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) {
|
||||
static inline void hvx_splat_f16_a(void * restrict dst, _Float16 v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) {
|
||||
static inline void hvx_splat_f16_u(void * restrict dst, _Float16 v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u16_a(void * restrict dst, uint16_t v, uint32_t n) {
|
||||
hvx_splat_a(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u16_u(void * restrict dst, uint16_t v, uint32_t n) {
|
||||
hvx_splat_u(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u8_a(void * restrict dst, uint8_t v, uint32_t n) {
|
||||
hvx_splat_a(dst, Q6_Vb_vsplat_R(v), n, 1);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u8_u(void * restrict dst, uint8_t v, uint32_t n) {
|
||||
hvx_splat_u(dst, Q6_Vb_vsplat_R(v), n, 1);
|
||||
}
|
||||
|
||||
#define hvx_copy_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
|
||||
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
|
||||
#define EXP_ONE (0x3f800000) // 1.0
|
||||
#define EXP_RANGE_R (0x42B16666) // 88.7
|
||||
#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228
|
||||
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
|
||||
@@ -163,7 +163,7 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict
|
||||
HVX_Vector vec_out = Q6_V_vzero();
|
||||
|
||||
static const float kInf = INFINITY;
|
||||
static const float kMaxExp = 88.7f;
|
||||
static const float kMaxExp = 88.7228f;
|
||||
|
||||
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
|
||||
|
||||
@@ -26,8 +26,8 @@ struct htp_unary_context {
|
||||
const uint8_t * data_src0;
|
||||
uint8_t * data_dst;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t dst_row_size;
|
||||
size_t src0_data_row_size; // actual data bytes per row
|
||||
size_t dst_data_row_size; // actual data bytes per row
|
||||
|
||||
size_t src0_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
@@ -41,6 +41,40 @@ struct htp_unary_context {
|
||||
uint32_t nc;
|
||||
};
|
||||
|
||||
// Convert flat row index to DDR byte offset using the tensor's actual strides.
|
||||
// ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3
|
||||
static inline size_t unary_row_offset(uint32_t ir,
|
||||
uint32_t ne1, uint32_t ne2,
|
||||
size_t nb1, size_t nb2, size_t nb3) {
|
||||
const uint32_t i1 = ir % ne1;
|
||||
const uint32_t i2 = (ir / ne1) % ne2;
|
||||
const uint32_t i3 = ir / (ne1 * ne2);
|
||||
return i1 * nb1 + i2 * nb2 + i3 * nb3;
|
||||
}
|
||||
// Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice
|
||||
// boundary of src and dst so the nb1 stride stays valid for all rows.
|
||||
static inline uint32_t unary_block_size(uint32_t ir,
|
||||
uint32_t end_row,
|
||||
uint32_t block,
|
||||
bool src_contig,
|
||||
bool dst_contig,
|
||||
uint32_t src_ne1,
|
||||
uint32_t dst_ne1) {
|
||||
uint32_t limit = MIN(block, end_row - ir);
|
||||
|
||||
if (!src_contig) {
|
||||
const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1;
|
||||
limit = MIN(limit, src_slice_end - ir);
|
||||
}
|
||||
|
||||
if (!dst_contig) {
|
||||
const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1;
|
||||
limit = MIN(limit, dst_slice_end - ir);
|
||||
}
|
||||
|
||||
return limit;
|
||||
}
|
||||
|
||||
#define htp_unary_preamble \
|
||||
const uint32_t ne00 = src->ne[0]; \
|
||||
const uint32_t ne01 = src->ne[1]; \
|
||||
@@ -276,8 +310,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
int32_t * op_params = octx->op_params;
|
||||
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
|
||||
|
||||
const size_t src0_row_size = uctx->src0_row_size;
|
||||
const size_t dst_row_size = uctx->dst_row_size;
|
||||
const size_t src0_data_row_size = uctx->src0_data_row_size;
|
||||
const size_t dst_data_row_size = uctx->dst_data_row_size;
|
||||
|
||||
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
|
||||
@@ -303,7 +337,16 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
size_t src0_spad_half_size = uctx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = uctx->dst_spad_half_size;
|
||||
|
||||
const int BLOCK = uctx->block;
|
||||
// Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride
|
||||
// 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every
|
||||
// transfer stays within a nb1-uniform region. Skipped for contiguous tensors.
|
||||
const bool src0_contig = (nb02 == (size_t)ne01 * nb01) &&
|
||||
(nb03 == (size_t)ne02 * nb02);
|
||||
const bool dst_contig = (nb2 == (size_t)ne1 * nb1) &&
|
||||
(nb3 == (size_t)ne2 * nb2);
|
||||
const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01);
|
||||
const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1);
|
||||
const uint32_t BLOCK = MIN(src0_max_block, dst_max_block);
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
@@ -312,21 +355,23 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) {
|
||||
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
nb1, dst_row_size_aligned, dst_data_row_size, 0);
|
||||
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03);
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off),
|
||||
src0_row_size_aligned, nb01, src0_data_row_size, block_size);
|
||||
ir += block_size;
|
||||
}
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ) {
|
||||
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
|
||||
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
@@ -361,18 +406,25 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
break;
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
|
||||
dst_row_size, dst_row_size_aligned, block_size);
|
||||
const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3);
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(data_dst + dst_off, dst_spad),
|
||||
nb1, dst_row_size_aligned, dst_data_row_size, block_size);
|
||||
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
const uint32_t next_ir = ir + block_size;
|
||||
if (next_ir < src0_end_row) {
|
||||
const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
const uint32_t pref_ir = next_ir + next_block_size;
|
||||
if (pref_ir < src0_end_row) {
|
||||
const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03);
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(src0_spad, data_src + src0_pref_off),
|
||||
src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
ir += block_size;
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
@@ -426,11 +478,11 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
const size_t src0_data_row_size = src0->ne[0] * sizeof(float);
|
||||
const size_t dst_data_row_size = dst->ne[0] * sizeof(float);
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN);
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
@@ -468,8 +520,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
.data_src0 = (const uint8_t *)src0->data,
|
||||
.data_dst = (uint8_t *)dst->data,
|
||||
|
||||
.src0_row_size = src0_row_size,
|
||||
.dst_row_size = dst_row_size,
|
||||
.src0_data_row_size = src0_data_row_size,
|
||||
.dst_data_row_size = dst_data_row_size,
|
||||
|
||||
.src0_row_size_aligned = src0_row_size_aligned,
|
||||
.dst_row_size_aligned = dst_row_size_aligned,
|
||||
|
||||
16
ggml/src/ggml-hexagon/htp/vtcm-utils.h
Normal file
16
ggml/src/ggml-hexagon/htp/vtcm-utils.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef VTCM_UTILS_H
|
||||
#define VTCM_UTILS_H
|
||||
|
||||
#include "hex-utils.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
uint8_t *p = *vtcm_ptr;
|
||||
*vtcm_ptr += size;
|
||||
return p;
|
||||
}
|
||||
|
||||
#endif // VTCM_UTILS_H
|
||||
@@ -107,6 +107,10 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mv_id_mxfp4_f32_flat
|
||||
gemm_moe_mxfp4_f32
|
||||
gemv_moe_mxfp4_f32
|
||||
gemm_moe_mxfp4_f32_ns
|
||||
gemv_moe_mxfp4_f32_ns
|
||||
moe_reorder_b
|
||||
moe_sort_by_expert
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul_mm_q4_0_f32_l4_lm
|
||||
|
||||
@@ -416,6 +416,15 @@ struct ggml_backend_opencl_context {
|
||||
ggml_cl_buffer prealloc_src0;
|
||||
ggml_cl_buffer prealloc_src1;
|
||||
|
||||
// prealloc buffers for MoE router table preprocess
|
||||
bool toggle_reorder = false;
|
||||
ggml_cl_buffer prealloc_post_router;
|
||||
ggml_cl_buffer prealloc_emap;
|
||||
ggml_cl_buffer prealloc_hist;
|
||||
ggml_cl_buffer prealloc_tile_offset;
|
||||
ggml_cl_buffer prealloc_total_tiles;
|
||||
ggml_cl_buffer prealloc_slot_counter;
|
||||
|
||||
cl_program program_add;
|
||||
cl_program program_add_id;
|
||||
cl_program program_clamp;
|
||||
@@ -531,6 +540,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
|
||||
cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1;
|
||||
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
|
||||
cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns;
|
||||
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans;
|
||||
cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
|
||||
@@ -587,6 +597,9 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;
|
||||
cl_kernel kernel_timestep_embedding;
|
||||
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
|
||||
cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns;
|
||||
cl_kernel kernel_moe_reorder_b;
|
||||
cl_kernel kernel_moe_histogram, kernel_moe_scan, kernel_moe_fill, kernel_moe_scatter;
|
||||
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
|
||||
cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
|
||||
cl_kernel kernel_mul_mv_id_mxfp4_f32;
|
||||
@@ -945,6 +958,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans4_ns", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
|
||||
@@ -2864,6 +2879,77 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemv_moe_mxfp4_f32_ns
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "gemv_moe_mxfp4_f32_ns.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("gemv_moe_mxfp4_f32_ns.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_mxfp4_f32_ns", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemm_moe_mxfp4_f32_ns
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "gemm_moe_mxfp4_f32_ns.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("gemm_moe_mxfp4_f32_ns.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// moe_reorder_b
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "moe_reorder_b.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("moe_reorder_b.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_moe_reorder_b = clCreateKernel(prog, "kernel_moe_reorder_b", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// moe_sort_by_expert
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "moe_sort_by_expert.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("moe_sort_by_expert.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_moe_histogram = clCreateKernel(prog, "kernel_moe_histogram", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_moe_scan = clCreateKernel(prog, "kernel_moe_scan", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_moe_fill = clCreateKernel(prog, "kernel_moe_fill", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_moe_scatter = clCreateKernel(prog, "kernel_moe_scatter", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemv_noshuffle_q6_k_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -3651,13 +3737,12 @@ struct ggml_tensor_extra_cl_mxfp4 {
|
||||
CL_CHECK(clReleaseMemObject(e));
|
||||
e = nullptr;
|
||||
}
|
||||
if (q != nullptr) {
|
||||
if (q_img != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(q_img));
|
||||
q = nullptr;
|
||||
q_img = nullptr;
|
||||
}
|
||||
// Currently, q_img and d_img are not used. They can be image1d_buffer_t
|
||||
// Currently, e_img is not used. They can be image1d_buffer_t
|
||||
// that wraps around q and d to utilize image access path.
|
||||
q_img = nullptr;
|
||||
e_img = nullptr;
|
||||
size_q = 0;
|
||||
size_e = 0;
|
||||
@@ -4740,7 +4825,7 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
|
||||
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
|
||||
GGML_UNUSED(backend_ctx);
|
||||
int ne01 = tensor->ne[1];
|
||||
return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
|
||||
return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
|
||||
}
|
||||
|
||||
inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
|
||||
@@ -5151,8 +5236,9 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
CL_CHECK(err);
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
// Adreno moe mxfp4 kernel needs special transpose and unshuffling
|
||||
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans4_ns;
|
||||
|
||||
int ne00 = tensor->ne[0];
|
||||
int ne01 = tensor->ne[1];
|
||||
@@ -5172,9 +5258,21 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
tensor->extra = extra;
|
||||
|
||||
// Create image for Q
|
||||
cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32};
|
||||
cl_image_desc img_desc_q = {
|
||||
CL_MEM_OBJECT_IMAGE1D_BUFFER,
|
||||
static_cast<size_t>(ggml_nelements(tensor) / 8),
|
||||
0, 0, 0, 0, 0, 0, 0,
|
||||
{ extra->q }
|
||||
};
|
||||
extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
|
||||
tensor->extra = extra;
|
||||
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
@@ -5912,7 +6010,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans;
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans4_ns;
|
||||
|
||||
int ne00 = tensor->ne[0];
|
||||
int ne01 = tensor->ne[1];
|
||||
@@ -5936,7 +6034,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
|
||||
@@ -12763,6 +12862,118 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
}
|
||||
}
|
||||
|
||||
static void moe_router_reoerder(ggml_backend_t backend, const ggml_tensor * src, int ne20) {
|
||||
cl_int err;
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra;
|
||||
cl_ulong offset = extra->offset + src->view_offs;
|
||||
|
||||
const int ne21 = src->ne[1];
|
||||
const int nb21 = src->nb[1];
|
||||
const int ne02 = nb21 / src->nb[0];
|
||||
const int n_tile_size = 32;
|
||||
const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02;
|
||||
|
||||
cl_buffer_region region;
|
||||
region.origin = offset;
|
||||
region.size = nb21 * ne21;
|
||||
cl_mem original_router_buf = clCreateSubBuffer(extra->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
backend_ctx->prealloc_post_router.allocate(backend_ctx->context, sizeof(int) * max_post_router_tile * n_tile_size);
|
||||
region.origin = 0;
|
||||
region.size = sizeof(int) * max_post_router_tile * n_tile_size;
|
||||
cl_mem post_router_buf = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
backend_ctx->prealloc_emap.allocate(backend_ctx->context, sizeof(short) * max_post_router_tile);
|
||||
region.origin = 0;
|
||||
region.size = sizeof(short) * max_post_router_tile;
|
||||
cl_mem emap_buf = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
backend_ctx->prealloc_hist.allocate(backend_ctx->context, sizeof(int) * ne02);
|
||||
region.origin = 0;
|
||||
region.size = sizeof(int) * ne02;
|
||||
cl_mem hist_buf = clCreateSubBuffer(backend_ctx->prealloc_hist.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
backend_ctx->prealloc_tile_offset.allocate(backend_ctx->context, sizeof(int) * ne02);
|
||||
region.origin = 0;
|
||||
region.size = sizeof(int) * ne02;
|
||||
cl_mem tile_offset_buf = clCreateSubBuffer(backend_ctx->prealloc_tile_offset.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
backend_ctx->prealloc_slot_counter.allocate(backend_ctx->context, sizeof(int) * ne02);
|
||||
region.origin = 0;
|
||||
region.size = sizeof(int) * ne02;
|
||||
cl_mem slot_counter_buf = clCreateSubBuffer(backend_ctx->prealloc_slot_counter.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
backend_ctx->prealloc_total_tiles.allocate(backend_ctx->context, sizeof(int));
|
||||
region.origin = 0;
|
||||
region.size = sizeof(int);
|
||||
cl_mem total_tiles_buf = clCreateSubBuffer(backend_ctx->prealloc_total_tiles.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
// Histogram
|
||||
cl_kernel kernel = backend_ctx->kernel_moe_histogram;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &hist_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &ne21));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne20));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne02));
|
||||
|
||||
size_t histogram_global_size[] = {(size_t)(((ne21 + 63) / 64) * 64), static_cast<size_t>(ne20), 1};
|
||||
size_t histogram_local_size[] = {64, static_cast<size_t>(ne20), 1};
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src);
|
||||
|
||||
// Scan
|
||||
kernel = backend_ctx->kernel_moe_scan;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &hist_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &tile_offset_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &total_tiles_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &slot_counter_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n_tile_size));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne02));
|
||||
|
||||
size_t scan_global_size[] = {1};
|
||||
size_t scan_local_size[] = {1};
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 1, scan_global_size, scan_local_size, src);
|
||||
|
||||
// Fill
|
||||
kernel = backend_ctx->kernel_moe_fill;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &post_router_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &total_tiles_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &n_tile_size));
|
||||
|
||||
size_t fill_global_size[] = {(size_t)(((max_post_router_tile + 63) / 64) * 64), n_tile_size, 1};
|
||||
size_t fill_local_size[] = {64, 1, 1};
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, fill_global_size, fill_local_size, src);
|
||||
|
||||
// Scatter
|
||||
kernel = backend_ctx->kernel_moe_scatter;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &post_router_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &emap_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &tile_offset_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &slot_counter_buf));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne21));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne20));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02));
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src);
|
||||
|
||||
CL_CHECK(clReleaseMemObject(original_router_buf));
|
||||
CL_CHECK(clReleaseMemObject(hist_buf));
|
||||
CL_CHECK(clReleaseMemObject(tile_offset_buf));
|
||||
CL_CHECK(clReleaseMemObject(total_tiles_buf));
|
||||
CL_CHECK(clReleaseMemObject(slot_counter_buf));
|
||||
CL_CHECK(clReleaseMemObject(post_router_buf));
|
||||
CL_CHECK(clReleaseMemObject(emap_buf));
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
@@ -12824,6 +13035,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
|
||||
const int ne0 = dst->ne[0];
|
||||
const int ne1 = dst->ne[1];
|
||||
const int ne2 = dst->ne[2];
|
||||
|
||||
const int r2 = ne12/ne02;
|
||||
const int r3 = ne13/ne03;
|
||||
@@ -12836,6 +13048,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
int nrows = 1; // number of row in src1
|
||||
int ndst = 4; // number of values produced by each subgroup
|
||||
|
||||
const int n_tile_size = 32;
|
||||
const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02;
|
||||
|
||||
cl_kernel kernel;
|
||||
|
||||
// subgroup mat vec
|
||||
@@ -12967,11 +13182,10 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
size_t local_size[3] = {64, 2, 1};
|
||||
size_t global_size[3] = {64, 2, 1};
|
||||
|
||||
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
|
||||
|
||||
int tile_size = 320;
|
||||
if (ne12 == 1) { // for gemv
|
||||
kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32;
|
||||
kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32_ns;
|
||||
|
||||
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
|
||||
|
||||
// create a sub_buffer for src2
|
||||
cl_buffer_region region;
|
||||
@@ -12985,78 +13199,154 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
global_size[1] = 4;
|
||||
global_size[2] = static_cast<size_t>(ne20);
|
||||
local_size[1] = 4;
|
||||
} else { // for gemm
|
||||
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32;
|
||||
|
||||
// preprocess router table
|
||||
int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size;
|
||||
void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short));
|
||||
void * host_src2 = malloc(ne21 * nb21);
|
||||
CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL));
|
||||
int total_experts = nb21 / nb20;
|
||||
int out_idx = 0;
|
||||
for (int i_expert = 0; i_expert < ne02; i_expert++) {
|
||||
for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) {
|
||||
for (int j = 0; j < ne21; j++) {
|
||||
for (int i = 0; i < ne20; i++) {
|
||||
int expert = ((int *)host_src2)[j * total_experts + i];
|
||||
if (i_expert == expert) {
|
||||
((short *)host_src2_reorder)[out_idx] = static_cast<short>(expert);
|
||||
((short *)host_src2_reorder)[out_idx + 1] = static_cast<short>(j * ne11 + (i % ne11));
|
||||
((short *)host_src2_reorder)[out_idx + 2] = static_cast<short>(j * ne20 + i);
|
||||
((short *)host_src2_reorder)[out_idx + 3] = static_cast<short>(i_tile);
|
||||
out_idx += 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status);
|
||||
// create a sub_buffer for src1
|
||||
region.origin = offset1;
|
||||
region.size = ne10 * ne11 * ne12 * sizeof(float);
|
||||
src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// set thread grid
|
||||
global_size[0] = static_cast<size_t>(tile_size);
|
||||
global_size[2] = static_cast<size_t>(ne20 * ne21 * num_tiles_per_expert);
|
||||
}
|
||||
// create image for src1
|
||||
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
|
||||
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
|
||||
buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// create a sub_buffer for src1
|
||||
cl_buffer_region region;
|
||||
region.origin = offset1;
|
||||
region.size = ne10 * ne11 * ne12 * sizeof(float);
|
||||
src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// create image for src1
|
||||
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
|
||||
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
|
||||
buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// Set kernel args
|
||||
int arg_idx = 0;
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
|
||||
if (ne12 == 1) {
|
||||
// Set kernel args
|
||||
int arg_idx = 0;
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11));
|
||||
} else {
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size));
|
||||
|
||||
// launch kernel
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
|
||||
|
||||
// deallocate sub buffers and images
|
||||
CL_CHECK(clReleaseMemObject(src1_sub_buffer));
|
||||
CL_CHECK(clReleaseMemObject(buf_src1_image));
|
||||
CL_CHECK(clReleaseMemObject(buf_src2));
|
||||
|
||||
} else { // for gemm
|
||||
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns;
|
||||
|
||||
// Reorder router if called from test-backend-ops or when new router is generated.
|
||||
// Otherwise reuse the reordered result from previous mul_mat_id call.
|
||||
if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) {
|
||||
moe_router_reoerder(backend, src2, ne20);
|
||||
backend_ctx->toggle_reorder = false;
|
||||
}
|
||||
|
||||
cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image;
|
||||
cl_mem buf_src2, buf_src2_emap;
|
||||
|
||||
cl_buffer_region region;
|
||||
region.origin = 0;
|
||||
region.size = sizeof(int) * max_post_router_tile * n_tile_size;
|
||||
GGML_ASSERT(backend_ctx->prealloc_post_router.buffer);
|
||||
buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
region.origin = 0;
|
||||
region.size = sizeof(short) * max_post_router_tile;
|
||||
buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// Reorder activations
|
||||
// create a sub_buffer for src1
|
||||
region.origin = offset1;
|
||||
region.size = ne10 * ne11 * ne12 * sizeof(float);
|
||||
sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// Create image for reordered src1
|
||||
// Use pre-allocated placeholder
|
||||
region.origin = 0;
|
||||
region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float);
|
||||
backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size);
|
||||
buf_src1_reordered = clCreateSubBuffer(
|
||||
backend_ctx->prealloc_act_trans.buffer,
|
||||
0,
|
||||
CL_BUFFER_CREATE_TYPE_REGION,
|
||||
®ion,
|
||||
&status);
|
||||
CL_CHECK(status);
|
||||
cl_image_format image_format_buf_src1;
|
||||
cl_image_desc image_desc_buf_src1;
|
||||
image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
|
||||
image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}};
|
||||
image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
unsigned short map_ratio = ne20 / ne11;
|
||||
GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n");
|
||||
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre));
|
||||
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2));
|
||||
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered));
|
||||
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer)));
|
||||
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio));
|
||||
CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size));
|
||||
|
||||
size_t reorder_b_local_size[3] = {256, 1, 1};
|
||||
size_t reorder_b_global_size[3] = {static_cast<size_t>(((ne00 / 4) + 255) / 256 * 256), static_cast<size_t>(max_post_router_tile * n_tile_size), 1};
|
||||
|
||||
// Dispatch reorder kernel
|
||||
backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst);
|
||||
|
||||
// MoE kernel prepare
|
||||
// Create sub buffer for dst
|
||||
region.origin = offsetd;
|
||||
region.size = ne0 * ne1 * ne2 * sizeof(float);
|
||||
sub_buf_dst = clCreateSubBuffer(
|
||||
extrad->data_device,
|
||||
0,
|
||||
CL_BUFFER_CREATE_TYPE_REGION,
|
||||
®ion,
|
||||
&status);
|
||||
CL_CHECK(status);
|
||||
// Create image for dst
|
||||
cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT};
|
||||
cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}};
|
||||
buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// Set kernel args
|
||||
int arg_idx = 0;
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer)));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
|
||||
|
||||
// set thread grid
|
||||
global_size[1] = static_cast<size_t>((ne01 + 63) / 64);
|
||||
global_size[2] = static_cast<size_t>(max_post_router_tile);
|
||||
local_size[1] = 1;
|
||||
local_size[2] = 1;
|
||||
|
||||
// Dispatch kernel
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
|
||||
|
||||
clReleaseMemObject(sub_buf_src1_pre);
|
||||
clReleaseMemObject(buf_src1_reordered);
|
||||
clReleaseMemObject(image_src1_reordered);
|
||||
clReleaseMemObject(buf_src2);
|
||||
clReleaseMemObject(buf_src2_emap);
|
||||
clReleaseMemObject(sub_buf_dst);
|
||||
clReleaseMemObject(buf_dst_image);
|
||||
}
|
||||
|
||||
// launch kernel
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
|
||||
|
||||
// deallocate sub buffers and images
|
||||
CL_CHECK(clReleaseMemObject(src1_sub_buffer));
|
||||
CL_CHECK(clReleaseMemObject(buf_src1_image));
|
||||
CL_CHECK(clReleaseMemObject(buf_src2));
|
||||
return;
|
||||
} // else fallback to generic kernel
|
||||
} // fallback to generic MoE mxfp4 kernel
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
@@ -14002,6 +14292,13 @@ static void ggml_cl_argsort(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
size_t local_work_size[] = {(size_t)ne00_padded, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
const int ne21 = dst->ne[1];
|
||||
if ((strstr(src0->name, "_moe") != NULL) && (ne21 != 1)) {
|
||||
backend_ctx->toggle_reorder = true;
|
||||
}
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
}
|
||||
|
||||
static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
||||
@@ -371,6 +371,93 @@ kernel void kernel_restore_block_mxfp4_trans(
|
||||
b->e = src_e[src_blk_offset];
|
||||
}
|
||||
|
||||
kernel void kernel_convert_block_mxfp4_trans4_ns(
|
||||
global struct block_mxfp4 * src0,
|
||||
__global uint * dst_q,
|
||||
__global uchar * dst_e,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_MXFP4;
|
||||
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
|
||||
global struct block_mxfp4 * b = src0 + src_blk_offset;
|
||||
dst_e[dst_blk_offset] = b->e;
|
||||
|
||||
// extract quantization and unshuffle
|
||||
ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0];
|
||||
|
||||
ushort8 post_block = (ushort8)(0);
|
||||
|
||||
uchar * pre_block_ptr = (uchar *)(&pre_block);
|
||||
uchar * post_block_ptr = (uchar *)(&post_block);
|
||||
|
||||
for (int i = 0; i < QK_MXFP4 / 4; ++i) {
|
||||
uchar x0 = pre_block_ptr[2*i + 0];
|
||||
uchar x1 = pre_block_ptr[2*i + 1];
|
||||
|
||||
post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
|
||||
post_block_ptr[i + QK_MXFP4 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
|
||||
}
|
||||
|
||||
uint4 q_block = as_uint4(post_block);
|
||||
|
||||
uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
|
||||
dst_q[offset] = q_block.x;
|
||||
dst_q[offset + ne01] = q_block.y;
|
||||
dst_q[offset + ne01 * 2] = q_block.z;
|
||||
dst_q[offset + ne01 * 3] = q_block.w;
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_mxfp4_trans4_ns(
|
||||
__global uint * src_q,
|
||||
__global uchar * src_e,
|
||||
__global struct block_mxfp4 * dst0,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_MXFP4;
|
||||
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
|
||||
__global struct block_mxfp4 * b = dst0 + dst_blk_offset;
|
||||
b->e = src_e[src_d_offset];
|
||||
|
||||
// collect transposed quantization parts for a block
|
||||
uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
|
||||
uint4 q_block;
|
||||
q_block.x = src_q[src_q_offset];
|
||||
q_block.y = src_q[src_q_offset + ne01];
|
||||
q_block.z = src_q[src_q_offset + ne01 * 2];
|
||||
q_block.w = src_q[src_q_offset + ne01 * 3];
|
||||
|
||||
ushort8 post_block = as_ushort8(q_block);
|
||||
ushort8 pre_block = (ushort8)(0);
|
||||
|
||||
uchar * pre_block_ptr = (uchar *)(&pre_block);
|
||||
uchar * post_block_ptr = (uchar *)(&post_block);
|
||||
|
||||
for (int i = 0; i < QK_MXFP4 / 4; ++i) {
|
||||
uchar x0 = post_block_ptr[i + 0];
|
||||
uchar x1 = post_block_ptr[i + QK_MXFP4 / 4];
|
||||
|
||||
pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
|
||||
pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
|
||||
}
|
||||
|
||||
((__global ushort8 *)(&(b->qs[0])))[0] = pre_block;
|
||||
}
|
||||
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q8_0
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
302
ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl
Normal file
302
ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl
Normal file
@@ -0,0 +1,302 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
|
||||
|
||||
#define TILESIZE_K 16
|
||||
#define TILESIZE_M 64
|
||||
#define TILESIZE_N 32
|
||||
|
||||
|
||||
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
|
||||
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
|
||||
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
|
||||
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
|
||||
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
|
||||
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
|
||||
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
|
||||
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s0 & 0x8000;
|
||||
|
||||
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
|
||||
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
|
||||
|
||||
ushort2 fp16_packed_a_1, fp16_packed_b_1;
|
||||
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
|
||||
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
|
||||
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
|
||||
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
|
||||
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
|
||||
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
|
||||
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s1 & 0x8000;
|
||||
|
||||
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
|
||||
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
|
||||
|
||||
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
|
||||
}
|
||||
|
||||
|
||||
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
|
||||
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
|
||||
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
|
||||
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
|
||||
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
|
||||
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
|
||||
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
|
||||
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
|
||||
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
|
||||
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
|
||||
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
|
||||
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
|
||||
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
|
||||
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
|
||||
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
|
||||
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
|
||||
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
|
||||
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
|
||||
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
|
||||
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
|
||||
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
|
||||
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
|
||||
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
|
||||
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
|
||||
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
|
||||
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
|
||||
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
|
||||
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
|
||||
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
|
||||
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
|
||||
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
|
||||
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
|
||||
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
|
||||
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
|
||||
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
|
||||
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
|
||||
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
|
||||
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
|
||||
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
|
||||
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
|
||||
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
|
||||
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
|
||||
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
|
||||
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
|
||||
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
|
||||
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
|
||||
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
|
||||
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
|
||||
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
|
||||
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
|
||||
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
|
||||
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
|
||||
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
|
||||
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
|
||||
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
|
||||
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
|
||||
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
|
||||
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
|
||||
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
|
||||
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
|
||||
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
|
||||
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
|
||||
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
|
||||
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
|
||||
|
||||
static inline half e8m0_to_fp16(uchar x) {
|
||||
ushort bits;
|
||||
bits = (ushort)(x) - (ushort)(112);
|
||||
bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10);
|
||||
return as_half(bits);
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
|
||||
__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair
|
||||
kernel void kernel_gemm_moe_mxfp4_f32_ns(
|
||||
__read_only image1d_buffer_t src0_q,
|
||||
__global uchar * src0_d,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global ushort * src2_emap,
|
||||
__write_only image1d_buffer_t dst,
|
||||
__global int * total_tiles,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint block_id_m = get_global_id(1); // m_tile
|
||||
uint block_id_n = get_global_id(2); // n_tile
|
||||
|
||||
// Boundary check
|
||||
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
|
||||
return;
|
||||
}
|
||||
|
||||
__private half16 reg_a;
|
||||
__private float32 reg_c = (float32)(0);
|
||||
__local half4 shared_b[128];
|
||||
|
||||
const ushort expert_id = src2_emap[block_id_n];
|
||||
|
||||
const uint row = block_id_m * TILESIZE_M;
|
||||
const uint col = block_id_n * TILESIZE_N;
|
||||
|
||||
uint sub_block_id_m = get_local_id(0);
|
||||
uint2 b_global_offset;
|
||||
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
|
||||
b_global_offset.y = b_global_offset.x + (16 * ne00);
|
||||
uint2 b_local_offset;
|
||||
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
|
||||
b_local_offset.y = b_local_offset.x + 16;
|
||||
|
||||
// Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks
|
||||
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
|
||||
// First sub-block
|
||||
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5);
|
||||
uint b_sub_offset = col * ne00 + step;
|
||||
|
||||
// Load scale for current mxfp4 block
|
||||
uint s_offset = s_sub_offset + get_global_id(0);
|
||||
float s = e8m0_to_fp32(src0_d[s_offset]);
|
||||
|
||||
// Load 16 fp4 (64-bits) in transposed layout
|
||||
uint2 mxfp4x16;
|
||||
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
|
||||
float8 bx8_f32;
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
// Convert to half and store to LM to share within the subgroup
|
||||
half8 bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantization
|
||||
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
|
||||
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// 32 16x16 fp16 dot product with 8 elements reduction for better precision
|
||||
half16 acc;
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
|
||||
// Repeat for second sub-block
|
||||
uint half_step = step + TILESIZE_K;
|
||||
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
b_sub_offset = col * ne00 + half_step;
|
||||
|
||||
// Load next 16 fp4 (64-bits) in transposed layout
|
||||
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
// Convert to half and store to LM to share within the subgroup
|
||||
bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantization
|
||||
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
|
||||
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// 32 16x16 fp16 dot product with 3-levels reduction for better precision
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
}
|
||||
|
||||
// Load poster router and share in LM
|
||||
__local uint out_idx[TILESIZE_N];
|
||||
|
||||
if (get_local_id(0) < TILESIZE_N) {
|
||||
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
|
||||
if (idx == 0xFFFFFFFF) {
|
||||
idx = src2[block_id_n * TILESIZE_N + 0];
|
||||
}
|
||||
out_idx[get_local_id(0)] = idx * ne01;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Scatter results back to original position in output grid
|
||||
uint m_offset = row + get_local_id(0);
|
||||
|
||||
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
|
||||
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
|
||||
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
|
||||
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
|
||||
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
|
||||
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
|
||||
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
|
||||
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
|
||||
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
|
||||
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
|
||||
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
|
||||
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
|
||||
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
|
||||
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
|
||||
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
|
||||
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
|
||||
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
|
||||
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
|
||||
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
|
||||
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
|
||||
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
|
||||
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
|
||||
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
|
||||
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
|
||||
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
|
||||
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
|
||||
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
|
||||
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
|
||||
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
|
||||
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
|
||||
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
|
||||
|
||||
// Store zero padding parts to the index of first output in tile, override correct result in the end
|
||||
barrier(CLK_GLOBAL_MEM_FENCE);
|
||||
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
|
||||
}
|
||||
161
ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl
Normal file
161
ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl
Normal file
@@ -0,0 +1,161 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#define QK_MXFP4 32
|
||||
#define N_SIMDGROUP 4
|
||||
#define SIMDGROUP_WIDTH 64
|
||||
|
||||
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
|
||||
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
|
||||
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
|
||||
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
|
||||
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
|
||||
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
|
||||
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
|
||||
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s0 & 0x8000;
|
||||
|
||||
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
|
||||
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
|
||||
|
||||
ushort2 fp16_packed_a_1, fp16_packed_b_1;
|
||||
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
|
||||
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
|
||||
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
|
||||
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
|
||||
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
|
||||
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
|
||||
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s1 & 0x8000;
|
||||
|
||||
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
|
||||
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
|
||||
|
||||
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
|
||||
__attribute__((qcom_reqd_sub_group_size("half")))
|
||||
__kernel void kernel_gemv_moe_mxfp4_f32_ns(
|
||||
__global uint * src0_q,
|
||||
__global uchar * src0_e,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne11
|
||||
) {
|
||||
uint i01 = get_global_id(0);
|
||||
uint i20 = get_global_id(2);
|
||||
uint sgid = get_local_id(1);
|
||||
uint slid = get_sub_group_local_id();
|
||||
|
||||
uint i11 = i20 % ne11;
|
||||
|
||||
uint expert_id = src2[i20];
|
||||
uint expert_offset = expert_id * ne00 * ne01 / 32;
|
||||
|
||||
__private float sum = 0.0f; // each thread calculate partial sum of one output
|
||||
|
||||
// loop along ne00 in block granularity, skip 4 blocks every iter
|
||||
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
|
||||
|
||||
// load one block of q
|
||||
uint4 regQ;
|
||||
uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01;
|
||||
|
||||
regQ.s0 = src0_q[block_offset];
|
||||
regQ.s1 = src0_q[block_offset + ne01];
|
||||
regQ.s2 = src0_q[block_offset + ne01 * 2];
|
||||
regQ.s3 = src0_q[block_offset + ne01 * 3];
|
||||
|
||||
uint offset = i11 * ne00 / 4 + ib00 * 8;
|
||||
|
||||
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
|
||||
|
||||
float4 shared_y4;
|
||||
shared_y4 = read_imagef(src1, (offset + 0));
|
||||
float4 acc = shared_y4 * convert_float4(fp16x8.lo);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 1));
|
||||
acc += shared_y4 * convert_float4(fp16x8.hi);
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 2));
|
||||
acc += shared_y4 * convert_float4(fp16x8.lo);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 3));
|
||||
acc += shared_y4 * convert_float4(fp16x8.hi);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 4));
|
||||
acc += shared_y4 * convert_float4(fp16x8.lo);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 5));
|
||||
acc += shared_y4 * convert_float4(fp16x8.hi);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 6));
|
||||
acc += shared_y4 * convert_float4(fp16x8.lo);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 7));
|
||||
acc += shared_y4 * convert_float4(fp16x8.hi);
|
||||
|
||||
uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];
|
||||
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #subgroups=4
|
||||
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
|
||||
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
|
||||
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
|
||||
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 1 outputs per thread in subgroup 0
|
||||
if (sgid == 0) {
|
||||
dst = dst + (offsetd >> 2);
|
||||
dst[i01 + i20 * ne01] = sum;
|
||||
}
|
||||
|
||||
}
|
||||
30
ggml/src/ggml-opencl/kernels/moe_reorder_b.cl
Normal file
30
ggml/src/ggml-opencl/kernels/moe_reorder_b.cl
Normal file
@@ -0,0 +1,30 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define QK4_0 32
|
||||
|
||||
kernel void kernel_moe_reorder_b(
|
||||
global float4 * src,
|
||||
global uint * router,
|
||||
global float4 * dst,
|
||||
global int * total_tiles,
|
||||
uint K,
|
||||
ushort map_ratio,
|
||||
uint tile_size
|
||||
) {
|
||||
uint k_4 = get_global_id(0);
|
||||
uint post_router_idx = get_global_id(1);
|
||||
|
||||
if ((k_4 >= (K / 4)) || (post_router_idx >= total_tiles[0] * tile_size)) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint router_idx = router[post_router_idx];
|
||||
|
||||
float4 out = (float4)(0);
|
||||
if (router_idx != 0xFFFFFFFF) {
|
||||
ushort activation_idx = router_idx / map_ratio;
|
||||
out = src[activation_idx * K / 4 + k_4];
|
||||
}
|
||||
|
||||
dst[post_router_idx * K / 4 + k_4] = out;
|
||||
}
|
||||
82
ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl
Normal file
82
ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl
Normal file
@@ -0,0 +1,82 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
__kernel void kernel_moe_histogram(
|
||||
__global const int * input,
|
||||
__global int * hist,
|
||||
uint N,
|
||||
uint topK,
|
||||
uint n_experts
|
||||
) {
|
||||
uint n = get_global_id(0);
|
||||
uint k = get_global_id(1);
|
||||
|
||||
if (n >= N || k >= topK) {
|
||||
return;
|
||||
}
|
||||
|
||||
int expert_id = input[n * n_experts + k];
|
||||
atomic_inc(&hist[expert_id]);
|
||||
}
|
||||
|
||||
__kernel void kernel_moe_scan(
|
||||
__global int * hist,
|
||||
__global int * tile_offset,
|
||||
__global int * total_tiles,
|
||||
__global int * slot_counter,
|
||||
int tile_size,
|
||||
uint n_experts
|
||||
) {
|
||||
int offset = 0;
|
||||
for (int v = 0; v < n_experts; v++) {
|
||||
int count = hist[v];
|
||||
int tiles = (count + tile_size - 1) / tile_size;
|
||||
tile_offset[v] = offset;
|
||||
offset += tiles;
|
||||
hist[v] = 0;
|
||||
slot_counter[v] = 0;
|
||||
}
|
||||
|
||||
*total_tiles = offset;
|
||||
}
|
||||
|
||||
__kernel void kernel_moe_scatter(
|
||||
__global const int * input,
|
||||
__global int * post_router,
|
||||
__global ushort * emap,
|
||||
__global const int * tile_offset,
|
||||
__global int * slot_counter,
|
||||
int N,
|
||||
int topK,
|
||||
uint n_experts
|
||||
) {
|
||||
uint n = get_global_id(0);
|
||||
uint k = get_global_id(1);
|
||||
|
||||
if (n >= N || k >= topK) {
|
||||
return;
|
||||
}
|
||||
|
||||
int val = input[n * n_experts + k];
|
||||
|
||||
int local_slot = atomic_inc(&slot_counter[val]);
|
||||
|
||||
int tile_idx = tile_offset[val] + (local_slot / 32);
|
||||
int lane = local_slot % 32;
|
||||
int out_pos = tile_idx * 32 + lane;
|
||||
|
||||
post_router[out_pos] = n * topK + k;
|
||||
emap[tile_idx] = val;
|
||||
}
|
||||
|
||||
__kernel void kernel_moe_fill(
|
||||
__global int * post_router,
|
||||
__global int * total_tiles,
|
||||
int tile_size
|
||||
) {
|
||||
int tile_id = get_global_id(0);
|
||||
int vec_id_in_tile = get_global_id(1);
|
||||
|
||||
if (tile_id < total_tiles[0]) {
|
||||
post_router[tile_id * tile_size + vec_id_in_tile] = 0xFFFFFFFF;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "virtgpu-shm.h"
|
||||
|
||||
#include "virtgpu.h"
|
||||
#include "ggml-remoting.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "virtgpu.h"
|
||||
#include "ggml-remoting.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
|
||||
@@ -18,8 +18,6 @@
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "ggml-remoting.h"
|
||||
|
||||
#define VIRGL_RENDERER_UNSTABLE_APIS 1
|
||||
#include "apir_hw.h"
|
||||
#include <drm/virtgpu_drm.h>
|
||||
|
||||
@@ -1779,12 +1779,12 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_mul_mat_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.use_subgroup_matrix = context.supports_subgroup_matrix;
|
||||
|
||||
auto it = mul_mat_fast_pipelines.find(key);
|
||||
@@ -2143,6 +2143,9 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
// variant suffix for src1 type
|
||||
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines);
|
||||
|
||||
|
||||
@@ -55,8 +55,13 @@
|
||||
|
||||
uint64_t ggml_graph_next_uid(void) {
|
||||
#ifdef _MSC_VER
|
||||
#if defined(_WIN32)
|
||||
static volatile LONG counter = 1;
|
||||
return (uint64_t) InterlockedIncrement(&counter) - 1;
|
||||
#else
|
||||
static volatile long long counter = 1;
|
||||
return (uint64_t) _InterlockedIncrement64(&counter) - 1;
|
||||
#endif
|
||||
#else
|
||||
static uint64_t counter = 1;
|
||||
return __atomic_fetch_add(&counter, 1, __ATOMIC_RELAXED);
|
||||
|
||||
@@ -1 +1 @@
|
||||
387fa29fbbf3149f06a631c7850b6c35c24b0232
|
||||
19eac6f0edaf285506eb6228d31bb9caeda9aba1
|
||||
|
||||
@@ -2253,6 +2253,28 @@ public:
|
||||
llama_io_write_buffer(
|
||||
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
|
||||
~llama_io_write_buffer() {
|
||||
#if 1
|
||||
// TODO: add backend support to batch tensor_get? or some other way to speed this up
|
||||
for (const auto & info : winfos) {
|
||||
ggml_backend_tensor_get(info.tensor, info.ptr, info.offset, info.size);
|
||||
}
|
||||
#else
|
||||
// flush the writes asynchronously
|
||||
// this helps on Macs, but on other devices - it does not. just an example
|
||||
std::vector<std::future<void>> futures;
|
||||
futures.reserve(winfos.size());
|
||||
for (const auto & info : winfos) {
|
||||
futures.push_back(std::async(std::launch::async, [info]() {
|
||||
ggml_backend_tensor_get(info.tensor, info.ptr, info.offset, info.size);
|
||||
}));
|
||||
}
|
||||
for (auto & f : futures) {
|
||||
f.wait();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void write(const void * src, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
@@ -2267,7 +2289,10 @@ public:
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
ggml_backend_tensor_get(tensor, ptr, offset, size);
|
||||
|
||||
// save the write for later during destruction
|
||||
winfos.push_back({tensor, ptr, size, offset});
|
||||
|
||||
ptr += size;
|
||||
size_written += size;
|
||||
buf_size -= size;
|
||||
@@ -2281,25 +2306,48 @@ private:
|
||||
uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_written = 0;
|
||||
|
||||
struct write_info {
|
||||
const ggml_tensor * tensor;
|
||||
uint8_t * ptr;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
};
|
||||
std::vector<write_info> winfos;
|
||||
};
|
||||
|
||||
class llama_io_read_buffer : public llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
|
||||
const uint8_t * read(size_t size) override {
|
||||
const uint8_t * base_ptr = ptr;
|
||||
~llama_io_read_buffer() {
|
||||
// flush the reads
|
||||
for (const auto & info : rinfos) {
|
||||
ggml_backend_tensor_set(info.tensor, info.ptr, info.offset, info.size);
|
||||
}
|
||||
}
|
||||
|
||||
void read(void * dst, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
memcpy(dst, ptr, size);
|
||||
ptr += size;
|
||||
size_read += size;
|
||||
buf_size -= size;
|
||||
return base_ptr;
|
||||
}
|
||||
|
||||
void read_to(void * dst, size_t size) override {
|
||||
memcpy(dst, read(size), size);
|
||||
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
|
||||
// save for later during destruction
|
||||
rinfos.push_back({tensor, ptr, size, offset});
|
||||
|
||||
ptr += size;
|
||||
size_read += size;
|
||||
buf_size -= size;
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
@@ -2310,6 +2358,14 @@ private:
|
||||
const uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_read = 0;
|
||||
|
||||
struct read_info {
|
||||
ggml_tensor * tensor;
|
||||
const uint8_t * ptr;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
};
|
||||
std::vector<read_info> rinfos;
|
||||
};
|
||||
|
||||
class llama_io_write_file : public llama_io_write_i {
|
||||
@@ -2341,15 +2397,15 @@ class llama_io_read_file : public llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_file(llama_file * f) : file(f) {}
|
||||
|
||||
void read_to(void * dst, size_t size) override {
|
||||
void read(void * dst, size_t size) override {
|
||||
file->read_raw(dst, size);
|
||||
size_read += size;
|
||||
}
|
||||
|
||||
const uint8_t * read(size_t size) override {
|
||||
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
temp_buffer.resize(size);
|
||||
read_to(temp_buffer.data(), size);
|
||||
return temp_buffer.data();
|
||||
read(temp_buffer.data(), size);
|
||||
ggml_backend_tensor_set(tensor, temp_buffer.data(), offset, size);
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "llama-io.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
void llama_io_write_i::write_string(const std::string & str) {
|
||||
uint32_t str_size = str.size();
|
||||
|
||||
@@ -9,7 +11,10 @@ void llama_io_write_i::write_string(const std::string & str) {
|
||||
|
||||
void llama_io_read_i::read_string(std::string & str) {
|
||||
uint32_t str_size;
|
||||
read_to(&str_size, sizeof(str_size));
|
||||
read(&str_size, sizeof(str_size));
|
||||
|
||||
str.assign((const char *) read(str_size), str_size);
|
||||
std::vector<char> buf(str_size);
|
||||
read(buf.data(), str_size);
|
||||
|
||||
str.assign(buf.data(), str_size);
|
||||
}
|
||||
|
||||
@@ -25,8 +25,8 @@ public:
|
||||
llama_io_read_i() = default;
|
||||
virtual ~llama_io_read_i() = default;
|
||||
|
||||
virtual const uint8_t * read(size_t size) = 0;
|
||||
virtual void read_to(void * dst, size_t size) = 0;
|
||||
virtual void read(void * dst, size_t size) = 0;
|
||||
virtual void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0;
|
||||
|
||||
// bytes read so far
|
||||
virtual size_t n_bytes() = 0;
|
||||
|
||||
@@ -1900,14 +1900,14 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
|
||||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||
|
||||
uint32_t n_stream_cur;
|
||||
io.read_to(&n_stream_cur, sizeof(n_stream_cur));
|
||||
io.read(&n_stream_cur, sizeof(n_stream_cur));
|
||||
if (n_stream_cur != n_stream) {
|
||||
throw std::runtime_error("n_stream mismatch");
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
uint32_t cell_count;
|
||||
io.read_to(&cell_count, sizeof(cell_count));
|
||||
io.read(&cell_count, sizeof(cell_count));
|
||||
|
||||
if (cell_count == 0) {
|
||||
continue;
|
||||
@@ -2082,8 +2082,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
io.read(&pos, sizeof(pos));
|
||||
io.read(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id != 1) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||
@@ -2092,7 +2092,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
if (hparams.n_pos_per_embd() > 1) {
|
||||
llama_kv_cell_ext ext;
|
||||
io.read_to(&ext, sizeof(ext));
|
||||
io.read(&ext, sizeof(ext));
|
||||
|
||||
ubatch.pos[i + ubatch.n_tokens] = ext.y;
|
||||
ubatch.pos[i + ubatch.n_tokens*2] = ext.x;
|
||||
@@ -2101,7 +2101,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
||||
// read the sequence id, but directly discard it - we will use dest_seq_id instead
|
||||
{
|
||||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
io.read(&seq_id, sizeof(seq_id));
|
||||
}
|
||||
|
||||
ubatch.pos[i] = pos;
|
||||
@@ -2143,20 +2143,20 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
io.read(&pos, sizeof(pos));
|
||||
io.read(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
cells.pos_set(i, pos);
|
||||
|
||||
if (hparams.n_pos_per_embd() > 1) {
|
||||
llama_kv_cell_ext ext;
|
||||
io.read_to(&ext, sizeof(ext));
|
||||
io.read(&ext, sizeof(ext));
|
||||
cells.ext_set(i, ext);
|
||||
}
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
io.read(&seq_id, sizeof(seq_id));
|
||||
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
|
||||
@@ -2189,8 +2189,8 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
uint32_t v_trans;
|
||||
uint32_t n_layer;
|
||||
|
||||
io.read_to(&v_trans, sizeof(v_trans));
|
||||
io.read_to(&n_layer, sizeof(n_layer));
|
||||
io.read(&v_trans, sizeof(v_trans));
|
||||
io.read(&n_layer, sizeof(n_layer));
|
||||
|
||||
if (n_layer != layers.size()) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
|
||||
@@ -2217,7 +2217,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read type of key
|
||||
int32_t k_type_i_ref;
|
||||
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
||||
io.read(&k_type_i_ref, sizeof(k_type_i_ref));
|
||||
const int32_t k_type_i = (int32_t) k->type;
|
||||
if (k_type_i != k_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
||||
@@ -2226,7 +2226,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read row size of key
|
||||
uint64_t k_size_row_ref;
|
||||
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
||||
io.read(&k_size_row_ref, sizeof(k_size_row_ref));
|
||||
const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
|
||||
if (k_size_row != k_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
||||
@@ -2236,13 +2236,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
if (cell_count) {
|
||||
if (sinfo.is_contiguous()) {
|
||||
// Fast path: contiguous cells, single memcpy
|
||||
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
|
||||
io.read_tensor(k, sinfo.head() * k_size_row, cell_count * k_size_row);
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
const void * src = io.read(cell_count * k_size_row);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
|
||||
ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
|
||||
io.read_tensor(k, dst_offset, k_size_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2261,7 +2260,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
const int32_t v_type_i = (int32_t) v->type;
|
||||
if (v_type_i != v_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||
@@ -2270,7 +2269,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read row size of value
|
||||
uint64_t v_size_row_ref;
|
||||
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
||||
io.read(&v_size_row_ref, sizeof(v_size_row_ref));
|
||||
const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
|
||||
if (v_size_row != v_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
||||
@@ -2280,13 +2279,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
if (cell_count) {
|
||||
if (sinfo.is_contiguous()) {
|
||||
// Fast path: contiguous cells, single memcpy
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
|
||||
io.read_tensor(v, sinfo.head() * v_size_row, cell_count * v_size_row);
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
const void * src = io.read(cell_count * v_size_row);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
|
||||
ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
|
||||
io.read_tensor(v, dst_offset, v_size_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2305,7 +2303,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
const int32_t v_type_i = (int32_t) v->type;
|
||||
if (v_type_i != v_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||
@@ -2314,7 +2312,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read element size of value
|
||||
uint32_t v_size_el_ref;
|
||||
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
||||
io.read(&v_size_el_ref, sizeof(v_size_el_ref));
|
||||
const size_t v_size_el = ggml_type_size(v->type);
|
||||
if (v_size_el != v_size_el_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
||||
@@ -2323,7 +2321,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read GQA embedding size
|
||||
uint32_t n_embd_v_gqa_ref;
|
||||
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
||||
io.read(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
||||
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
||||
return false;
|
||||
@@ -2335,15 +2333,14 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
const uint32_t h = sinfo.head();
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (h + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
io.read_tensor(v, dst_offset, cell_count * v_size_el);
|
||||
}
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const void * src = io.read(cell_count * v_size_el);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
|
||||
io.read_tensor(v, dst_offset, v_size_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -743,7 +743,7 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
uint32_t cell_count;
|
||||
io.read_to(&cell_count, sizeof(cell_count));
|
||||
io.read(&cell_count, sizeof(cell_count));
|
||||
|
||||
bool res = true;
|
||||
|
||||
@@ -879,8 +879,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
io.read(&pos, sizeof(pos));
|
||||
io.read(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id != 0) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||
@@ -920,14 +920,14 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
io.read(&pos, sizeof(pos));
|
||||
io.read(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
cell.pos = pos;
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
io.read(&seq_id, sizeof(seq_id));
|
||||
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max);
|
||||
@@ -961,8 +961,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||
uint32_t s_trans;
|
||||
uint32_t n_layer;
|
||||
io.read_to(&s_trans, sizeof(s_trans));
|
||||
io.read_to(&n_layer, sizeof(n_layer));
|
||||
io.read(&s_trans, sizeof(s_trans));
|
||||
io.read(&n_layer, sizeof(n_layer));
|
||||
|
||||
if (n_layer != hparams.n_layer) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
|
||||
@@ -984,7 +984,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read type of key
|
||||
int32_t r_type_i_ref;
|
||||
io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
|
||||
io.read(&r_type_i_ref, sizeof(r_type_i_ref));
|
||||
const int32_t r_type_i = (int32_t) r_l[il]->type;
|
||||
if (r_type_i != r_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
|
||||
@@ -993,7 +993,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read row size of key
|
||||
uint64_t r_size_row_ref;
|
||||
io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
|
||||
io.read(&r_size_row_ref, sizeof(r_size_row_ref));
|
||||
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||
if (r_size_row != r_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
|
||||
@@ -1002,7 +1002,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the keys for the whole cell range
|
||||
ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
|
||||
io.read_tensor(r_l[il], head * r_size_row, cell_count * r_size_row);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1013,7 +1013,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read type of value
|
||||
int32_t s_type_i_ref;
|
||||
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
|
||||
if (s_type_i != s_type_i_ref) {
|
||||
@@ -1023,7 +1023,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read row size of value
|
||||
uint64_t s_size_row_ref;
|
||||
io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
|
||||
io.read(&s_size_row_ref, sizeof(s_size_row_ref));
|
||||
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
if (s_size_row != s_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
||||
@@ -1032,7 +1032,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the values for the whole cell range
|
||||
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
|
||||
io.read_tensor(s_l[il], head * s_size_row, cell_count * s_size_row);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -1045,7 +1045,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read type of value
|
||||
int32_t s_type_i_ref;
|
||||
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
if (s_type_i != s_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
||||
@@ -1054,7 +1054,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read element size of value
|
||||
uint32_t s_size_el_ref;
|
||||
io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
|
||||
io.read(&s_size_el_ref, sizeof(s_size_el_ref));
|
||||
const size_t s_size_el = ggml_type_size(s_l[il]->type);
|
||||
if (s_size_el != s_size_el_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
|
||||
@@ -1063,7 +1063,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read state embedding size
|
||||
uint32_t n_embd_s_ref;
|
||||
io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
|
||||
io.read(&n_embd_s_ref, sizeof(n_embd_s_ref));
|
||||
if (n_embd_s != n_embd_s_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
|
||||
return false;
|
||||
@@ -1073,7 +1073,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
// For each row in the transposed matrix, read the values for the whole cell range
|
||||
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
||||
const size_t dst_offset = (head + j * size) * s_size_el;
|
||||
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
|
||||
io.read_tensor(s_l[il], dst_offset, cell_count * s_size_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1994,7 +1994,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
}
|
||||
}
|
||||
|
||||
if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) {
|
||||
if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false)) {
|
||||
// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
||||
// cancel the factor from the convert script
|
||||
hparams.rope_yarn_log_mul /= 0.1f;
|
||||
@@ -2868,7 +2868,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
|
||||
|
||||
hparams.f_attn_temp_offset = 0.0f;
|
||||
|
||||
|
||||
@@ -683,9 +683,9 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_mod
|
||||
LLAMA_LOG_WARN("%s: %-36s - applying manual override: %s -> %s\n",
|
||||
__func__, tensor_name.c_str(), ggml_type_name(new_type), ggml_type_name(qtype));
|
||||
new_type = qtype;
|
||||
manual = true;
|
||||
break;
|
||||
}
|
||||
manual = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -36,7 +36,7 @@ using json = nlohmann::ordered_json;
|
||||
|
||||
constexpr int HTTP_POLLING_SECONDS = 1;
|
||||
|
||||
static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) {
|
||||
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) {
|
||||
if (pos_min == -1) {
|
||||
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
|
||||
}
|
||||
@@ -46,19 +46,15 @@ static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int i
|
||||
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto cur = server_prompt_checkpoint {
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.n_tokens = */ n_tokens,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
};
|
||||
ckpt.pos_min = pos_min;
|
||||
ckpt.pos_max = pos_max;
|
||||
ckpt.n_tokens = n_tokens;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != checkpoint_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
||||
@@ -364,7 +360,12 @@ struct server_slot {
|
||||
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
const auto n_tokens = prompt.tokens.size();
|
||||
|
||||
spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens);
|
||||
|
||||
//const int64_t t_total = ggml_time_us() - t_start;
|
||||
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
|
||||
|
||||
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
|
||||
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
|
||||
@@ -1836,7 +1837,8 @@ private:
|
||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
const auto & cur = slot.prompt.checkpoints.emplace_back(server_get_checkpoint(ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max));
|
||||
auto & cur = slot.prompt.checkpoints.emplace_back();
|
||||
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
|
||||
|
||||
SLT_WRN(slot,
|
||||
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<script lang="ts">
|
||||
import * as Tooltip from '../src/lib/components/ui/tooltip';
|
||||
import * as Tooltip from '../../src/lib/components/ui/tooltip';
|
||||
|
||||
interface Props {
|
||||
children: any;
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { Preview } from '@storybook/sveltekit';
|
||||
import '../src/app.css';
|
||||
import ModeWatcherDecorator from './ModeWatcherDecorator.svelte';
|
||||
import TooltipProviderDecorator from './TooltipProviderDecorator.svelte';
|
||||
import ModeWatcherDecorator from './decorators/ModeWatcherDecorator.svelte';
|
||||
import TooltipProviderDecorator from './decorators/TooltipProviderDecorator.svelte';
|
||||
|
||||
const preview: Preview = {
|
||||
parameters: {
|
||||
|
||||
6
tools/server/webui/package-lock.json
generated
6
tools/server/webui/package-lock.json
generated
@@ -3640,9 +3640,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/bits-ui": {
|
||||
"version": "2.17.3",
|
||||
"resolved": "https://registry.npmjs.org/bits-ui/-/bits-ui-2.17.3.tgz",
|
||||
"integrity": "sha512-Bef41uY9U2jaBJHPhcPvmBNkGec5Wx2z6eioDsTmsaR2vH4QoaOcPi75gzCG3+/2TNr6v/qBwzgWNPYCxNtrEA==",
|
||||
"version": "2.18.0",
|
||||
"resolved": "https://registry.npmjs.org/bits-ui/-/bits-ui-2.18.0.tgz",
|
||||
"integrity": "sha512-GLOBZRVy3hxNHIQ2MpD/+5aK9KcBFZRhUJtZ1UDABXdlVR4K6zFpgt4T+Rwuhf2sQzlc6yK1q/DprHPjwT4Pjw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
||||
2
tools/server/webui/src/app.d.ts
vendored
2
tools/server/webui/src/app.d.ts
vendored
@@ -28,7 +28,6 @@ import type {
|
||||
ApiRouterModelsUnloadResponse,
|
||||
// Chat types
|
||||
ChatAttachmentDisplayItem,
|
||||
ChatAttachmentPreviewItem,
|
||||
ChatMessageType,
|
||||
ChatRole,
|
||||
ChatUploadedFile,
|
||||
@@ -92,7 +91,6 @@ declare global {
|
||||
ApiRouterModelsUnloadResponse,
|
||||
// Chat types
|
||||
ChatAttachmentDisplayItem,
|
||||
ChatAttachmentPreviewItem,
|
||||
ChatMessagePromptProgress,
|
||||
ChatMessageSiblingInfo,
|
||||
ChatMessageTimings,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { isElementInViewport } from '$lib/utils/viewport';
|
||||
|
||||
/**
|
||||
* Svelte action that fades in an element when it enters the viewport.
|
||||
* Uses IntersectionObserver for efficient viewport detection.
|
||||
@@ -12,17 +14,8 @@ export function fadeInView(
|
||||
) {
|
||||
const { duration = 300, y = 0, skipIfVisible = false } = options;
|
||||
|
||||
if (skipIfVisible) {
|
||||
const rect = node.getBoundingClientRect();
|
||||
const isAlreadyVisible =
|
||||
rect.top < window.innerHeight &&
|
||||
rect.bottom > 0 &&
|
||||
rect.left < window.innerWidth &&
|
||||
rect.right > 0;
|
||||
|
||||
if (isAlreadyVisible) {
|
||||
return;
|
||||
}
|
||||
if (skipIfVisible && isElementInViewport(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
node.style.opacity = '0';
|
||||
|
||||
11
tools/server/webui/src/lib/components/app/SKILL.md
Normal file
11
tools/server/webui/src/lib/components/app/SKILL.md
Normal file
@@ -0,0 +1,11 @@
|
||||
---
|
||||
name: app
|
||||
description: Opinionated app components building on top of ./ui primitives
|
||||
---
|
||||
|
||||
- Can include business logic and state management
|
||||
- Can include data fetching and caching logic
|
||||
- Should use original spelling for HTML-native events and `camelCase` for custom events
|
||||
- Props and markup attributes should be listed alphabetically
|
||||
- Use JS Objects and Arrays for CSS classes and styles when they are dynamic
|
||||
- Whenever there can be repetition in the component's markup, if it's too small to be decoupled as a separate component — use Svelte 5's `{#snippet}` + `{@render}`
|
||||
@@ -5,15 +5,16 @@
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
icon: Component;
|
||||
tooltip: string;
|
||||
variant?: ButtonVariant;
|
||||
size?: ButtonSize;
|
||||
iconSize?: string;
|
||||
ariaLabel?: string;
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
icon: Component;
|
||||
iconSize?: string;
|
||||
onclick: (e?: MouseEvent) => void;
|
||||
'aria-label'?: string;
|
||||
size?: ButtonSize;
|
||||
stopPropagationOnClick?: boolean;
|
||||
tooltip: string;
|
||||
variant?: ButtonVariant;
|
||||
tooltipSide?: TooltipSide;
|
||||
}
|
||||
|
||||
@@ -26,8 +27,9 @@
|
||||
disabled = false,
|
||||
iconSize = 'h-3 w-3',
|
||||
tooltipSide = TooltipSide.TOP,
|
||||
stopPropagationOnClick = false,
|
||||
onclick,
|
||||
'aria-label': ariaLabel
|
||||
ariaLabel
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
@@ -37,13 +39,18 @@
|
||||
{variant}
|
||||
{size}
|
||||
{disabled}
|
||||
{onclick}
|
||||
onclick={(e: MouseEvent) => {
|
||||
if (stopPropagationOnClick) e.stopPropagation();
|
||||
|
||||
onclick?.(e);
|
||||
}}
|
||||
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
|
||||
aria-label={ariaLabel || tooltip}
|
||||
>
|
||||
{@const IconComponent = icon}
|
||||
|
||||
<IconComponent class={iconSize} />
|
||||
{#if icon}
|
||||
{@const IconComponent = icon}
|
||||
<IconComponent class={iconSize} />
|
||||
{/if}
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
<script lang="ts">
|
||||
import { Copy } from '@lucide/svelte';
|
||||
import { copyToClipboard } from '$lib/utils';
|
||||
import ActionIcon from './ActionIcon.svelte';
|
||||
|
||||
interface Props {
|
||||
ariaLabel?: string;
|
||||
canCopy?: boolean;
|
||||
text: string;
|
||||
}
|
||||
|
||||
let { ariaLabel = 'Copy to clipboard', canCopy = true, text }: Props = $props();
|
||||
export let ariaLabel: string = 'Copy to clipboard';
|
||||
export let canCopy: boolean = true;
|
||||
export let text: string;
|
||||
</script>
|
||||
|
||||
<Copy
|
||||
class="h-3 w-3 flex-shrink-0 cursor-{canCopy ? 'pointer' : 'not-allowed'}"
|
||||
aria-label={ariaLabel}
|
||||
<ActionIcon
|
||||
icon={Copy}
|
||||
tooltip={ariaLabel}
|
||||
iconSize="h-4 w-4"
|
||||
disabled={!canCopy}
|
||||
onclick={() => canCopy && copyToClipboard(text)}
|
||||
/>
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
onRemove?: (id: string) => void;
|
||||
class?: string;
|
||||
iconSize?: number;
|
||||
}
|
||||
|
||||
let { id, onRemove, class: className = '', iconSize = 3 }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="icon-sm"
|
||||
class="bg-white/20 p-0 hover:bg-white/30 {className}"
|
||||
onclick={(e: MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
onRemove?.(id);
|
||||
}}
|
||||
aria-label="Remove file"
|
||||
>
|
||||
<X class="h-{iconSize} w-{iconSize}" />
|
||||
</Button>
|
||||
@@ -1,46 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { Eye } from '@lucide/svelte';
|
||||
import { ActionIconCopyToClipboard } from '$lib/components/app';
|
||||
import { FileTypeText } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
code: string;
|
||||
language: string;
|
||||
disabled?: boolean;
|
||||
onPreview?: (code: string, language: string) => void;
|
||||
}
|
||||
|
||||
let { code, language, disabled = false, onPreview }: Props = $props();
|
||||
|
||||
const showPreview = $derived(language?.toLowerCase() === FileTypeText.HTML);
|
||||
|
||||
function handlePreview() {
|
||||
if (disabled) return;
|
||||
onPreview?.(code, language);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="code-block-actions">
|
||||
<div class="copy-code-btn" class:opacity-50={disabled} class:!cursor-not-allowed={disabled}>
|
||||
<ActionIconCopyToClipboard
|
||||
text={code}
|
||||
canCopy={!disabled}
|
||||
ariaLabel={disabled ? 'Code incomplete' : 'Copy code'}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{#if showPreview}
|
||||
<button
|
||||
class="preview-code-btn"
|
||||
class:opacity-50={disabled}
|
||||
class:!cursor-not-allowed={disabled}
|
||||
title={disabled ? 'Code incomplete' : 'Preview code'}
|
||||
aria-label="Preview code"
|
||||
aria-disabled={disabled}
|
||||
type="button"
|
||||
onclick={handlePreview}
|
||||
>
|
||||
<Eye size={16} />
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -9,11 +9,5 @@
|
||||
/** Styled icon button for action triggers with tooltip. */
|
||||
export { default as ActionIcon } from './ActionIcon.svelte';
|
||||
|
||||
/** Code block actions component (copy, preview). */
|
||||
export { default as ActionIconsCodeBlock } from './ActionIconsCodeBlock.svelte';
|
||||
|
||||
/** Copy-to-clipboard icon button with click handler. */
|
||||
/** Copy-to-clipboard icon button with clipboard logic. */
|
||||
export { default as ActionIconCopyToClipboard } from './ActionIconCopyToClipboard.svelte';
|
||||
|
||||
/** Remove/delete icon button with X icon. */
|
||||
export { default as ActionIconRemove } from './ActionIconRemove.svelte';
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
<script lang="ts">
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
import type { Snippet } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -13,10 +12,10 @@
|
||||
</script>
|
||||
|
||||
<button
|
||||
class={cn(
|
||||
class={[
|
||||
'inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75',
|
||||
className
|
||||
)}
|
||||
]}
|
||||
{onclick}
|
||||
>
|
||||
{#if icon}
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { ModelModality } from '$lib/enums';
|
||||
import { MODALITY_ICONS, MODALITY_LABELS } from '$lib/constants';
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
|
||||
type DisplayableModality = ModelModality.VISION | ModelModality.AUDIO;
|
||||
|
||||
interface Props {
|
||||
modalities: ModelModality[];
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { modalities, class: className = '' }: Props = $props();
|
||||
|
||||
// Filter to only modalities that have icons (VISION, AUDIO)
|
||||
const displayableModalities = $derived(
|
||||
modalities.filter(
|
||||
(m): m is DisplayableModality => m === ModelModality.VISION || m === ModelModality.AUDIO
|
||||
)
|
||||
);
|
||||
</script>
|
||||
|
||||
{#each displayableModalities as modality, index (index)}
|
||||
{@const IconComponent = MODALITY_ICONS[modality]}
|
||||
{@const label = MODALITY_LABELS[modality]}
|
||||
|
||||
<span
|
||||
class={cn(
|
||||
'inline-flex items-center gap-1 rounded-md bg-muted px-2 py-1 text-xs font-medium',
|
||||
className
|
||||
)}
|
||||
>
|
||||
{#if IconComponent}
|
||||
<IconComponent class="h-3 w-3" />
|
||||
{/if}
|
||||
|
||||
{label}
|
||||
</span>
|
||||
{/each}
|
||||
@@ -0,0 +1,32 @@
|
||||
<script lang="ts">
|
||||
import { Eye, Mic } from '@lucide/svelte';
|
||||
import { ModelModality } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
modalities: ModelModality[];
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { modalities, class: className = '' }: Props = $props();
|
||||
</script>
|
||||
|
||||
{#each modalities as modality (modality)}
|
||||
{#if modality === ModelModality.VISION || modality === ModelModality.AUDIO}
|
||||
<span
|
||||
class={[
|
||||
'inline-flex items-center gap-1 rounded-md bg-muted px-2 py-1 text-xs font-medium',
|
||||
className
|
||||
]}
|
||||
>
|
||||
{#if modality === ModelModality.VISION}
|
||||
<Eye class="h-3 w-3" />
|
||||
|
||||
Vision
|
||||
{:else}
|
||||
<Mic class="h-3 w-3" />
|
||||
|
||||
Audio
|
||||
{/if}
|
||||
</span>
|
||||
{/if}
|
||||
{/each}
|
||||
@@ -6,11 +6,8 @@
|
||||
*
|
||||
*/
|
||||
|
||||
/** Badge displaying chat statistics (tokens, timing). */
|
||||
export { default as BadgeChatStatistic } from './BadgeChatStatistic.svelte';
|
||||
|
||||
/** Generic info badge with optional tooltip and click handler. */
|
||||
export { default as BadgeInfo } from './BadgeInfo.svelte';
|
||||
|
||||
/** Badge indicating model modality (vision, audio, tools). */
|
||||
export { default as BadgeModality } from './BadgeModality.svelte';
|
||||
export { default as BadgesModality } from './BadgesModality.svelte';
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Alert from '$lib/components/ui/alert';
|
||||
import { SyntaxHighlightedCode } from '$lib/components/app';
|
||||
import { FileText, Image, Music, FileIcon, Eye, Info } from '@lucide/svelte';
|
||||
import {
|
||||
isTextFile,
|
||||
isImageFile,
|
||||
isPdfFile,
|
||||
isAudioFile,
|
||||
getLanguageFromFilename,
|
||||
createBase64DataUrl
|
||||
} from '$lib/utils';
|
||||
import { convertPDFToImage } from '$lib/utils/browser-only';
|
||||
import { modelsStore } from '$lib/stores/models.svelte';
|
||||
|
||||
interface Props {
|
||||
// Either an uploaded file or a stored attachment
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
// For uploaded files
|
||||
preview?: string;
|
||||
name?: string;
|
||||
textContent?: string;
|
||||
// For checking vision modality
|
||||
activeModelId?: string;
|
||||
}
|
||||
|
||||
let { uploadedFile, attachment, preview, name, textContent, activeModelId }: Props = $props();
|
||||
|
||||
let hasVisionModality = $derived(
|
||||
activeModelId ? modelsStore.modelSupportsVision(activeModelId) : false
|
||||
);
|
||||
|
||||
let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File');
|
||||
|
||||
// Determine file type from uploaded file or attachment
|
||||
let isAudio = $derived(isAudioFile(attachment, uploadedFile));
|
||||
let isImage = $derived(isImageFile(attachment, uploadedFile));
|
||||
let isPdf = $derived(isPdfFile(attachment, uploadedFile));
|
||||
let isText = $derived(isTextFile(attachment, uploadedFile));
|
||||
|
||||
let displayPreview = $derived(
|
||||
uploadedFile?.preview ||
|
||||
(isImage && attachment && 'base64Url' in attachment ? attachment.base64Url : preview)
|
||||
);
|
||||
|
||||
let displayTextContent = $derived(
|
||||
uploadedFile?.textContent ||
|
||||
(attachment && 'content' in attachment ? attachment.content : textContent)
|
||||
);
|
||||
|
||||
let language = $derived(getLanguageFromFilename(displayName));
|
||||
|
||||
let IconComponent = $derived(() => {
|
||||
if (isImage) return Image;
|
||||
if (isText || isPdf) return FileText;
|
||||
if (isAudio) return Music;
|
||||
|
||||
return FileIcon;
|
||||
});
|
||||
|
||||
let pdfViewMode = $state<'text' | 'pages'>('pages');
|
||||
|
||||
let pdfImages = $state<string[]>([]);
|
||||
|
||||
let pdfImagesLoading = $state(false);
|
||||
|
||||
let pdfImagesError = $state<string | null>(null);
|
||||
|
||||
async function loadPdfImages() {
|
||||
if (!isPdf || pdfImages.length > 0 || pdfImagesLoading) return;
|
||||
|
||||
pdfImagesLoading = true;
|
||||
pdfImagesError = null;
|
||||
|
||||
try {
|
||||
let file: File | null = null;
|
||||
|
||||
if (uploadedFile?.file) {
|
||||
file = uploadedFile.file;
|
||||
} else if (isPdf && attachment) {
|
||||
// Check if we have pre-processed images
|
||||
if (
|
||||
'images' in attachment &&
|
||||
attachment.images &&
|
||||
Array.isArray(attachment.images) &&
|
||||
attachment.images.length > 0
|
||||
) {
|
||||
pdfImages = attachment.images;
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert base64 back to File for processing
|
||||
if ('base64Data' in attachment && attachment.base64Data) {
|
||||
const base64Data = attachment.base64Data;
|
||||
const byteCharacters = atob(base64Data);
|
||||
const byteNumbers = new Array(byteCharacters.length);
|
||||
for (let i = 0; i < byteCharacters.length; i++) {
|
||||
byteNumbers[i] = byteCharacters.charCodeAt(i);
|
||||
}
|
||||
const byteArray = new Uint8Array(byteNumbers);
|
||||
file = new File([byteArray], displayName, { type: 'application/pdf' });
|
||||
}
|
||||
}
|
||||
|
||||
if (file) {
|
||||
pdfImages = await convertPDFToImage(file);
|
||||
} else {
|
||||
throw new Error('No PDF file available for conversion');
|
||||
}
|
||||
} catch (error) {
|
||||
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
|
||||
} finally {
|
||||
pdfImagesLoading = false;
|
||||
}
|
||||
}
|
||||
|
||||
export function reset() {
|
||||
pdfImages = [];
|
||||
pdfImagesLoading = false;
|
||||
pdfImagesError = null;
|
||||
pdfViewMode = 'pages';
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (isPdf && pdfViewMode === 'pages') {
|
||||
loadPdfImages();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="space-y-4">
|
||||
<div class="flex items-center justify-end gap-6">
|
||||
{#if isPdf}
|
||||
<div class="flex items-center gap-2">
|
||||
<Button
|
||||
variant={pdfViewMode === 'text' ? 'default' : 'outline'}
|
||||
size="sm"
|
||||
onclick={() => (pdfViewMode = 'text')}
|
||||
disabled={pdfImagesLoading}
|
||||
>
|
||||
<FileText class="mr-1 h-4 w-4" />
|
||||
|
||||
Text
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant={pdfViewMode === 'pages' ? 'default' : 'outline'}
|
||||
size="sm"
|
||||
onclick={() => {
|
||||
pdfViewMode = 'pages';
|
||||
loadPdfImages();
|
||||
}}
|
||||
disabled={pdfImagesLoading}
|
||||
>
|
||||
{#if pdfImagesLoading}
|
||||
<div
|
||||
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
|
||||
></div>
|
||||
{:else}
|
||||
<Eye class="mr-1 h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
Pages
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="flex-1 overflow-auto">
|
||||
{#if isImage && displayPreview}
|
||||
<div class="flex items-center justify-center">
|
||||
<img
|
||||
src={displayPreview}
|
||||
alt={displayName}
|
||||
class="max-h-full rounded-lg object-contain shadow-lg"
|
||||
/>
|
||||
</div>
|
||||
{:else if isPdf && pdfViewMode === 'pages'}
|
||||
{#if !hasVisionModality && activeModelId}
|
||||
<Alert.Root class="mb-4">
|
||||
<Info class="h-4 w-4" />
|
||||
<Alert.Title>Preview only</Alert.Title>
|
||||
<Alert.Description>
|
||||
<span class="inline-flex">
|
||||
The selected model does not support vision. Only the extracted
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<span class="mx-1 cursor-pointer underline" onclick={() => (pdfViewMode = 'text')}>
|
||||
text
|
||||
</span>
|
||||
will be sent to the model.
|
||||
</span>
|
||||
</Alert.Description>
|
||||
</Alert.Root>
|
||||
{/if}
|
||||
|
||||
{#if pdfImagesLoading}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<div
|
||||
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"
|
||||
></div>
|
||||
|
||||
<p class="text-muted-foreground">Converting PDF to images...</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else if pdfImagesError}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
|
||||
<p class="mb-4 text-muted-foreground">Failed to load PDF images</p>
|
||||
|
||||
<p class="text-sm text-muted-foreground">{pdfImagesError}</p>
|
||||
|
||||
<Button class="mt-4" onclick={() => (pdfViewMode = 'text')}>View as Text</Button>
|
||||
</div>
|
||||
</div>
|
||||
{:else if pdfImages.length > 0}
|
||||
<div class="max-h-[70vh] space-y-4 overflow-auto">
|
||||
{#each pdfImages as image, index (image)}
|
||||
<div class="text-center">
|
||||
<p class="mb-2 text-sm text-muted-foreground">Page {index + 1}</p>
|
||||
|
||||
<img
|
||||
src={image}
|
||||
alt="PDF Page {index + 1}"
|
||||
class="mx-auto max-w-full rounded-lg shadow-lg"
|
||||
/>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
|
||||
<p class="mb-4 text-muted-foreground">No PDF pages available</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent}
|
||||
<SyntaxHighlightedCode code={displayTextContent} {language} maxWidth="calc(69rem - 2rem)" />
|
||||
{:else if isAudio}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="w-full max-w-md text-center">
|
||||
<Music class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
|
||||
{#if uploadedFile?.preview}
|
||||
<audio controls class="mb-4 w-full" src={uploadedFile.preview}>
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
{:else if isAudio && attachment && 'mimeType' in attachment && 'base64Data' in attachment}
|
||||
<audio
|
||||
controls
|
||||
class="mb-4 w-full"
|
||||
src={createBase64DataUrl(attachment.mimeType, attachment.base64Data)}
|
||||
>
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
{:else}
|
||||
<p class="mb-4 text-muted-foreground">Audio preview not available</p>
|
||||
{/if}
|
||||
|
||||
<p class="text-sm text-muted-foreground">
|
||||
{displayName}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
{#if IconComponent}
|
||||
<IconComponent class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
{/if}
|
||||
|
||||
<p class="mb-4 text-muted-foreground">Preview not available for this file type</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -1,165 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { ActionIconRemove } from '$lib/components/app';
|
||||
import { formatFileSize, getFileTypeLabel, getPreviewText, isTextFile } from '$lib/utils';
|
||||
import { AttachmentType } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
id: string;
|
||||
onClick?: (event?: MouseEvent) => void;
|
||||
onRemove?: (id: string) => void;
|
||||
name: string;
|
||||
readonly?: boolean;
|
||||
size?: number;
|
||||
textContent?: string;
|
||||
// Either uploaded file or stored attachment
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
id,
|
||||
onClick,
|
||||
onRemove,
|
||||
name,
|
||||
readonly = false,
|
||||
size,
|
||||
textContent,
|
||||
uploadedFile,
|
||||
attachment
|
||||
}: Props = $props();
|
||||
|
||||
let isText = $derived(isTextFile(attachment, uploadedFile));
|
||||
|
||||
let fileTypeLabel = $derived.by(() => {
|
||||
if (uploadedFile?.type) {
|
||||
return getFileTypeLabel(uploadedFile.type);
|
||||
}
|
||||
|
||||
if (attachment) {
|
||||
if ('mimeType' in attachment && attachment.mimeType) {
|
||||
return getFileTypeLabel(attachment.mimeType);
|
||||
}
|
||||
|
||||
if (attachment.type) {
|
||||
return getFileTypeLabel(attachment.type);
|
||||
}
|
||||
}
|
||||
|
||||
return getFileTypeLabel(name);
|
||||
});
|
||||
|
||||
let pdfProcessingMode = $derived.by(() => {
|
||||
if (attachment?.type === AttachmentType.PDF) {
|
||||
const pdfAttachment = attachment as DatabaseMessageExtraPdfFile;
|
||||
|
||||
return pdfAttachment.processedAsImages ? 'Sent as Image' : 'Sent as Text';
|
||||
}
|
||||
return null;
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if isText}
|
||||
{#if readonly}
|
||||
<!-- Readonly mode (ChatMessage) -->
|
||||
<button
|
||||
class="cursor-pointer rounded-lg border border-border bg-muted p-3 transition-shadow hover:shadow-md {className} w-full max-w-2xl"
|
||||
onclick={onClick}
|
||||
aria-label={`Preview ${name}`}
|
||||
type="button"
|
||||
>
|
||||
<div class="flex items-start gap-3">
|
||||
<div class="flex min-w-0 flex-1 flex-col items-start text-left">
|
||||
<span class="w-full truncate text-sm font-medium text-foreground">{name}</span>
|
||||
|
||||
{#if size}
|
||||
<span class="text-xs text-muted-foreground">{formatFileSize(size)}</span>
|
||||
{/if}
|
||||
|
||||
{#if textContent}
|
||||
<div class="relative mt-2 w-full">
|
||||
<div
|
||||
class="overflow-hidden font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground"
|
||||
>
|
||||
{getPreviewText(textContent)}
|
||||
</div>
|
||||
|
||||
{#if textContent.length > 150}
|
||||
<div
|
||||
class="pointer-events-none absolute right-0 bottom-0 left-0 h-6 bg-gradient-to-t from-muted to-transparent"
|
||||
></div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
{:else}
|
||||
<!-- Non-readonly mode (ChatForm) -->
|
||||
<button
|
||||
class="group relative rounded-lg border border-border bg-muted p-3 {className} {textContent
|
||||
? 'max-h-24 max-w-72'
|
||||
: 'max-w-36'} cursor-pointer text-left"
|
||||
onclick={onClick}
|
||||
>
|
||||
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
|
||||
<ActionIconRemove {id} {onRemove} />
|
||||
</div>
|
||||
|
||||
<div class="pr-8">
|
||||
<span class="mb-3 block truncate text-sm font-medium text-foreground">{name}</span>
|
||||
|
||||
{#if textContent}
|
||||
<div class="relative">
|
||||
<div
|
||||
class="overflow-hidden font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground"
|
||||
style="max-height: 3rem; line-height: 1.2em;"
|
||||
>
|
||||
{getPreviewText(textContent)}
|
||||
</div>
|
||||
|
||||
{#if textContent.length > 150}
|
||||
<div
|
||||
class="pointer-events-none absolute right-0 bottom-0 left-0 h-4 bg-gradient-to-t from-muted to-transparent"
|
||||
></div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</button>
|
||||
{/if}
|
||||
{:else}
|
||||
<button
|
||||
class="group flex items-center gap-3 rounded-lg border border-border bg-muted p-3 {className} relative"
|
||||
onclick={onClick}
|
||||
>
|
||||
<div
|
||||
class="flex h-8 w-8 items-center justify-center rounded bg-primary/10 text-xs font-medium text-primary"
|
||||
>
|
||||
{fileTypeLabel}
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-0.5">
|
||||
<span
|
||||
class="max-w-24 truncate text-sm font-medium text-foreground {readonly
|
||||
? ''
|
||||
: 'group-hover:pr-6'} md:max-w-32"
|
||||
>
|
||||
{name}
|
||||
</span>
|
||||
|
||||
{#if pdfProcessingMode}
|
||||
<span class="text-left text-xs text-muted-foreground">{pdfProcessingMode}</span>
|
||||
{:else if size}
|
||||
<span class="text-left text-xs text-muted-foreground">{formatFileSize(size)}</span>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if !readonly}
|
||||
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
|
||||
<ActionIconRemove {id} {onRemove} />
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
@@ -1,287 +0,0 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
ChatAttachmentMcpPrompt,
|
||||
ChatAttachmentMcpResource,
|
||||
ChatAttachmentThumbnailImage,
|
||||
ChatAttachmentThumbnailFile,
|
||||
HorizontalScrollCarousel,
|
||||
DialogChatAttachmentPreview,
|
||||
DialogChatAttachmentsViewAll,
|
||||
DialogMcpResourcePreview
|
||||
} from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { AttachmentType } from '$lib/enums';
|
||||
import type {
|
||||
DatabaseMessageExtraMcpPrompt,
|
||||
DatabaseMessageExtraMcpResource,
|
||||
MCPResourceAttachment
|
||||
} from '$lib/types';
|
||||
import { getAttachmentDisplayItems } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
style?: string;
|
||||
// For ChatMessage - stored attachments
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
readonly?: boolean;
|
||||
// For ChatForm - pending uploads
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
// Image size customization
|
||||
imageClass?: string;
|
||||
imageHeight?: string;
|
||||
imageWidth?: string;
|
||||
// Limit display to single row with "+ X more" button
|
||||
limitToSingleRow?: boolean;
|
||||
// For vision modality check
|
||||
activeModelId?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
style = '',
|
||||
attachments = [],
|
||||
readonly = false,
|
||||
onFileRemove,
|
||||
uploadedFiles = $bindable([]),
|
||||
// Default to small size for form previews
|
||||
imageClass = '',
|
||||
imageHeight = 'h-24',
|
||||
imageWidth = 'w-auto',
|
||||
limitToSingleRow = false,
|
||||
activeModelId
|
||||
}: Props = $props();
|
||||
|
||||
let displayItems = $derived(getAttachmentDisplayItems({ uploadedFiles, attachments }));
|
||||
|
||||
let carouselRef: HorizontalScrollCarousel | undefined = $state();
|
||||
let isScrollable = $state(false);
|
||||
let previewDialogOpen = $state(false);
|
||||
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
|
||||
let mcpResourcePreviewOpen = $state(false);
|
||||
let mcpResourcePreviewExtra = $state<DatabaseMessageExtraMcpResource | null>(null);
|
||||
let showViewAll = $derived(limitToSingleRow && displayItems.length > 0 && isScrollable);
|
||||
let viewAllDialogOpen = $state(false);
|
||||
|
||||
function openPreview(item: ChatAttachmentDisplayItem, event?: MouseEvent) {
|
||||
event?.stopPropagation();
|
||||
event?.preventDefault();
|
||||
|
||||
previewItem = {
|
||||
uploadedFile: item.uploadedFile,
|
||||
attachment: item.attachment,
|
||||
preview: item.preview,
|
||||
name: item.name,
|
||||
size: item.size,
|
||||
textContent: item.textContent
|
||||
};
|
||||
previewDialogOpen = true;
|
||||
}
|
||||
|
||||
function openMcpResourcePreview(extra: DatabaseMessageExtraMcpResource) {
|
||||
mcpResourcePreviewExtra = extra;
|
||||
mcpResourcePreviewOpen = true;
|
||||
}
|
||||
|
||||
function toMcpResourceAttachment(
|
||||
extra: DatabaseMessageExtraMcpResource,
|
||||
id: string
|
||||
): MCPResourceAttachment {
|
||||
return {
|
||||
id,
|
||||
resource: {
|
||||
uri: extra.uri,
|
||||
name: extra.name,
|
||||
title: extra.name,
|
||||
serverName: extra.serverName
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (carouselRef && displayItems.length) {
|
||||
carouselRef.resetScroll();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if displayItems.length > 0}
|
||||
<div class={className} {style}>
|
||||
{#if limitToSingleRow}
|
||||
<HorizontalScrollCarousel
|
||||
bind:this={carouselRef}
|
||||
onScrollableChange={(scrollable) => (isScrollable = scrollable)}
|
||||
>
|
||||
{#each displayItems as item (item.id)}
|
||||
{#if item.isMcpPrompt}
|
||||
{@const mcpPrompt =
|
||||
item.attachment?.type === AttachmentType.MCP_PROMPT
|
||||
? (item.attachment as DatabaseMessageExtraMcpPrompt)
|
||||
: item.uploadedFile?.mcpPrompt
|
||||
? {
|
||||
type: AttachmentType.MCP_PROMPT as const,
|
||||
name: item.name,
|
||||
serverName: item.uploadedFile.mcpPrompt.serverName,
|
||||
promptName: item.uploadedFile.mcpPrompt.promptName,
|
||||
content: item.textContent ?? '',
|
||||
arguments: item.uploadedFile.mcpPrompt.arguments
|
||||
}
|
||||
: null}
|
||||
{#if mcpPrompt}
|
||||
<ChatAttachmentMcpPrompt
|
||||
class="max-w-[300px] min-w-[200px] flex-shrink-0 {limitToSingleRow
|
||||
? 'first:ml-4 last:mr-4'
|
||||
: ''}"
|
||||
prompt={mcpPrompt}
|
||||
{readonly}
|
||||
isLoading={item.isLoading}
|
||||
loadError={item.loadError}
|
||||
onRemove={onFileRemove ? () => onFileRemove(item.id) : undefined}
|
||||
/>
|
||||
{/if}
|
||||
{:else if item.isMcpResource && item.attachment?.type === AttachmentType.MCP_RESOURCE}
|
||||
{@const mcpResource = item.attachment as DatabaseMessageExtraMcpResource}
|
||||
|
||||
<ChatAttachmentMcpResource
|
||||
class="flex-shrink-0 {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
||||
attachment={toMcpResourceAttachment(mcpResource, item.id)}
|
||||
onClick={() => openMcpResourcePreview(mcpResource)}
|
||||
/>
|
||||
{:else if item.isImage && item.preview}
|
||||
<ChatAttachmentThumbnailImage
|
||||
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{:else}
|
||||
<ChatAttachmentThumbnailFile
|
||||
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
attachment={item.attachment}
|
||||
uploadedFile={item.uploadedFile}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/if}
|
||||
{/each}
|
||||
</HorizontalScrollCarousel>
|
||||
|
||||
{#if showViewAll}
|
||||
<div class="mt-2 -mr-2 flex justify-end px-4">
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 text-xs text-muted-foreground hover:text-foreground"
|
||||
onclick={() => (viewAllDialogOpen = true)}
|
||||
>
|
||||
View all ({displayItems.length})
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
{:else}
|
||||
<div class="flex flex-wrap items-start justify-end gap-3">
|
||||
{#each displayItems as item (item.id)}
|
||||
{#if item.isMcpPrompt}
|
||||
{@const mcpPrompt =
|
||||
item.attachment?.type === AttachmentType.MCP_PROMPT
|
||||
? (item.attachment as DatabaseMessageExtraMcpPrompt)
|
||||
: item.uploadedFile?.mcpPrompt
|
||||
? {
|
||||
type: AttachmentType.MCP_PROMPT as const,
|
||||
name: item.name,
|
||||
serverName: item.uploadedFile.mcpPrompt.serverName,
|
||||
promptName: item.uploadedFile.mcpPrompt.promptName,
|
||||
content: item.textContent ?? '',
|
||||
arguments: item.uploadedFile.mcpPrompt.arguments
|
||||
}
|
||||
: null}
|
||||
|
||||
{#if mcpPrompt}
|
||||
<ChatAttachmentMcpPrompt
|
||||
class="max-w-[300px] min-w-[200px]"
|
||||
prompt={mcpPrompt}
|
||||
{readonly}
|
||||
isLoading={item.isLoading}
|
||||
loadError={item.loadError}
|
||||
onRemove={onFileRemove ? () => onFileRemove(item.id) : undefined}
|
||||
/>
|
||||
{/if}
|
||||
{:else if item.isMcpResource && item.attachment?.type === AttachmentType.MCP_RESOURCE}
|
||||
{@const mcpResource = item.attachment as DatabaseMessageExtraMcpResource}
|
||||
|
||||
<ChatAttachmentMcpResource
|
||||
attachment={toMcpResourceAttachment(mcpResource, item.id)}
|
||||
onClick={() => openMcpResourcePreview(mcpResource)}
|
||||
/>
|
||||
{:else if item.isImage && item.preview}
|
||||
<ChatAttachmentThumbnailImage
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{:else}
|
||||
<ChatAttachmentThumbnailFile
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
attachment={item.attachment}
|
||||
uploadedFile={item.uploadedFile}
|
||||
onClick={(event?: MouseEvent) => openPreview(item, event)}
|
||||
/>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if previewItem}
|
||||
<DialogChatAttachmentPreview
|
||||
bind:open={previewDialogOpen}
|
||||
uploadedFile={previewItem.uploadedFile}
|
||||
attachment={previewItem.attachment}
|
||||
preview={previewItem.preview}
|
||||
name={previewItem.name}
|
||||
size={previewItem.size}
|
||||
textContent={previewItem.textContent}
|
||||
{activeModelId}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<DialogChatAttachmentsViewAll
|
||||
bind:open={viewAllDialogOpen}
|
||||
{uploadedFiles}
|
||||
{attachments}
|
||||
{readonly}
|
||||
{onFileRemove}
|
||||
imageHeight="h-64"
|
||||
{imageClass}
|
||||
{activeModelId}
|
||||
/>
|
||||
|
||||
{#if mcpResourcePreviewExtra}
|
||||
<DialogMcpResourcePreview bind:open={mcpResourcePreviewOpen} extra={mcpResourcePreviewExtra} />
|
||||
{/if}
|
||||
@@ -0,0 +1,119 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
ChatAttachmentsListItem,
|
||||
DialogChatAttachmentsPreview,
|
||||
DialogMcpResourcePreview,
|
||||
HorizontalScrollCarousel
|
||||
} from '$lib/components/app';
|
||||
import type { DatabaseMessageExtraMcpResource } from '$lib/types';
|
||||
import { getAttachmentDisplayItems, isMcpPrompt, isMcpResource } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
style?: string;
|
||||
// For ChatMessage - stored attachments
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
readonly?: boolean;
|
||||
// For ChatForm - pending uploads
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
// Image size customization
|
||||
imageClass?: string;
|
||||
imageHeight?: string;
|
||||
imageWidth?: string;
|
||||
// Limit display to single row with "+ X more" button
|
||||
limitToSingleRow?: boolean;
|
||||
// For vision modality check
|
||||
activeModelId?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
style = '',
|
||||
attachments = [],
|
||||
readonly = false,
|
||||
onFileRemove,
|
||||
uploadedFiles = $bindable([]),
|
||||
// Default to small size for form previews
|
||||
imageClass = '',
|
||||
imageHeight = 'h-24',
|
||||
imageWidth = 'w-auto',
|
||||
limitToSingleRow = false,
|
||||
activeModelId
|
||||
}: Props = $props();
|
||||
|
||||
let carouselRef: HorizontalScrollCarousel | undefined = $state();
|
||||
let mcpResourcePreviewOpen = $state(false);
|
||||
let mcpResourcePreviewExtra = $state<DatabaseMessageExtraMcpResource | null>(null);
|
||||
let previewFocusIndex = $state(0);
|
||||
let viewAllDialogOpen = $state(false);
|
||||
|
||||
let displayItems = $derived(getAttachmentDisplayItems({ uploadedFiles, attachments }));
|
||||
|
||||
function openPreview(item: ChatAttachmentDisplayItem, event?: MouseEvent) {
|
||||
event?.stopPropagation();
|
||||
event?.preventDefault();
|
||||
|
||||
// Find the index of the clicked item among non-MCP attachments
|
||||
const nonMcpItems = displayItems.filter((i) => !isMcpPrompt(i) && !isMcpResource(i));
|
||||
const index = nonMcpItems.findIndex((i) => i.id === item.id);
|
||||
|
||||
previewFocusIndex = index >= 0 ? index : 0;
|
||||
viewAllDialogOpen = true;
|
||||
}
|
||||
|
||||
function openMcpResourcePreview(extra: DatabaseMessageExtraMcpResource) {
|
||||
mcpResourcePreviewExtra = extra;
|
||||
mcpResourcePreviewOpen = true;
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (carouselRef && displayItems.length) {
|
||||
carouselRef.resetScroll();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
{#snippet attachmentitem(item: ChatAttachmentDisplayItem)}
|
||||
<ChatAttachmentsListItem
|
||||
{imageClass}
|
||||
{imageHeight}
|
||||
{imageWidth}
|
||||
{item}
|
||||
{limitToSingleRow}
|
||||
{onFileRemove}
|
||||
onMcpResourcePreview={openMcpResourcePreview}
|
||||
onPreview={(i: ChatAttachmentDisplayItem, event?: MouseEvent) => openPreview(i, event)}
|
||||
{readonly}
|
||||
/>
|
||||
{/snippet}
|
||||
|
||||
{#if displayItems.length > 0}
|
||||
<div class={className} {style}>
|
||||
{#if limitToSingleRow}
|
||||
<HorizontalScrollCarousel bind:this={carouselRef}>
|
||||
{#each displayItems as item (item.id)}
|
||||
{@render attachmentitem(item)}
|
||||
{/each}
|
||||
</HorizontalScrollCarousel>
|
||||
{:else}
|
||||
<div class="flex flex-wrap items-start justify-end gap-3">
|
||||
{#each displayItems as item (item.id)}
|
||||
{@render attachmentitem(item)}
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<DialogChatAttachmentsPreview
|
||||
{activeModelId}
|
||||
{attachments}
|
||||
bind:open={viewAllDialogOpen}
|
||||
{previewFocusIndex}
|
||||
{uploadedFiles}
|
||||
/>
|
||||
|
||||
{#if mcpResourcePreviewExtra}
|
||||
<DialogMcpResourcePreview extra={mcpResourcePreviewExtra} bind:open={mcpResourcePreviewOpen} />
|
||||
{/if}
|
||||
@@ -0,0 +1,132 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
ChatAttachmentsListItemMcpPrompt,
|
||||
ChatAttachmentsListItemMcpResource,
|
||||
ChatAttachmentsListItemThumbnailImage,
|
||||
ChatAttachmentsListItemThumbnailFile
|
||||
} from '$lib/components/app';
|
||||
import { AttachmentType } from '$lib/enums';
|
||||
import type {
|
||||
ChatAttachmentDisplayItem,
|
||||
DatabaseMessageExtraMcpPrompt,
|
||||
DatabaseMessageExtraMcpResource,
|
||||
MCPResourceAttachment
|
||||
} from '$lib/types';
|
||||
import { isMcpPrompt, isMcpResource, isPdfFile } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
imageClass?: string;
|
||||
imageHeight?: string;
|
||||
imageWidth?: string;
|
||||
item: ChatAttachmentDisplayItem;
|
||||
limitToSingleRow?: boolean;
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
onMcpResourcePreview?: (extra: DatabaseMessageExtraMcpResource) => void;
|
||||
onPreview?: (item: ChatAttachmentDisplayItem) => void;
|
||||
readonly?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
imageClass = '',
|
||||
imageHeight = 'h-24',
|
||||
imageWidth = 'w-auto',
|
||||
item,
|
||||
limitToSingleRow = false,
|
||||
onFileRemove,
|
||||
onMcpResourcePreview,
|
||||
onPreview,
|
||||
readonly = false
|
||||
}: Props = $props();
|
||||
|
||||
const scrollClasses = $derived(limitToSingleRow ? 'first:ml-4 last:mr-4' : '');
|
||||
|
||||
function toMcpResourceAttachment(
|
||||
extra: DatabaseMessageExtraMcpResource,
|
||||
id: string
|
||||
): MCPResourceAttachment {
|
||||
return {
|
||||
id,
|
||||
resource: {
|
||||
uri: extra.uri,
|
||||
name: extra.name,
|
||||
title: extra.name,
|
||||
serverName: extra.serverName
|
||||
}
|
||||
};
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if isMcpPrompt(item)}
|
||||
{@const mcpPrompt =
|
||||
item.attachment?.type === AttachmentType.MCP_PROMPT
|
||||
? (item.attachment as DatabaseMessageExtraMcpPrompt)
|
||||
: item.uploadedFile?.mcpPrompt
|
||||
? {
|
||||
type: AttachmentType.MCP_PROMPT as const,
|
||||
name: item.name,
|
||||
serverName: item.uploadedFile.mcpPrompt.serverName,
|
||||
promptName: item.uploadedFile.mcpPrompt.promptName,
|
||||
content: item.textContent ?? '',
|
||||
arguments: item.uploadedFile.mcpPrompt.arguments
|
||||
}
|
||||
: null}
|
||||
{#if mcpPrompt}
|
||||
<ChatAttachmentsListItemMcpPrompt
|
||||
class="max-w-[300px] min-w-[200px] flex-shrink-0 {className} {scrollClasses}"
|
||||
prompt={mcpPrompt}
|
||||
{readonly}
|
||||
isLoading={item.isLoading}
|
||||
loadError={item.loadError}
|
||||
onRemove={onFileRemove ? () => onFileRemove(item.id) : undefined}
|
||||
/>
|
||||
{/if}
|
||||
{:else if isMcpResource(item)}
|
||||
{@const mcpResource = item.attachment as DatabaseMessageExtraMcpResource}
|
||||
|
||||
<ChatAttachmentsListItemMcpResource
|
||||
class="flex-shrink-0 {className} {scrollClasses}"
|
||||
attachment={toMcpResourceAttachment(mcpResource, item.id)}
|
||||
onclick={() => onMcpResourcePreview?.(mcpResource)}
|
||||
/>
|
||||
{:else if item.isImage && item.preview}
|
||||
<ChatAttachmentsListItemThumbnailImage
|
||||
class="flex-shrink-0 cursor-pointer {className} {scrollClasses}"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onclick={() => onPreview?.(item)}
|
||||
/>
|
||||
{:else if isPdfFile(item.attachment, item.uploadedFile)}
|
||||
<ChatAttachmentsListItemThumbnailFile
|
||||
class="flex-shrink-0 cursor-pointer {className} {scrollClasses}"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
attachment={item.attachment}
|
||||
uploadedFile={item.uploadedFile}
|
||||
onclick={() => onPreview?.(item)}
|
||||
/>
|
||||
{:else}
|
||||
<ChatAttachmentsListItemThumbnailFile
|
||||
class="flex-shrink-0 cursor-pointer {className} {scrollClasses}"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
attachment={item.attachment}
|
||||
uploadedFile={item.uploadedFile}
|
||||
onclick={() => onPreview?.(item)}
|
||||
/>
|
||||
{/if}
|
||||
@@ -1,40 +1,41 @@
|
||||
<script lang="ts">
|
||||
import { ChatMessageMcpPromptContent, ActionIconRemove } from '$lib/components/app';
|
||||
import { ChatMessageMcpPromptContent, ActionIcon } from '$lib/components/app';
|
||||
import { X } from '@lucide/svelte';
|
||||
import type { DatabaseMessageExtraMcpPrompt } from '$lib/types';
|
||||
import { McpPromptVariant } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
prompt: DatabaseMessageExtraMcpPrompt;
|
||||
readonly?: boolean;
|
||||
isLoading?: boolean;
|
||||
loadError?: string;
|
||||
onRemove?: () => void;
|
||||
prompt: DatabaseMessageExtraMcpPrompt;
|
||||
readonly?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
prompt,
|
||||
readonly = false,
|
||||
isLoading = false,
|
||||
loadError,
|
||||
onRemove
|
||||
onRemove,
|
||||
prompt,
|
||||
readonly = false
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="group relative {className}">
|
||||
<ChatMessageMcpPromptContent
|
||||
{prompt}
|
||||
variant={McpPromptVariant.ATTACHMENT}
|
||||
{isLoading}
|
||||
{loadError}
|
||||
{prompt}
|
||||
variant={McpPromptVariant.ATTACHMENT}
|
||||
/>
|
||||
|
||||
{#if !readonly && onRemove}
|
||||
<div
|
||||
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
|
||||
>
|
||||
<ActionIconRemove id={prompt.name} onRemove={() => onRemove?.()} />
|
||||
<ActionIcon icon={X} tooltip="Remove" stopPropagationOnClick onclick={() => onRemove?.()} />
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -1,46 +1,47 @@
|
||||
<script lang="ts">
|
||||
import { Loader2, AlertCircle } from '@lucide/svelte';
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import type { MCPResourceAttachment } from '$lib/types';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { ActionIconRemove } from '$lib/components/app';
|
||||
import { ActionIcon } from '$lib/components/app';
|
||||
import { X } from '@lucide/svelte';
|
||||
import { getResourceIcon, getResourceDisplayName } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
attachment: MCPResourceAttachment;
|
||||
onRemove?: (attachmentId: string) => void;
|
||||
onClick?: () => void;
|
||||
class?: string;
|
||||
onclick?: () => void;
|
||||
onRemove?: (attachmentId: string) => void;
|
||||
}
|
||||
|
||||
let { attachment, onRemove, onClick, class: className }: Props = $props();
|
||||
|
||||
function getStatusClass(attachment: MCPResourceAttachment): string {
|
||||
if (attachment.error) return 'border-red-500/50 bg-red-500/10';
|
||||
if (attachment.loading) return 'border-border/50 bg-muted/30';
|
||||
return 'border-border/50 bg-muted/30';
|
||||
}
|
||||
let { attachment, class: className, onclick, onRemove }: Props = $props();
|
||||
|
||||
const ResourceIcon = $derived(
|
||||
getResourceIcon(attachment.resource.mimeType, attachment.resource.uri)
|
||||
);
|
||||
const serverName = $derived(mcpStore.getServerDisplayName(attachment.resource.serverName));
|
||||
const favicon = $derived(mcpStore.getServerFavicon(attachment.resource.serverName));
|
||||
|
||||
function getStatusClass(attachment: MCPResourceAttachment): string {
|
||||
if (attachment.error) return 'border-red-500/50 bg-red-500/10';
|
||||
if (attachment.loading) return 'border-border/50 bg-muted/30';
|
||||
|
||||
return 'border-border/50 bg-muted/30';
|
||||
}
|
||||
</script>
|
||||
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<button
|
||||
type="button"
|
||||
class={cn(
|
||||
class={[
|
||||
'flex flex-shrink-0 items-center gap-1.5 rounded-md border px-2 py-0.75 text-sm transition-colors',
|
||||
getStatusClass(attachment),
|
||||
onClick && 'cursor-pointer hover:bg-muted/50',
|
||||
onclick && 'cursor-pointer hover:bg-muted/50',
|
||||
className
|
||||
)}
|
||||
onclick={onClick}
|
||||
disabled={!onClick}
|
||||
]}
|
||||
disabled={!onclick}
|
||||
{onclick}
|
||||
type="button"
|
||||
>
|
||||
{#if attachment.loading}
|
||||
<Loader2 class="h-3 w-3 animate-spin text-muted-foreground" />
|
||||
@@ -55,11 +56,13 @@
|
||||
</span>
|
||||
|
||||
{#if onRemove}
|
||||
<ActionIconRemove
|
||||
<ActionIcon
|
||||
class="-my-2 -mr-1.5 bg-transparent"
|
||||
iconSize={2}
|
||||
id={attachment.id}
|
||||
{onRemove}
|
||||
icon={X}
|
||||
iconSize="h-2 w-2"
|
||||
onclick={() => onRemove?.(attachment.id)}
|
||||
stopPropagationOnClick
|
||||
tooltip="Remove"
|
||||
/>
|
||||
{/if}
|
||||
</button>
|
||||
@@ -69,12 +72,12 @@
|
||||
<div class="flex items-center gap-1 text-xs">
|
||||
{#if favicon}
|
||||
<img
|
||||
src={favicon}
|
||||
alt=""
|
||||
alt={attachment.resource.serverName}
|
||||
class="h-3 w-3 shrink-0 rounded-sm"
|
||||
onerror={(e) => {
|
||||
(e.currentTarget as HTMLImageElement).style.display = 'none';
|
||||
}}
|
||||
src={favicon}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import {
|
||||
formatFileSize,
|
||||
getFileTypeLabel,
|
||||
getPreviewText,
|
||||
isPdfFile,
|
||||
isTextFile
|
||||
} from '$lib/utils';
|
||||
import { ActionIcon } from '$lib/components/app';
|
||||
import { AttachmentType } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
attachment?: DatabaseMessageExtra;
|
||||
class?: string;
|
||||
id: string;
|
||||
onclick?: (event: MouseEvent) => void;
|
||||
onRemove?: (id: string) => void;
|
||||
name: string;
|
||||
readonly?: boolean;
|
||||
size?: number;
|
||||
textContent?: string;
|
||||
// Either uploaded file or stored attachment
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
}
|
||||
|
||||
let {
|
||||
attachment,
|
||||
class: className = '',
|
||||
id,
|
||||
onclick,
|
||||
onRemove,
|
||||
name,
|
||||
readonly = false,
|
||||
size,
|
||||
textContent,
|
||||
uploadedFile
|
||||
}: Props = $props();
|
||||
|
||||
let isPdf = $derived(isPdfFile(attachment, uploadedFile));
|
||||
let isPdfWithContent = $derived(isPdf && !!textContent);
|
||||
|
||||
let isText = $derived(isTextFile(attachment, uploadedFile));
|
||||
let isTextWithContent = $derived(isText && !!textContent);
|
||||
|
||||
let fileTypeLabel = $derived.by(() => {
|
||||
if (uploadedFile?.type) {
|
||||
return getFileTypeLabel(uploadedFile.type);
|
||||
}
|
||||
|
||||
if (attachment) {
|
||||
if ('mimeType' in attachment && attachment.mimeType) {
|
||||
return getFileTypeLabel(attachment.mimeType);
|
||||
}
|
||||
|
||||
if (attachment.type) {
|
||||
return getFileTypeLabel(attachment.type);
|
||||
}
|
||||
}
|
||||
|
||||
return getFileTypeLabel(name);
|
||||
});
|
||||
|
||||
let pdfProcessingMode = $derived.by(() => {
|
||||
if (attachment?.type === AttachmentType.PDF) {
|
||||
const pdfAttachment = attachment as DatabaseMessageExtraPdfFile;
|
||||
|
||||
return pdfAttachment.processedAsImages ? 'Sent as Image' : 'Sent as Text';
|
||||
}
|
||||
|
||||
return null;
|
||||
});
|
||||
</script>
|
||||
|
||||
{#snippet textPreview(content: string)}
|
||||
<div class="relative">
|
||||
<div
|
||||
class="font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground {!readonly
|
||||
? 'max-h-3rem line-height-1.2'
|
||||
: ''}"
|
||||
>
|
||||
{getPreviewText(content)}
|
||||
</div>
|
||||
|
||||
{#if content.length > 150}
|
||||
<div
|
||||
class="pointer-events-none absolute right-0 bottom-0 left-0 h-4 bg-gradient-to-t from-muted to-transparent {readonly
|
||||
? 'h-6'
|
||||
: ''}"
|
||||
></div>
|
||||
{/if}
|
||||
</div>
|
||||
{/snippet}
|
||||
|
||||
{#snippet removeButton()}
|
||||
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
|
||||
<ActionIcon icon={X} tooltip="Remove" stopPropagationOnClick onclick={() => onRemove?.(id)} />
|
||||
</div>
|
||||
{/snippet}
|
||||
|
||||
{#snippet fileIcon()}
|
||||
<div
|
||||
class="flex h-8 w-8 items-center justify-center rounded bg-primary/10 text-xs font-medium text-primary"
|
||||
>
|
||||
{fileTypeLabel}
|
||||
</div>
|
||||
{/snippet}
|
||||
|
||||
{#snippet info(text: string | undefined)}
|
||||
{#if text}
|
||||
<span class="text-xs text-muted-foreground">{text}</span>
|
||||
{/if}
|
||||
{/snippet}
|
||||
|
||||
{#if isTextWithContent || isPdfWithContent}
|
||||
<button
|
||||
aria-label={readonly ? `Preview ${name}` : undefined}
|
||||
class="rounded-lg border border-border bg-muted p-3 {className} cursor-pointer {readonly
|
||||
? 'w-full max-w-2xl transition-shadow hover:shadow-md'
|
||||
: `group relative text-left ${textContent ? 'max-h-24 max-w-72' : 'max-w-36'}`} overflow-hidden"
|
||||
{onclick}
|
||||
type="button"
|
||||
>
|
||||
{#if !readonly}
|
||||
{@render removeButton()}
|
||||
{/if}
|
||||
|
||||
<div class={[!readonly && 'pr-8', 'overflow-hidden']}>
|
||||
{#if readonly}
|
||||
<div class="flex items-start gap-3">
|
||||
<div class="flex min-w-0 flex-1 flex-col items-start text-left">
|
||||
<span class="w-full truncate text-sm font-medium text-foreground">{name}</span>
|
||||
|
||||
{@render info(pdfProcessingMode || (size ? formatFileSize(size) : undefined))}
|
||||
|
||||
{#if textContent}
|
||||
{@render textPreview(textContent)}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<span class="mb-3 block truncate text-sm font-medium text-foreground">{name}</span>
|
||||
|
||||
{#if textContent}
|
||||
{@render textPreview(textContent)}
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="group flex items-center gap-3 rounded-lg border border-border bg-muted p-3 {className} relative"
|
||||
{onclick}
|
||||
type="button"
|
||||
>
|
||||
{@render fileIcon()}
|
||||
|
||||
<div class="flex flex-col items-start gap-0.5">
|
||||
<span
|
||||
class="max-w-24 truncate text-sm font-medium text-foreground {readonly
|
||||
? ''
|
||||
: 'group-hover:pr-6'} md:max-w-32"
|
||||
>
|
||||
{name}
|
||||
</span>
|
||||
|
||||
{@render info(pdfProcessingMode || (size ? formatFileSize(size) : undefined))}
|
||||
</div>
|
||||
|
||||
{#if !readonly}
|
||||
{@render removeButton()}
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
@@ -1,64 +1,65 @@
|
||||
<script lang="ts">
|
||||
import { ActionIconRemove } from '$lib/components/app';
|
||||
import { ActionIcon } from '$lib/components/app';
|
||||
import { X } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
height?: string;
|
||||
id: string;
|
||||
imageClass?: string;
|
||||
onclick?: (event?: MouseEvent) => void;
|
||||
onRemove?: (id: string) => void;
|
||||
name: string;
|
||||
preview: string;
|
||||
readonly?: boolean;
|
||||
onRemove?: (id: string) => void;
|
||||
onClick?: (event?: MouseEvent) => void;
|
||||
class?: string;
|
||||
// Customizable size props
|
||||
width?: string;
|
||||
height?: string;
|
||||
imageClass?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
height = 'h-16',
|
||||
id,
|
||||
imageClass = '',
|
||||
onclick,
|
||||
onRemove,
|
||||
name,
|
||||
preview,
|
||||
readonly = false,
|
||||
onRemove,
|
||||
onClick,
|
||||
class: className = '',
|
||||
// Default to small size for form previews
|
||||
width = 'w-auto',
|
||||
height = 'h-16',
|
||||
imageClass = ''
|
||||
width = 'w-auto'
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
{#snippet image()}
|
||||
<img src={preview} alt={name} class="{height} {width} cursor-pointer object-cover {imageClass}" />
|
||||
{/snippet}
|
||||
|
||||
<div
|
||||
class="group relative overflow-hidden rounded-lg bg-muted shadow-lg dark:border dark:border-muted {className}"
|
||||
>
|
||||
{#if onClick}
|
||||
{#if onclick}
|
||||
<button
|
||||
type="button"
|
||||
class="block h-full w-full rounded-lg focus:ring-2 focus:ring-primary focus:ring-offset-2 focus:outline-none"
|
||||
onclick={onClick}
|
||||
aria-label="Preview {name}"
|
||||
class="block h-full w-full rounded-lg focus:ring-2 focus:ring-primary focus:ring-offset-2 focus:outline-none"
|
||||
{onclick}
|
||||
type="button"
|
||||
>
|
||||
<img
|
||||
src={preview}
|
||||
alt={name}
|
||||
class="{height} {width} cursor-pointer object-cover {imageClass}"
|
||||
/>
|
||||
{@render image()}
|
||||
</button>
|
||||
{:else}
|
||||
<img
|
||||
src={preview}
|
||||
alt={name}
|
||||
class="{height} {width} cursor-pointer object-cover {imageClass}"
|
||||
/>
|
||||
{@render image()}
|
||||
{/if}
|
||||
|
||||
{#if !readonly}
|
||||
<div
|
||||
class="absolute top-1 right-1 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
|
||||
>
|
||||
<ActionIconRemove {id} {onRemove} class="text-white" />
|
||||
<ActionIcon
|
||||
class="text-white"
|
||||
icon={X}
|
||||
onclick={() => onRemove?.(id)}
|
||||
stopPropagationOnClick
|
||||
tooltip="Remove"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -0,0 +1,190 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
ChatAttachmentsPreviewCurrentItem,
|
||||
ChatAttachmentsPreviewFileInfo,
|
||||
ChatAttachmentsPreviewNavButtons,
|
||||
ChatAttachmentsPreviewThumbnailStrip
|
||||
} from '$lib/components/app';
|
||||
import { modelsStore } from '$lib/stores/models.svelte';
|
||||
import {
|
||||
createBase64DataUrl,
|
||||
formatFileSize,
|
||||
getAttachmentDisplayItems,
|
||||
getLanguageFromFilename,
|
||||
isAudioFile,
|
||||
isImageFile,
|
||||
isMcpPrompt,
|
||||
isMcpResource,
|
||||
isPdfFile,
|
||||
isTextFile
|
||||
} from '$lib/utils';
|
||||
|
||||
interface PreviewItem {
|
||||
id: string;
|
||||
name: string;
|
||||
size?: number;
|
||||
preview?: string;
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
textContent?: string;
|
||||
isImage: boolean;
|
||||
isAudio: boolean;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
activeModelId?: string;
|
||||
class?: string;
|
||||
previewFocusIndex?: number;
|
||||
}
|
||||
|
||||
let {
|
||||
uploadedFiles = [],
|
||||
attachments = [],
|
||||
activeModelId,
|
||||
class: className = '',
|
||||
previewFocusIndex = 0
|
||||
}: Props = $props();
|
||||
|
||||
let allItems = $derived(
|
||||
getAttachmentDisplayItems({ uploadedFiles, attachments })
|
||||
.filter((item) => !isMcpPrompt(item) && !isMcpResource(item))
|
||||
.map(
|
||||
(item): PreviewItem => ({
|
||||
...item,
|
||||
isImage: isImageFile(item.attachment, item.uploadedFile),
|
||||
isAudio: isAudioFile(item.attachment, item.uploadedFile)
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
let currentIndex = $state(0);
|
||||
|
||||
$effect(() => {
|
||||
if (previewFocusIndex >= 0 && previewFocusIndex < allItems.length) {
|
||||
currentIndex = previewFocusIndex;
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
const handler = (e: Event) => {
|
||||
const delta = (e as CustomEvent).detail;
|
||||
|
||||
if (delta < 0) {
|
||||
currentIndex = currentIndex > 0 ? currentIndex - 1 : allItems.length - 1;
|
||||
} else {
|
||||
currentIndex = currentIndex < allItems.length - 1 ? currentIndex + 1 : 0;
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener('chat-attachments-nav', handler);
|
||||
|
||||
return () => document.removeEventListener('chat-attachments-nav', handler);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
const index = currentIndex;
|
||||
setTimeout(() => {
|
||||
const thumbnail = document.querySelector(`[data-thumbnail-index="${index}"]`);
|
||||
|
||||
thumbnail?.scrollIntoView({ behavior: 'smooth', inline: 'center', block: 'nearest' });
|
||||
}, 0);
|
||||
});
|
||||
|
||||
let currentItem = $derived(allItems[currentIndex] ?? null);
|
||||
let displayName = $derived(
|
||||
currentItem?.name ||
|
||||
currentItem?.uploadedFile?.name ||
|
||||
currentItem?.attachment?.name ||
|
||||
'Unknown File'
|
||||
);
|
||||
let isAudio = $derived(
|
||||
currentItem ? isAudioFile(currentItem.attachment, currentItem.uploadedFile) : false
|
||||
);
|
||||
let isImage = $derived(
|
||||
currentItem ? isImageFile(currentItem.attachment, currentItem.uploadedFile) : false
|
||||
);
|
||||
let isPdf = $derived(
|
||||
currentItem ? isPdfFile(currentItem.attachment, currentItem.uploadedFile) : false
|
||||
);
|
||||
let isText = $derived(
|
||||
currentItem ? isTextFile(currentItem.attachment, currentItem.uploadedFile) : false
|
||||
);
|
||||
|
||||
let displayPreview = $derived(
|
||||
currentItem?.uploadedFile?.preview ||
|
||||
(isImage && currentItem?.attachment && 'base64Url' in currentItem.attachment
|
||||
? currentItem.attachment.base64Url
|
||||
: currentItem?.preview)
|
||||
);
|
||||
|
||||
let displayTextContent = $derived(
|
||||
currentItem?.uploadedFile?.textContent ||
|
||||
(currentItem?.attachment && 'content' in currentItem.attachment
|
||||
? currentItem.attachment.content
|
||||
: currentItem?.textContent)
|
||||
);
|
||||
|
||||
let language = $derived(getLanguageFromFilename(displayName));
|
||||
|
||||
let fileSize = $derived(currentItem?.size ? formatFileSize(currentItem.size) : '');
|
||||
|
||||
let hasVisionModality = $derived(
|
||||
currentItem && activeModelId ? modelsStore.modelSupportsVision(activeModelId) : false
|
||||
);
|
||||
|
||||
let audioSrc = $derived(
|
||||
isAudio && currentItem
|
||||
? (currentItem.uploadedFile?.preview ??
|
||||
(currentItem.attachment &&
|
||||
'mimeType' in currentItem.attachment &&
|
||||
'base64Data' in currentItem.attachment
|
||||
? createBase64DataUrl(
|
||||
currentItem.attachment.mimeType,
|
||||
currentItem.attachment.base64Data
|
||||
)
|
||||
: null))
|
||||
: null
|
||||
);
|
||||
|
||||
export function prev() {
|
||||
currentIndex = currentIndex > 0 ? currentIndex - 1 : allItems.length - 1;
|
||||
}
|
||||
|
||||
export function next() {
|
||||
currentIndex = currentIndex < allItems.length - 1 ? currentIndex + 1 : 0;
|
||||
}
|
||||
|
||||
function onNavigate(index: number) {
|
||||
currentIndex = index;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="{className} flex flex-col text-white">
|
||||
<div class="relative flex min-h-0 flex-1 items-center justify-center overflow-hidden">
|
||||
<ChatAttachmentsPreviewNavButtons onPrev={prev} onNext={next} show={allItems.length > 1} />
|
||||
|
||||
<div class="flex h-full w-full flex-col items-center justify-start overflow-auto py-4">
|
||||
{#if currentItem}
|
||||
<ChatAttachmentsPreviewFileInfo {displayName} {fileSize} />
|
||||
|
||||
<ChatAttachmentsPreviewCurrentItem
|
||||
{currentItem}
|
||||
{isImage}
|
||||
{isAudio}
|
||||
{isPdf}
|
||||
{isText}
|
||||
{displayPreview}
|
||||
{displayTextContent}
|
||||
{audioSrc}
|
||||
{language}
|
||||
{hasVisionModality}
|
||||
{activeModelId}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<ChatAttachmentsPreviewThumbnailStrip items={allItems} {currentIndex} {onNavigate} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,65 @@
|
||||
<script lang="ts">
|
||||
import type { ChatAttachmentDisplayItem } from '$lib/types';
|
||||
import { Image, Music, FileText, FileIcon } from '@lucide/svelte';
|
||||
import ChatAttachmentsPreviewCurrentItemPdf from './ChatAttachmentsPreviewCurrentItemPdf.svelte';
|
||||
import ChatAttachmentsPreviewCurrentItemImage from './ChatAttachmentsPreviewCurrentItemImage.svelte';
|
||||
import ChatAttachmentsPreviewCurrentItemAudio from './ChatAttachmentsPreviewCurrentItemAudio.svelte';
|
||||
import ChatAttachmentsPreviewCurrentItemText from './ChatAttachmentsPreviewCurrentItemText.svelte';
|
||||
import ChatAttachmentsPreviewCurrentItemUnavailable from './ChatAttachmentsPreviewCurrentItemUnavailable.svelte';
|
||||
|
||||
interface Props {
|
||||
currentItem: ChatAttachmentDisplayItem | null;
|
||||
isImage: boolean;
|
||||
isAudio: boolean;
|
||||
isPdf: boolean;
|
||||
isText: boolean;
|
||||
displayPreview: string | undefined;
|
||||
displayTextContent: string | undefined;
|
||||
audioSrc: string | null;
|
||||
language: string;
|
||||
hasVisionModality: boolean;
|
||||
activeModelId?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
currentItem,
|
||||
isImage,
|
||||
isAudio,
|
||||
isPdf,
|
||||
isText,
|
||||
displayPreview,
|
||||
displayTextContent,
|
||||
audioSrc,
|
||||
language,
|
||||
hasVisionModality,
|
||||
activeModelId
|
||||
}: Props = $props();
|
||||
|
||||
let IconComponent = $derived(
|
||||
isImage ? Image : isText || isPdf ? FileText : isAudio ? Music : FileIcon
|
||||
);
|
||||
|
||||
let isUnavailable = $derived(!isPdf && !isImage && !(isText && displayTextContent) && !isAudio);
|
||||
</script>
|
||||
|
||||
{#if currentItem}
|
||||
{#key currentItem.id}
|
||||
{#if isPdf}
|
||||
<ChatAttachmentsPreviewCurrentItemPdf
|
||||
{currentItem}
|
||||
displayName={currentItem.name}
|
||||
{displayTextContent}
|
||||
{hasVisionModality}
|
||||
{activeModelId}
|
||||
/>
|
||||
{:else if isImage}
|
||||
<ChatAttachmentsPreviewCurrentItemImage {currentItem} {displayPreview} />
|
||||
{:else if isText && displayTextContent}
|
||||
<ChatAttachmentsPreviewCurrentItemText {displayTextContent} {language} />
|
||||
{:else if isAudio}
|
||||
<ChatAttachmentsPreviewCurrentItemAudio {currentItem} {audioSrc} />
|
||||
{:else if isUnavailable}
|
||||
<ChatAttachmentsPreviewCurrentItemUnavailable {IconComponent} />
|
||||
{/if}
|
||||
{/key}
|
||||
{/if}
|
||||
@@ -0,0 +1,26 @@
|
||||
<script lang="ts">
|
||||
import { Music } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
currentItem: { name?: string } | null;
|
||||
audioSrc: string | null;
|
||||
}
|
||||
|
||||
let { currentItem, audioSrc }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="flex flex-1 items-center justify-center p-8">
|
||||
<div class="w-full max-w-md text-center">
|
||||
<Music class="mx-auto mb-4 h-16 w-16 text-white/50" />
|
||||
|
||||
{#if audioSrc}
|
||||
<audio controls class="mb-4 w-full" src={audioSrc}>
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
{:else}
|
||||
<p class="mb-4 text-white/70">Audio preview not available</p>
|
||||
{/if}
|
||||
|
||||
<p class="text-sm text-white/50">{currentItem?.name || 'Audio'}</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,18 @@
|
||||
<script lang="ts">
|
||||
interface Props {
|
||||
currentItem: { name?: string } | null;
|
||||
displayPreview: string | undefined;
|
||||
}
|
||||
|
||||
let { currentItem, displayPreview }: Props = $props();
|
||||
</script>
|
||||
|
||||
{#if displayPreview}
|
||||
<div class="flex flex-1 items-center justify-center">
|
||||
<img
|
||||
src={displayPreview}
|
||||
alt={currentItem?.name || 'preview'}
|
||||
class="max-h-[80vh] max-w-[80vw] rounded-lg object-contain shadow-lg"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -0,0 +1,174 @@
|
||||
<script lang="ts">
|
||||
import type { ChatAttachmentDisplayItem } from '$lib/types';
|
||||
import { FileText, Eye, Info } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Alert from '$lib/components/ui/alert';
|
||||
import { SyntaxHighlightedCode } from '$lib/components/app';
|
||||
import { getLanguageFromFilename } from '$lib/utils';
|
||||
import { convertPDFToImage } from '$lib/utils/browser-only';
|
||||
import { PdfViewMode } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
currentItem: ChatAttachmentDisplayItem | null;
|
||||
displayName: string;
|
||||
displayTextContent: string | undefined;
|
||||
hasVisionModality: boolean;
|
||||
activeModelId?: string;
|
||||
}
|
||||
|
||||
let { currentItem, displayName, displayTextContent, hasVisionModality, activeModelId }: Props =
|
||||
$props();
|
||||
|
||||
let pdfViewMode = $state<PdfViewMode>(PdfViewMode.PAGES);
|
||||
let pdfImages = $state<string[]>([]);
|
||||
let pdfImagesLoading = $state(false);
|
||||
let pdfImagesError = $state<string | null>(null);
|
||||
|
||||
let language = $derived(getLanguageFromFilename(displayName));
|
||||
|
||||
async function loadPdfImages() {
|
||||
if (pdfImages.length > 0 || pdfImagesLoading || !currentItem) return;
|
||||
|
||||
pdfImagesLoading = true;
|
||||
pdfImagesError = null;
|
||||
|
||||
try {
|
||||
let file: File | null = null;
|
||||
|
||||
if (currentItem.uploadedFile?.file) {
|
||||
file = currentItem.uploadedFile.file;
|
||||
} else if (currentItem.attachment) {
|
||||
// Check if we have pre-processed images
|
||||
if (
|
||||
'images' in currentItem.attachment &&
|
||||
currentItem.attachment.images &&
|
||||
Array.isArray(currentItem.attachment.images) &&
|
||||
currentItem.attachment.images.length > 0
|
||||
) {
|
||||
pdfImages = currentItem.attachment.images;
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert base64 back to File for processing
|
||||
if ('base64Data' in currentItem.attachment && currentItem.attachment.base64Data) {
|
||||
const base64Data = currentItem.attachment.base64Data;
|
||||
const byteCharacters = atob(base64Data);
|
||||
const byteNumbers = new Array(byteCharacters.length);
|
||||
for (let i = 0; i < byteCharacters.length; i++) {
|
||||
byteNumbers[i] = byteCharacters.charCodeAt(i);
|
||||
}
|
||||
const byteArray = new Uint8Array(byteNumbers);
|
||||
file = new File([byteArray], displayName, { type: 'application/pdf' });
|
||||
}
|
||||
}
|
||||
|
||||
if (file) {
|
||||
pdfImages = await convertPDFToImage(file);
|
||||
} else {
|
||||
throw new Error('No PDF file available for conversion');
|
||||
}
|
||||
} catch (error) {
|
||||
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
|
||||
} finally {
|
||||
pdfImagesLoading = false;
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (pdfViewMode === PdfViewMode.PAGES) {
|
||||
loadPdfImages();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="mb-4 flex items-center justify-end gap-2">
|
||||
<Button
|
||||
variant={pdfViewMode === PdfViewMode.TEXT ? 'default' : 'outline'}
|
||||
size="sm"
|
||||
onclick={() => (pdfViewMode = PdfViewMode.TEXT)}
|
||||
disabled={pdfImagesLoading}
|
||||
>
|
||||
<FileText class="mr-1 h-4 w-4" />
|
||||
Text
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant={pdfViewMode === PdfViewMode.PAGES ? 'default' : 'outline'}
|
||||
size="sm"
|
||||
onclick={() => (pdfViewMode = PdfViewMode.PAGES)}
|
||||
disabled={pdfImagesLoading}
|
||||
>
|
||||
{#if pdfImagesLoading}
|
||||
<div
|
||||
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
|
||||
></div>
|
||||
{:else}
|
||||
<Eye class="mr-1 h-4 w-4" />
|
||||
{/if}
|
||||
Pages
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{#if !hasVisionModality && activeModelId && currentItem}
|
||||
<Alert.Root class="mb-4 max-w-4xl">
|
||||
<Info class="h-4 w-4" />
|
||||
<Alert.Title>Preview only</Alert.Title>
|
||||
<Alert.Description>
|
||||
<span class="inline-flex">
|
||||
The selected model does not support vision. Only the extracted
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<span
|
||||
class="mx-1 cursor-pointer underline"
|
||||
onclick={() => (pdfViewMode = PdfViewMode.TEXT)}
|
||||
>
|
||||
text
|
||||
</span>
|
||||
will be sent to the model.
|
||||
</span>
|
||||
</Alert.Description>
|
||||
</Alert.Root>
|
||||
{/if}
|
||||
|
||||
{#if pdfImagesLoading}
|
||||
<div class="flex flex-1 items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<div
|
||||
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-white border-t-transparent"
|
||||
></div>
|
||||
<p class="text-white/70">Converting PDF to images...</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else if pdfImagesError}
|
||||
<div class="flex flex-1 items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<FileText class="mx-auto mb-4 h-16 w-16 text-white/50" />
|
||||
<p class="mb-4 text-white/70">Failed to load PDF images</p>
|
||||
<p class="text-sm text-white/50">{pdfImagesError}</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else if pdfImages.length > 0}
|
||||
{#each pdfImages as image, index (image)}
|
||||
<p class="mb-2 text-sm text-white/50">Page {index + 1}</p>
|
||||
<img src={image} alt="PDF Page {index + 1}" class="mx-auto max-w-[85vw] rounded-lg shadow-lg" />
|
||||
<div class="h-4"></div>
|
||||
{/each}
|
||||
{:else}
|
||||
<div class="flex flex-1 items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<FileText class="mx-auto mb-4 h-16 w-16 text-white/50" />
|
||||
<p class="text-white/70">No PDF pages available</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if pdfViewMode === PdfViewMode.TEXT && displayTextContent}
|
||||
<div class="px-4 pb-4">
|
||||
<SyntaxHighlightedCode
|
||||
class="max-w-4xl"
|
||||
code={displayTextContent}
|
||||
{language}
|
||||
maxHeight="none"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -0,0 +1,21 @@
|
||||
<script lang="ts">
|
||||
import { SyntaxHighlightedCode } from '$lib/components/app';
|
||||
|
||||
interface Props {
|
||||
displayTextContent: string | undefined;
|
||||
language: string;
|
||||
}
|
||||
|
||||
let { displayTextContent, language }: Props = $props();
|
||||
</script>
|
||||
|
||||
{#if displayTextContent}
|
||||
<div class="px-4 pb-4">
|
||||
<SyntaxHighlightedCode
|
||||
class="max-w-4xl"
|
||||
code={displayTextContent}
|
||||
{language}
|
||||
maxHeight="none"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -0,0 +1,17 @@
|
||||
<script lang="ts">
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
IconComponent: Component;
|
||||
}
|
||||
|
||||
let { IconComponent }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="flex flex-1 items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<IconComponent class="mx-auto mb-4 h-16 w-16 text-white/50" />
|
||||
|
||||
<p class="text-white/70">Preview not available for this file type</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,16 @@
|
||||
<script lang="ts">
|
||||
interface Props {
|
||||
displayName: string;
|
||||
fileSize: string;
|
||||
}
|
||||
|
||||
let { displayName, fileSize }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="sticky top-0 z-[20] mb-4 rounded-lg bg-black/5 px-4 py-2 text-center backdrop-blur-md">
|
||||
<p class="font-medium text-white">{displayName}</p>
|
||||
|
||||
{#if fileSize}
|
||||
<p class="text-xs text-white/60">{fileSize}</p>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -0,0 +1,34 @@
|
||||
<script lang="ts">
|
||||
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
|
||||
interface Props {
|
||||
onPrev: () => void;
|
||||
onNext: () => void;
|
||||
show: boolean;
|
||||
}
|
||||
|
||||
let { onPrev, onNext, show }: Props = $props();
|
||||
</script>
|
||||
|
||||
{#if show}
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="icon"
|
||||
class="absolute top-1/2 left-4 z-10 h-8 w-8 -translate-y-1/2 rounded-full bg-background/5 p-0 text-white!"
|
||||
onclick={onPrev}
|
||||
aria-label="Previous"
|
||||
>
|
||||
<ChevronLeft class="size-4" />
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="icon"
|
||||
class="absolute top-1/2 right-4 z-10 h-8 w-8 -translate-y-1/2 rounded-full bg-background/5 p-0 text-white!"
|
||||
onclick={onNext}
|
||||
aria-label="Next"
|
||||
>
|
||||
<ChevronRight class="size-4" />
|
||||
</Button>
|
||||
{/if}
|
||||
@@ -0,0 +1,63 @@
|
||||
<script lang="ts">
|
||||
import { Music, FileText } from '@lucide/svelte';
|
||||
import { HorizontalScrollCarousel } from '$lib/components/app/misc';
|
||||
|
||||
interface PreviewItem {
|
||||
id: string;
|
||||
name: string;
|
||||
isImage: boolean;
|
||||
isAudio: boolean;
|
||||
preview?: string;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
items: PreviewItem[];
|
||||
currentIndex: number;
|
||||
onNavigate: (index: number) => void;
|
||||
}
|
||||
|
||||
let { items, currentIndex, onNavigate }: Props = $props();
|
||||
|
||||
function getFileExtension(name: string): string {
|
||||
const parts = name.split('.');
|
||||
if (parts.length > 1) {
|
||||
return parts.pop()?.toUpperCase() ?? '';
|
||||
}
|
||||
return '';
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if items.length > 1}
|
||||
<div class="sticky bottom-0 z-10 mt-4 flex-shrink-0">
|
||||
<HorizontalScrollCarousel class="max-w-full">
|
||||
{#each items as item, index (item.id)}
|
||||
<button
|
||||
data-thumbnail-index={index}
|
||||
class={[
|
||||
'relative flex-shrink-0 cursor-pointer overflow-hidden rounded border-2 bg-black/80 backdrop-blur-sm transition-all hover:opacity-90',
|
||||
index === currentIndex ? 'border-white' : 'border-transparent opacity-60',
|
||||
'[&:not(:first-child)]:last:mr-4 [&:not(:last-child)]:first:ml-4'
|
||||
]}
|
||||
onclick={() => onNavigate(index)}
|
||||
aria-label={`Go to ${item.name}`}
|
||||
>
|
||||
{#if item.isImage && item.preview}
|
||||
<img src={item.preview} alt={item.name} class="h-12 w-12 object-cover" />
|
||||
{:else}
|
||||
<div
|
||||
class="bg-foreground-muted/50 flex h-12 w-12 flex-col items-center justify-center gap-0.5 py-1"
|
||||
>
|
||||
{#if item.isAudio}
|
||||
<Music class="h-4 w-4 text-white/70" />
|
||||
{:else}
|
||||
<FileText class="h-4 w-4 text-white/70" />
|
||||
{/if}
|
||||
|
||||
<span class="font-mono text-[9px] text-white/60">{getFileExtension(item.name)}</span>
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
{/each}
|
||||
</HorizontalScrollCarousel>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -1,117 +0,0 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
ChatAttachmentThumbnailImage,
|
||||
ChatAttachmentThumbnailFile,
|
||||
DialogChatAttachmentPreview
|
||||
} from '$lib/components/app';
|
||||
import { getAttachmentDisplayItems } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
readonly?: boolean;
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
imageHeight?: string;
|
||||
imageWidth?: string;
|
||||
imageClass?: string;
|
||||
activeModelId?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
uploadedFiles = [],
|
||||
attachments = [],
|
||||
readonly = false,
|
||||
onFileRemove,
|
||||
imageHeight = 'h-24',
|
||||
imageWidth = 'w-auto',
|
||||
imageClass = '',
|
||||
activeModelId
|
||||
}: Props = $props();
|
||||
|
||||
let previewDialogOpen = $state(false);
|
||||
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
|
||||
|
||||
let displayItems = $derived(getAttachmentDisplayItems({ uploadedFiles, attachments }));
|
||||
let imageItems = $derived(displayItems.filter((item) => item.isImage));
|
||||
let fileItems = $derived(displayItems.filter((item) => !item.isImage));
|
||||
|
||||
function openPreview(item: (typeof displayItems)[0], event?: Event) {
|
||||
if (event) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
}
|
||||
|
||||
previewItem = {
|
||||
uploadedFile: item.uploadedFile,
|
||||
attachment: item.attachment,
|
||||
preview: item.preview,
|
||||
name: item.name,
|
||||
size: item.size,
|
||||
textContent: item.textContent
|
||||
};
|
||||
previewDialogOpen = true;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="space-y-4">
|
||||
<div class="min-h-0 flex-1 space-y-6 overflow-y-auto px-1">
|
||||
{#if fileItems.length > 0}
|
||||
<div>
|
||||
<h3 class="mb-3 text-sm font-medium text-foreground">Files ({fileItems.length})</h3>
|
||||
<div class="flex flex-wrap items-start gap-3">
|
||||
{#each fileItems as item (item.id)}
|
||||
<ChatAttachmentThumbnailFile
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
attachment={item.attachment}
|
||||
uploadedFile={item.uploadedFile}
|
||||
onClick={(event?: MouseEvent) => openPreview(item, event)}
|
||||
/>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if imageItems.length > 0}
|
||||
<div>
|
||||
<h3 class="mb-3 text-sm font-medium text-foreground">Images ({imageItems.length})</h3>
|
||||
<div class="flex flex-wrap items-start gap-3">
|
||||
{#each imageItems as item (item.id)}
|
||||
{#if item.preview}
|
||||
<ChatAttachmentThumbnailImage
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if previewItem}
|
||||
<DialogChatAttachmentPreview
|
||||
bind:open={previewDialogOpen}
|
||||
uploadedFile={previewItem.uploadedFile}
|
||||
attachment={previewItem.attachment}
|
||||
preview={previewItem.preview}
|
||||
name={previewItem.name}
|
||||
size={previewItem.size}
|
||||
textContent={previewItem.textContent}
|
||||
{activeModelId}
|
||||
/>
|
||||
{/if}
|
||||
@@ -1,14 +1,13 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
ChatAttachmentsList,
|
||||
ChatAttachmentMcpResources,
|
||||
ChatFormActions,
|
||||
ChatFormFileInputInvisible,
|
||||
ChatFormPromptPicker,
|
||||
ChatFormResourcePicker,
|
||||
ChatFormTextarea
|
||||
ChatFormMcpResourcesList,
|
||||
ChatFormPickers,
|
||||
ChatFormTextarea,
|
||||
DialogMcpResourcesBrowser
|
||||
} from '$lib/components/app';
|
||||
import { DialogMcpResources } from '$lib/components/app/dialogs';
|
||||
import {
|
||||
CLIPBOARD_CONTENT_QUOTE_PREFIX,
|
||||
INPUT_CLASSES,
|
||||
@@ -54,6 +53,8 @@
|
||||
isLoading?: boolean;
|
||||
placeholder?: string;
|
||||
showMcpPromptButton?: boolean;
|
||||
showAddButton?: boolean;
|
||||
showModelSelector?: boolean;
|
||||
|
||||
// Event Handlers
|
||||
onAttachmentRemove?: (index: number) => void;
|
||||
@@ -73,6 +74,8 @@
|
||||
isLoading = false,
|
||||
placeholder = 'Type a message...',
|
||||
showMcpPromptButton = false,
|
||||
showAddButton = true,
|
||||
showModelSelector = true,
|
||||
uploadedFiles = $bindable([]),
|
||||
value = $bindable(''),
|
||||
onAttachmentRemove,
|
||||
@@ -85,31 +88,21 @@
|
||||
onValueChange
|
||||
}: Props = $props();
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* STATE
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
// Component References
|
||||
let audioRecorder: AudioRecorder | undefined;
|
||||
let chatFormActionsRef: ChatFormActions | undefined = $state(undefined);
|
||||
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
|
||||
let promptPickerRef: ChatFormPromptPicker | undefined = $state(undefined);
|
||||
let resourcePickerRef: ChatFormResourcePicker | undefined = $state(undefined);
|
||||
let pickersRef: { handleKeydown: (event: KeyboardEvent) => boolean } | undefined =
|
||||
$state(undefined);
|
||||
let textareaRef: ChatFormTextarea | undefined = $state(undefined);
|
||||
|
||||
// Audio Recording State
|
||||
let isRecording = $state(false);
|
||||
let recordingSupported = $state(false);
|
||||
|
||||
// Prompt Picker State
|
||||
// Picker State
|
||||
let isPromptPickerOpen = $state(false);
|
||||
let promptSearchQuery = $state('');
|
||||
|
||||
// Inline Resource Picker State (triggered by @)
|
||||
let isInlineResourcePickerOpen = $state(false);
|
||||
let resourceSearchQuery = $state('');
|
||||
|
||||
@@ -117,22 +110,12 @@
|
||||
let isResourceDialogOpen = $state(false);
|
||||
let preSelectedResourceUri = $state<string | undefined>(undefined);
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* DERIVED STATE
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
// Configuration
|
||||
let currentConfig = $derived(config());
|
||||
let pasteLongTextToFileLength = $derived.by(() => {
|
||||
const n = Number(currentConfig.pasteLongTextToFileLen);
|
||||
return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
|
||||
});
|
||||
|
||||
// Model Selection Logic
|
||||
let isRouter = $derived(isRouterMode());
|
||||
let conversationModel = $derived(
|
||||
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
|
||||
@@ -158,7 +141,6 @@
|
||||
return null;
|
||||
});
|
||||
|
||||
// Form Validation State
|
||||
let hasModelSelected = $derived(!isRouter || !!conversationModel || !!selectedModelId());
|
||||
let hasLoadingAttachments = $derived(uploadedFiles.some((f) => f.isLoading));
|
||||
let hasAttachments = $derived(
|
||||
@@ -166,27 +148,11 @@
|
||||
);
|
||||
let canSubmit = $derived(value.trim().length > 0 || hasAttachments);
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* LIFECYCLE
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
onMount(() => {
|
||||
recordingSupported = isAudioRecordingSupported();
|
||||
audioRecorder = new AudioRecorder();
|
||||
});
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* PUBLIC API
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
export function focus() {
|
||||
textareaRef?.focus();
|
||||
}
|
||||
@@ -199,10 +165,6 @@
|
||||
chatFormActionsRef?.openModelSelector();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model is selected, open selector if not
|
||||
* @returns true if model is selected, false otherwise
|
||||
*/
|
||||
export function checkModelSelected(): boolean {
|
||||
if (!hasModelSelected) {
|
||||
chatFormActionsRef?.openModelSelector();
|
||||
@@ -211,14 +173,6 @@
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* EVENT HANDLERS - File Management
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
function handleFileSelect(files: File[]) {
|
||||
onFilesAdd?.(files);
|
||||
}
|
||||
@@ -238,14 +192,6 @@
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* EVENT HANDLERS - Input & Keyboard
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
function handleInput() {
|
||||
const perChatOverrides = conversationsStore.getAllMcpServerOverrides();
|
||||
const hasServers = mcpStore.hasEnabledServers(perChatOverrides);
|
||||
@@ -273,11 +219,7 @@
|
||||
}
|
||||
|
||||
function handleKeydown(event: KeyboardEvent) {
|
||||
if (isPromptPickerOpen && promptPickerRef?.handleKeydown(event)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (isInlineResourcePickerOpen && resourcePickerRef?.handleKeydown(event)) {
|
||||
if (pickersRef?.handleKeydown(event)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -388,14 +330,6 @@
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* EVENT HANDLERS - Prompt Picker
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
function handlePromptLoadStart(
|
||||
placeholderId: string,
|
||||
promptInfo: MCPPromptInfo,
|
||||
@@ -474,14 +408,6 @@
|
||||
textareaRef?.focus();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* EVENT HANDLERS - Inline Resource Picker
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
function handleInlineResourcePickerClose() {
|
||||
isInlineResourcePickerOpen = false;
|
||||
resourceSearchQuery = '';
|
||||
@@ -489,7 +415,6 @@
|
||||
}
|
||||
|
||||
function handleInlineResourceSelect() {
|
||||
// Clear the @query from input after resource is attached
|
||||
if (value.startsWith(RESOURCE_TRIGGER_PREFIX)) {
|
||||
value = '';
|
||||
onValueChange?.('');
|
||||
@@ -512,14 +437,6 @@
|
||||
isResourceDialogOpen = true;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* EVENT HANDLERS - Audio Recording
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
async function handleMicClick() {
|
||||
if (!audioRecorder || !recordingSupported) {
|
||||
console.warn('Audio recording not supported');
|
||||
@@ -552,29 +469,27 @@
|
||||
|
||||
<form
|
||||
class="relative {className}"
|
||||
onsubmit={(e) => {
|
||||
e.preventDefault();
|
||||
onsubmit={(event) => {
|
||||
event.preventDefault();
|
||||
|
||||
if (!canSubmit || disabled || hasLoadingAttachments) return;
|
||||
|
||||
onSubmit?.();
|
||||
}}
|
||||
>
|
||||
<ChatFormPromptPicker
|
||||
bind:this={promptPickerRef}
|
||||
isOpen={isPromptPickerOpen}
|
||||
searchQuery={promptSearchQuery}
|
||||
onClose={handlePromptPickerClose}
|
||||
<ChatFormPickers
|
||||
bind:this={pickersRef}
|
||||
{isPromptPickerOpen}
|
||||
{promptSearchQuery}
|
||||
{isInlineResourcePickerOpen}
|
||||
{resourceSearchQuery}
|
||||
onPromptPickerClose={handlePromptPickerClose}
|
||||
onInlineResourcePickerClose={handleInlineResourcePickerClose}
|
||||
onInlineResourceSelect={handleInlineResourceSelect}
|
||||
onPromptLoadStart={handlePromptLoadStart}
|
||||
onPromptLoadComplete={handlePromptLoadComplete}
|
||||
onPromptLoadError={handlePromptLoadError}
|
||||
/>
|
||||
|
||||
<ChatFormResourcePicker
|
||||
bind:this={resourcePickerRef}
|
||||
isOpen={isInlineResourcePickerOpen}
|
||||
searchQuery={resourceSearchQuery}
|
||||
onClose={handleInlineResourcePickerClose}
|
||||
onResourceSelect={handleInlineResourceSelect}
|
||||
onBrowse={handleBrowseResources}
|
||||
onInlineResourceBrowse={handleBrowseResources}
|
||||
/>
|
||||
|
||||
<div
|
||||
@@ -611,7 +526,7 @@
|
||||
/>
|
||||
|
||||
{#if mcpHasResourceAttachments()}
|
||||
<ChatAttachmentMcpResources
|
||||
<ChatFormMcpResourcesList
|
||||
class="mb-3"
|
||||
onResourceClick={(uri) => {
|
||||
preSelectedResourceUri = uri;
|
||||
@@ -624,10 +539,11 @@
|
||||
class="px-3"
|
||||
bind:this={chatFormActionsRef}
|
||||
canSend={canSubmit}
|
||||
hasText={value.trim().length > 0}
|
||||
{disabled}
|
||||
{isLoading}
|
||||
{isRecording}
|
||||
{showAddButton}
|
||||
{showModelSelector}
|
||||
{uploadedFiles}
|
||||
onFileUpload={handleFileUpload}
|
||||
onMicClick={handleMicClick}
|
||||
@@ -640,7 +556,7 @@
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<DialogMcpResources
|
||||
<DialogMcpResourcesBrowser
|
||||
bind:open={isResourceDialogOpen}
|
||||
preSelectedUri={preSelectedResourceUri}
|
||||
onAttach={(resource: MCPResourceInfo) => {
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
<script lang="ts">
|
||||
import { Plus } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { ATTACHMENT_TOOLTIP_TEXT } from '$lib/constants';
|
||||
|
||||
interface Props {
|
||||
disabled?: boolean;
|
||||
onclick?: (e: MouseEvent) => void;
|
||||
}
|
||||
|
||||
let { disabled = false, onclick }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full p-0"
|
||||
{disabled}
|
||||
{onclick}
|
||||
variant="secondary"
|
||||
type="button"
|
||||
>
|
||||
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
|
||||
|
||||
<Plus class="h-4 w-4" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>{ATTACHMENT_TOOLTIP_TEXT}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
@@ -1,17 +1,18 @@
|
||||
<script lang="ts">
|
||||
import { Plus } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import type { Snippet } from 'svelte';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import {
|
||||
ATTACHMENT_FILE_ITEMS,
|
||||
ATTACHMENT_EXTRA_ITEMS,
|
||||
ATTACHMENT_MCP_ITEMS,
|
||||
ATTACHMENT_TOOLTIP_TEXT,
|
||||
TOOLTIP_DELAY_DURATION
|
||||
} from '$lib/constants';
|
||||
import { AttachmentMenuItemId } from '$lib/enums';
|
||||
import { ChatFormActionToolsSubmenu, ChatFormActionMcpServersSubmenu } from '$lib/components/app';
|
||||
import {
|
||||
ChatFormActionAddToolsSubmenu,
|
||||
ChatFormActionAddMcpServersSubmenu
|
||||
} from '$lib/components/app';
|
||||
|
||||
import { useAttachmentMenu } from '$lib/hooks/use-attachment-menu.svelte';
|
||||
|
||||
@@ -27,6 +28,7 @@
|
||||
onMcpPromptClick?: () => void;
|
||||
onMcpSettingsClick?: () => void;
|
||||
onMcpResourcesClick?: () => void;
|
||||
trigger: Snippet<[{ disabled: boolean }]>;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -40,7 +42,8 @@
|
||||
onSystemPromptClick,
|
||||
onMcpPromptClick,
|
||||
onMcpSettingsClick,
|
||||
onMcpResourcesClick
|
||||
onMcpResourcesClick,
|
||||
trigger
|
||||
}: Props = $props();
|
||||
|
||||
let dropdownOpen = $state(false);
|
||||
@@ -62,24 +65,7 @@
|
||||
<div class="flex items-center gap-1 {className}">
|
||||
<DropdownMenu.Root bind:open={dropdownOpen}>
|
||||
<DropdownMenu.Trigger name="Attach files" {disabled}>
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full p-0"
|
||||
{disabled}
|
||||
variant="secondary"
|
||||
type="button"
|
||||
>
|
||||
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
|
||||
|
||||
<Plus class="h-4 w-4" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>{ATTACHMENT_TOOLTIP_TEXT}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{@render trigger({ disabled })}
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Content align="start" class="w-48">
|
||||
@@ -161,9 +147,9 @@
|
||||
{/if}
|
||||
{/each}
|
||||
|
||||
<ChatFormActionToolsSubmenu />
|
||||
<ChatFormActionAddToolsSubmenu />
|
||||
|
||||
<ChatFormActionMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
|
||||
<ChatFormActionAddMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
|
||||
|
||||
{#each ATTACHMENT_MCP_ITEMS as item (item.id)}
|
||||
{#if attachmentMenu.isItemVisible(item.visibleWhen)}
|
||||
@@ -0,0 +1,149 @@
|
||||
<script lang="ts">
|
||||
import { Settings, Plus } from '@lucide/svelte';
|
||||
import { Switch } from '$lib/components/ui/switch';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import { McpLogo, DropdownMenuSearchable } from '$lib/components/app';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { HealthCheckStatus } from '$lib/enums';
|
||||
import type { MCPServerSettingsEntry } from '$lib/types';
|
||||
import { goto } from '$app/navigation';
|
||||
|
||||
interface Props {
|
||||
onMcpSettingsClick?: () => void;
|
||||
}
|
||||
|
||||
let { onMcpSettingsClick }: Props = $props();
|
||||
|
||||
let mcpSearchQuery = $state('');
|
||||
let allMcpServers = $derived(mcpStore.getServersSorted());
|
||||
let mcpServers = $derived(allMcpServers.filter((s) => s.enabled));
|
||||
let hasMcpServers = $derived(mcpServers.length > 0);
|
||||
// let hasAnyMcpServers = $derived(allMcpServers.length > 0);
|
||||
let filteredMcpServers = $derived.by(() => {
|
||||
const query = mcpSearchQuery.toLowerCase().trim();
|
||||
if (!query) return mcpServers;
|
||||
return mcpServers.filter((s) => {
|
||||
const name = getServerLabel(s).toLowerCase();
|
||||
const url = s.url.toLowerCase();
|
||||
return name.includes(query) || url.includes(query);
|
||||
});
|
||||
});
|
||||
|
||||
function getServerLabel(server: MCPServerSettingsEntry): string {
|
||||
return mcpStore.getServerLabel(server);
|
||||
}
|
||||
|
||||
function isServerEnabledForChat(serverId: string): boolean {
|
||||
return conversationsStore.isMcpServerEnabledForChat(serverId);
|
||||
}
|
||||
|
||||
async function toggleServerForChat(serverId: string) {
|
||||
await conversationsStore.toggleMcpServerForChat(serverId);
|
||||
}
|
||||
|
||||
function handleMcpSubMenuOpen(open: boolean) {
|
||||
if (open) {
|
||||
mcpSearchQuery = '';
|
||||
mcpStore.runHealthChecksForServers(allMcpServers);
|
||||
}
|
||||
}
|
||||
|
||||
function handleMcpSettingsClick() {
|
||||
onMcpSettingsClick?.();
|
||||
|
||||
goto(`${hasMcpServers ? '' : '?add'}#/settings/mcp`);
|
||||
}
|
||||
</script>
|
||||
|
||||
<DropdownMenu.Root>
|
||||
<DropdownMenu.Sub onOpenChange={handleMcpSubMenuOpen}>
|
||||
<DropdownMenu.SubTrigger class="flex cursor-pointer items-center gap-2">
|
||||
<McpLogo class="h-4 w-4" />
|
||||
|
||||
<span>MCP Servers</span>
|
||||
</DropdownMenu.SubTrigger>
|
||||
|
||||
<DropdownMenu.SubContent class="w-72 pt-0">
|
||||
{#if hasMcpServers}
|
||||
<DropdownMenuSearchable
|
||||
placeholder="Search servers..."
|
||||
bind:searchValue={mcpSearchQuery}
|
||||
emptyMessage="No servers found"
|
||||
isEmpty={filteredMcpServers.length === 0}
|
||||
>
|
||||
<div class="max-h-64 overflow-y-auto">
|
||||
{#each filteredMcpServers as server (server.id)}
|
||||
{@const healthState = mcpStore.getHealthCheckState(server.id)}
|
||||
{@const hasError = healthState.status === HealthCheckStatus.ERROR}
|
||||
{@const isEnabledForChat = isServerEnabledForChat(server.id)}
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="flex w-full items-center justify-between gap-2 rounded-sm px-2 py-2 text-left transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50"
|
||||
onclick={() => !hasError && toggleServerForChat(server.id)}
|
||||
disabled={hasError}
|
||||
>
|
||||
<div class="flex min-w-0 flex-1 items-center gap-2">
|
||||
{#if mcpStore.getServerFavicon(server.id)}
|
||||
<img
|
||||
src={mcpStore.getServerFavicon(server.id)}
|
||||
alt=""
|
||||
class="h-4 w-4 shrink-0 rounded-sm"
|
||||
onerror={(e) => {
|
||||
(e.currentTarget as HTMLImageElement).style.display = 'none';
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<span class="truncate text-sm">{getServerLabel(server)}</span>
|
||||
|
||||
{#if hasError}
|
||||
<span
|
||||
class="shrink-0 rounded bg-destructive/15 px-1.5 py-0.5 text-xs text-destructive"
|
||||
>
|
||||
Error
|
||||
</span>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<Switch
|
||||
checked={isEnabledForChat}
|
||||
disabled={hasError}
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
onCheckedChange={() => toggleServerForChat(server.id)}
|
||||
/>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
{#snippet footer()}
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={handleMcpSettingsClick}
|
||||
>
|
||||
<Settings class="h-4 w-4" />
|
||||
|
||||
<span>Manage MCP Servers</span>
|
||||
</DropdownMenu.Item>
|
||||
{/snippet}
|
||||
</DropdownMenuSearchable>
|
||||
{:else}
|
||||
<div class="px-2 py-3 text-center text-sm text-muted-foreground">
|
||||
No MCP servers configured
|
||||
</div>
|
||||
|
||||
<DropdownMenu.Separator />
|
||||
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={handleMcpSettingsClick}
|
||||
>
|
||||
<Plus class="h-4 w-4" />
|
||||
|
||||
<span>Add MCP Servers</span>
|
||||
</DropdownMenu.Item>
|
||||
{/if}
|
||||
</DropdownMenu.SubContent>
|
||||
</DropdownMenu.Sub>
|
||||
</DropdownMenu.Root>
|
||||
@@ -1,18 +1,17 @@
|
||||
<script lang="ts">
|
||||
import { Plus } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import type { Snippet } from 'svelte';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import * as Sheet from '$lib/components/ui/sheet';
|
||||
import { TOOLTIP_DELAY_DURATION } from '$lib/constants';
|
||||
import {
|
||||
ATTACHMENT_FILE_ITEMS,
|
||||
ATTACHMENT_EXTRA_ITEMS,
|
||||
ATTACHMENT_MCP_ITEMS,
|
||||
ATTACHMENT_TOOLTIP_TEXT
|
||||
ATTACHMENT_MCP_ITEMS
|
||||
} from '$lib/constants/attachment-menu';
|
||||
import { ChatFormActionToolsSubmenu, ChatFormActionMcpServersSubmenu } from '$lib/components/app';
|
||||
import { McpLogo } from '$lib/components/app';
|
||||
import { useAttachmentMenu } from '$lib/hooks/use-attachment-menu.svelte';
|
||||
import { AttachmentMenuItemId } from '$lib/enums';
|
||||
import { PencilRuler } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -24,8 +23,8 @@
|
||||
onFileUpload?: () => void;
|
||||
onSystemPromptClick?: () => void;
|
||||
onMcpPromptClick?: () => void;
|
||||
onMcpSettingsClick?: () => void;
|
||||
onMcpResourcesClick?: () => void;
|
||||
trigger: Snippet<[{ disabled: boolean; onclick?: () => void }]>;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -38,8 +37,8 @@
|
||||
onFileUpload,
|
||||
onSystemPromptClick,
|
||||
onMcpPromptClick,
|
||||
onMcpSettingsClick,
|
||||
onMcpResourcesClick
|
||||
onMcpResourcesClick,
|
||||
trigger
|
||||
}: Props = $props();
|
||||
|
||||
let sheetOpen = $state(false);
|
||||
@@ -52,28 +51,14 @@
|
||||
}
|
||||
);
|
||||
|
||||
function handleMcpSettingsClick() {
|
||||
sheetOpen = false;
|
||||
onMcpSettingsClick?.();
|
||||
}
|
||||
|
||||
const sheetItemClass =
|
||||
'flex w-full items-center gap-3 rounded-md px-3 py-2.5 text-left text-sm transition-colors hover:bg-accent active:bg-accent disabled:cursor-not-allowed disabled:opacity-50';
|
||||
</script>
|
||||
|
||||
<div class="flex items-center gap-1 {className}">
|
||||
<Sheet.Root bind:open={sheetOpen}>
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full p-0"
|
||||
{disabled}
|
||||
variant="secondary"
|
||||
type="button"
|
||||
onclick={() => (sheetOpen = true)}
|
||||
>
|
||||
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
|
||||
|
||||
<Plus class="h-4 w-4" />
|
||||
</Button>
|
||||
{@render trigger({ disabled, onclick: () => (sheetOpen = true) })}
|
||||
<!-- <ChatFormActionAddButton {disabled} onclick={() => (sheetOpen = true)} /> -->
|
||||
|
||||
<Sheet.Content side="bottom" class="max-h-[85vh] gap-0 overflow-y-auto">
|
||||
<Sheet.Header>
|
||||
@@ -161,9 +146,17 @@
|
||||
|
||||
<div class="my-2 border-t"></div>
|
||||
|
||||
<ChatFormActionToolsSubmenu />
|
||||
<a href="#/settings/mcp" class="flex items-center gap-3 px-3 py-2">
|
||||
<McpLogo class="inline h-4 w-4" />
|
||||
|
||||
<ChatFormActionMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
|
||||
<span class="text-sm">MCP Servers</span>
|
||||
</a>
|
||||
|
||||
<a href="#/settings/chat/tools" class="flex items-center gap-3 px-3 py-2">
|
||||
<PencilRuler class="inline h-4 w-4" />
|
||||
|
||||
<span class="text-sm">Tools</span>
|
||||
</a>
|
||||
|
||||
{#each ATTACHMENT_MCP_ITEMS as item (item.id)}
|
||||
{#if attachmentMenu.isItemVisible(item.visibleWhen)}
|
||||
@@ -24,6 +24,7 @@
|
||||
{#if toolsStore.loading}
|
||||
<div class="px-3 py-4 text-center text-sm text-muted-foreground">
|
||||
<Loader2 class="mx-auto mb-1 h-4 w-4 animate-spin" />
|
||||
|
||||
Loading tools...
|
||||
</div>
|
||||
{:else if toolsStore.isToolsEndpointUnreachable}
|
||||
@@ -31,19 +32,21 @@
|
||||
<span class="flex gap-2">
|
||||
<Info class="mt-0.5 h-4 w-4 shrink-0" />
|
||||
|
||||
<span
|
||||
>Run llama-server with <code>--tools</code> flag to enable
|
||||
<strong>Built-in Tools</strong>.</span
|
||||
>
|
||||
<span>
|
||||
Run llama-server with <code>--tools</code> flag to enable
|
||||
|
||||
<strong>Built-in Tools</strong>.
|
||||
</span>
|
||||
</span>
|
||||
|
||||
<span class="flex gap-2">
|
||||
<Info class="mt-0.5 h-4 w-4 shrink-0" />
|
||||
|
||||
<span
|
||||
>{hasMcpServersAvailable ? 'Enable' : 'Add'} MCP Server(s) to access
|
||||
<strong>MCP Tools</strong>.</span
|
||||
>
|
||||
<span>
|
||||
{hasMcpServersAvailable ? 'Enable' : 'Add'} MCP Server(s) to access
|
||||
|
||||
<strong>MCP Tools</strong>.
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
{:else if toolsStore.error}
|
||||
@@ -0,0 +1,68 @@
|
||||
<script lang="ts">
|
||||
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
|
||||
import ChatFormActionAddDropdown from './ChatFormActionAddDropdown.svelte';
|
||||
import ChatFormActionAddSheet from './ChatFormActionAddSheet.svelte';
|
||||
import ChatFormActionAddButton from './ChatFormActionAddButton.svelte';
|
||||
|
||||
interface Props {
|
||||
disabled?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
hasMcpPromptsSupport?: boolean;
|
||||
hasMcpResourcesSupport?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
onFileUpload?: () => void;
|
||||
onMcpPromptClick?: () => void;
|
||||
onMcpResourcesClick?: () => void;
|
||||
onMcpSettingsClick?: () => void;
|
||||
onSystemPromptClick?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
disabled = false,
|
||||
hasAudioModality = false,
|
||||
hasMcpPromptsSupport = false,
|
||||
hasMcpResourcesSupport = false,
|
||||
hasVisionModality = false,
|
||||
onFileUpload,
|
||||
onMcpPromptClick,
|
||||
onMcpResourcesClick,
|
||||
onMcpSettingsClick,
|
||||
onSystemPromptClick
|
||||
}: Props = $props();
|
||||
|
||||
const isMobile = new IsMobile();
|
||||
</script>
|
||||
|
||||
{#if isMobile.current}
|
||||
<ChatFormActionAddSheet
|
||||
{disabled}
|
||||
{hasAudioModality}
|
||||
{hasVisionModality}
|
||||
{hasMcpPromptsSupport}
|
||||
{hasMcpResourcesSupport}
|
||||
{onFileUpload}
|
||||
{onMcpPromptClick}
|
||||
{onMcpResourcesClick}
|
||||
>
|
||||
{#snippet trigger({ disabled, onclick })}
|
||||
<ChatFormActionAddButton {disabled} {onclick} />
|
||||
{/snippet}
|
||||
</ChatFormActionAddSheet>
|
||||
{:else}
|
||||
<ChatFormActionAddDropdown
|
||||
{disabled}
|
||||
{hasAudioModality}
|
||||
{hasVisionModality}
|
||||
{hasMcpPromptsSupport}
|
||||
{hasMcpResourcesSupport}
|
||||
{onFileUpload}
|
||||
{onMcpPromptClick}
|
||||
{onMcpResourcesClick}
|
||||
{onMcpSettingsClick}
|
||||
{onSystemPromptClick}
|
||||
>
|
||||
{#snippet trigger()}
|
||||
<ChatFormActionAddButton {disabled} />
|
||||
{/snippet}
|
||||
</ChatFormActionAddDropdown>
|
||||
{/if}
|
||||
@@ -1,147 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { Settings, Plus } from '@lucide/svelte';
|
||||
import { Switch } from '$lib/components/ui/switch';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import { McpLogo, DropdownMenuSearchable } from '$lib/components/app';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { HealthCheckStatus } from '$lib/enums';
|
||||
import type { MCPServerSettingsEntry } from '$lib/types';
|
||||
import { goto } from '$app/navigation';
|
||||
|
||||
interface Props {
|
||||
onMcpSettingsClick?: () => void;
|
||||
}
|
||||
|
||||
let { onMcpSettingsClick }: Props = $props();
|
||||
|
||||
let mcpSearchQuery = $state('');
|
||||
let allMcpServers = $derived(mcpStore.getServersSorted());
|
||||
let mcpServers = $derived(allMcpServers.filter((s) => s.enabled));
|
||||
let hasMcpServers = $derived(mcpServers.length > 0);
|
||||
// let hasAnyMcpServers = $derived(allMcpServers.length > 0);
|
||||
let filteredMcpServers = $derived.by(() => {
|
||||
const query = mcpSearchQuery.toLowerCase().trim();
|
||||
if (!query) return mcpServers;
|
||||
return mcpServers.filter((s) => {
|
||||
const name = getServerLabel(s).toLowerCase();
|
||||
const url = s.url.toLowerCase();
|
||||
return name.includes(query) || url.includes(query);
|
||||
});
|
||||
});
|
||||
|
||||
function getServerLabel(server: MCPServerSettingsEntry): string {
|
||||
return mcpStore.getServerLabel(server);
|
||||
}
|
||||
|
||||
function isServerEnabledForChat(serverId: string): boolean {
|
||||
return conversationsStore.isMcpServerEnabledForChat(serverId);
|
||||
}
|
||||
|
||||
async function toggleServerForChat(serverId: string) {
|
||||
await conversationsStore.toggleMcpServerForChat(serverId);
|
||||
}
|
||||
|
||||
function handleMcpSubMenuOpen(open: boolean) {
|
||||
if (open) {
|
||||
mcpSearchQuery = '';
|
||||
mcpStore.runHealthChecksForServers(allMcpServers);
|
||||
}
|
||||
}
|
||||
|
||||
function handleMcpSettingsClick() {
|
||||
onMcpSettingsClick?.();
|
||||
|
||||
goto(`${hasMcpServers ? '' : '?add'}#/settings/mcp`);
|
||||
}
|
||||
</script>
|
||||
|
||||
<DropdownMenu.Sub onOpenChange={handleMcpSubMenuOpen}>
|
||||
<DropdownMenu.SubTrigger class="flex cursor-pointer items-center gap-2">
|
||||
<McpLogo class="h-4 w-4" />
|
||||
|
||||
<span>MCP Servers</span>
|
||||
</DropdownMenu.SubTrigger>
|
||||
|
||||
<DropdownMenu.SubContent class="w-72 pt-0">
|
||||
{#if hasMcpServers}
|
||||
<DropdownMenuSearchable
|
||||
placeholder="Search servers..."
|
||||
bind:searchValue={mcpSearchQuery}
|
||||
emptyMessage="No servers found"
|
||||
isEmpty={filteredMcpServers.length === 0}
|
||||
>
|
||||
<div class="max-h-64 overflow-y-auto">
|
||||
{#each filteredMcpServers as server (server.id)}
|
||||
{@const healthState = mcpStore.getHealthCheckState(server.id)}
|
||||
{@const hasError = healthState.status === HealthCheckStatus.ERROR}
|
||||
{@const isEnabledForChat = isServerEnabledForChat(server.id)}
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="flex w-full items-center justify-between gap-2 rounded-sm px-2 py-2 text-left transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50"
|
||||
onclick={() => !hasError && toggleServerForChat(server.id)}
|
||||
disabled={hasError}
|
||||
>
|
||||
<div class="flex min-w-0 flex-1 items-center gap-2">
|
||||
{#if mcpStore.getServerFavicon(server.id)}
|
||||
<img
|
||||
src={mcpStore.getServerFavicon(server.id)}
|
||||
alt=""
|
||||
class="h-4 w-4 shrink-0 rounded-sm"
|
||||
onerror={(e) => {
|
||||
(e.currentTarget as HTMLImageElement).style.display = 'none';
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<span class="truncate text-sm">{getServerLabel(server)}</span>
|
||||
|
||||
{#if hasError}
|
||||
<span
|
||||
class="shrink-0 rounded bg-destructive/15 px-1.5 py-0.5 text-xs text-destructive"
|
||||
>
|
||||
Error
|
||||
</span>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<Switch
|
||||
checked={isEnabledForChat}
|
||||
disabled={hasError}
|
||||
onclick={(e: MouseEvent) => e.stopPropagation()}
|
||||
onCheckedChange={() => toggleServerForChat(server.id)}
|
||||
/>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
{#snippet footer()}
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={handleMcpSettingsClick}
|
||||
>
|
||||
<Settings class="h-4 w-4" />
|
||||
|
||||
<span>Manage MCP Servers</span>
|
||||
</DropdownMenu.Item>
|
||||
{/snippet}
|
||||
</DropdownMenuSearchable>
|
||||
{:else}
|
||||
<div class="px-2 py-3 text-center text-sm text-muted-foreground">
|
||||
No MCP servers configured
|
||||
</div>
|
||||
|
||||
<DropdownMenu.Separator />
|
||||
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={handleMcpSettingsClick}
|
||||
>
|
||||
<Plus class="h-4 w-4" />
|
||||
|
||||
<span>Add MCP Servers</span>
|
||||
</DropdownMenu.Item>
|
||||
{/if}
|
||||
</DropdownMenu.SubContent>
|
||||
</DropdownMenu.Sub>
|
||||
@@ -0,0 +1,160 @@
|
||||
<script lang="ts">
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
|
||||
import { isRouterMode, serverError } from '$lib/stores/server.svelte';
|
||||
import { ModelsSelectorDropdown, ModelsSelectorSheet } from '$lib/components/app';
|
||||
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
|
||||
import { activeMessages } from '$lib/stores/conversations.svelte';
|
||||
|
||||
interface Props {
|
||||
currentModel?: string;
|
||||
disabled?: boolean;
|
||||
forceForegroundText?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
hasModelSelected?: boolean;
|
||||
isSelectedModelInCache?: boolean;
|
||||
submitTooltip?: string;
|
||||
useGlobalSelection?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
currentModel,
|
||||
disabled = false,
|
||||
forceForegroundText = false,
|
||||
hasAudioModality = $bindable(false),
|
||||
hasVisionModality = $bindable(false),
|
||||
hasModelSelected = $bindable(false),
|
||||
isSelectedModelInCache = $bindable(true),
|
||||
submitTooltip = $bindable(''),
|
||||
useGlobalSelection = false
|
||||
}: Props = $props();
|
||||
|
||||
let isRouter = $derived(isRouterMode());
|
||||
let isOffline = $derived(!!serverError());
|
||||
|
||||
let conversationModel = $derived(
|
||||
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
|
||||
);
|
||||
|
||||
let lastSyncedConversationModel: string | null = null;
|
||||
|
||||
$effect(() => {
|
||||
if (conversationModel && conversationModel !== lastSyncedConversationModel) {
|
||||
lastSyncedConversationModel = conversationModel;
|
||||
|
||||
modelsStore.selectModelByName(conversationModel);
|
||||
} else if (isRouter && !modelsStore.selectedModelId && modelsStore.loadedModelIds.length > 0) {
|
||||
lastSyncedConversationModel = null;
|
||||
// auto-select the first loaded model only when nothing is selected yet
|
||||
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
|
||||
|
||||
if (first) modelsStore.selectModelById(first.id);
|
||||
}
|
||||
});
|
||||
|
||||
let activeModelId = $derived.by(() => {
|
||||
const options = modelOptions();
|
||||
|
||||
if (!isRouter) {
|
||||
return options.length > 0 ? options[0].model : null;
|
||||
}
|
||||
|
||||
const selectedId = selectedModelId();
|
||||
|
||||
if (selectedId) {
|
||||
const model = options.find((m) => m.id === selectedId);
|
||||
|
||||
if (model) return model.model;
|
||||
}
|
||||
|
||||
if (conversationModel) {
|
||||
const model = options.find((m) => m.model === conversationModel);
|
||||
|
||||
if (model) return model.model;
|
||||
}
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
let modelPropsVersion = $state(0); // Used to trigger reactivity after fetch
|
||||
|
||||
$effect(() => {
|
||||
if (activeModelId) {
|
||||
const cached = modelsStore.getModelProps(activeModelId);
|
||||
|
||||
if (!cached) {
|
||||
modelsStore.fetchModelProps(activeModelId).then(() => {
|
||||
modelPropsVersion++;
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
hasAudioModality = activeModelId ? modelsStore.modelSupportsAudio(activeModelId) : false;
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
void modelPropsVersion;
|
||||
|
||||
hasVisionModality = activeModelId ? modelsStore.modelSupportsVision(activeModelId) : false;
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
hasModelSelected = !isRouter || !!conversationModel || !!selectedModelId();
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (!isRouter) {
|
||||
isSelectedModelInCache = true;
|
||||
} else if (conversationModel) {
|
||||
isSelectedModelInCache = modelOptions().some((option) => option.model === conversationModel);
|
||||
} else {
|
||||
const currentModelId = selectedModelId();
|
||||
|
||||
if (!currentModelId) {
|
||||
isSelectedModelInCache = false;
|
||||
} else {
|
||||
isSelectedModelInCache = modelOptions().some((option) => option.id === currentModelId);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (!hasModelSelected) {
|
||||
submitTooltip = 'Please select a model first';
|
||||
} else if (!isSelectedModelInCache) {
|
||||
submitTooltip = 'Selected model is not available, please select another';
|
||||
} else {
|
||||
submitTooltip = '';
|
||||
}
|
||||
});
|
||||
|
||||
let selectorModelRef: ModelsSelectorDropdown | ModelsSelectorSheet | undefined =
|
||||
$state(undefined);
|
||||
|
||||
let isMobile = new IsMobile();
|
||||
|
||||
export function open() {
|
||||
selectorModelRef?.open();
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if isMobile.current}
|
||||
<ModelsSelectorSheet
|
||||
disabled={disabled || isOffline}
|
||||
bind:this={selectorModelRef}
|
||||
{currentModel}
|
||||
{forceForegroundText}
|
||||
{useGlobalSelection}
|
||||
/>
|
||||
{:else}
|
||||
<ModelsSelectorDropdown
|
||||
disabled={disabled || isOffline}
|
||||
bind:this={selectorModelRef}
|
||||
{currentModel}
|
||||
{forceForegroundText}
|
||||
{useGlobalSelection}
|
||||
/>
|
||||
{/if}
|
||||
@@ -2,7 +2,6 @@
|
||||
import { ArrowUp } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
|
||||
interface Props {
|
||||
canSend?: boolean;
|
||||
@@ -20,12 +19,11 @@
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isDisabled}
|
||||
class={cn(
|
||||
class={[
|
||||
'h-8 w-8 rounded-full p-0',
|
||||
showErrorState
|
||||
? 'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
|
||||
: ''
|
||||
)}
|
||||
showErrorState &&
|
||||
'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
|
||||
]}
|
||||
{...props}
|
||||
>
|
||||
<span class="sr-only">Send</span>
|
||||
|
||||
@@ -2,31 +2,27 @@
|
||||
import { Square } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import {
|
||||
ChatFormActionAttachmentsDropdown,
|
||||
ChatFormActionAttachmentsSheet,
|
||||
ChatFormActionsAdd,
|
||||
ChatFormActionModels,
|
||||
ChatFormActionRecord,
|
||||
ChatFormActionSubmit,
|
||||
ModelsSelectorDropdown,
|
||||
ModelsSelectorSheet
|
||||
ChatFormActionSubmit
|
||||
} from '$lib/components/app';
|
||||
import { FileTypeCategory } from '$lib/enums';
|
||||
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
|
||||
import { isRouterMode, serverError } from '$lib/stores/server.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { activeMessages, conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { getFileTypeCategory } from '$lib/utils';
|
||||
import { goto } from '$app/navigation';
|
||||
|
||||
interface Props {
|
||||
canSend?: boolean;
|
||||
canSubmit?: boolean;
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
isLoading?: boolean;
|
||||
isRecording?: boolean;
|
||||
hasText?: boolean;
|
||||
showAddButton?: boolean;
|
||||
showModelSelector?: boolean;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
onFileUpload?: () => void;
|
||||
onMicClick?: () => void;
|
||||
@@ -38,11 +34,13 @@
|
||||
|
||||
let {
|
||||
canSend = false,
|
||||
canSubmit = false,
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
isLoading = false,
|
||||
isRecording = false,
|
||||
hasText = false,
|
||||
showAddButton = true,
|
||||
showModelSelector = true,
|
||||
uploadedFiles = [],
|
||||
onFileUpload,
|
||||
onMicClick,
|
||||
@@ -53,124 +51,6 @@
|
||||
}: Props = $props();
|
||||
|
||||
let currentConfig = $derived(config());
|
||||
let isRouter = $derived(isRouterMode());
|
||||
let isOffline = $derived(!!serverError());
|
||||
|
||||
let conversationModel = $derived(
|
||||
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
|
||||
);
|
||||
|
||||
let lastSyncedConversationModel: string | null = null;
|
||||
|
||||
$effect(() => {
|
||||
if (conversationModel && conversationModel !== lastSyncedConversationModel) {
|
||||
lastSyncedConversationModel = conversationModel;
|
||||
modelsStore.selectModelByName(conversationModel);
|
||||
} else if (isRouter && !modelsStore.selectedModelId && modelsStore.loadedModelIds.length > 0) {
|
||||
lastSyncedConversationModel = null;
|
||||
// auto-select the first loaded model only when nothing is selected yet
|
||||
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
|
||||
if (first) modelsStore.selectModelById(first.id);
|
||||
}
|
||||
});
|
||||
|
||||
let activeModelId = $derived.by(() => {
|
||||
const options = modelOptions();
|
||||
|
||||
if (!isRouter) {
|
||||
return options.length > 0 ? options[0].model : null;
|
||||
}
|
||||
|
||||
const selectedId = selectedModelId();
|
||||
if (selectedId) {
|
||||
const model = options.find((m) => m.id === selectedId);
|
||||
if (model) return model.model;
|
||||
}
|
||||
|
||||
if (conversationModel) {
|
||||
const model = options.find((m) => m.model === conversationModel);
|
||||
if (model) return model.model;
|
||||
}
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
let modelPropsVersion = $state(0); // Used to trigger reactivity after fetch
|
||||
|
||||
$effect(() => {
|
||||
if (activeModelId) {
|
||||
const cached = modelsStore.getModelProps(activeModelId);
|
||||
|
||||
if (!cached) {
|
||||
modelsStore.fetchModelProps(activeModelId).then(() => {
|
||||
modelPropsVersion++;
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let hasAudioModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
|
||||
return modelsStore.modelSupportsAudio(activeModelId);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
let hasVisionModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
|
||||
return modelsStore.modelSupportsVision(activeModelId);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
let hasAudioAttachments = $derived(
|
||||
uploadedFiles.some((file) => getFileTypeCategory(file.type) === FileTypeCategory.AUDIO)
|
||||
);
|
||||
let shouldShowRecordButton = $derived(
|
||||
hasAudioModality && !hasText && !hasAudioAttachments && currentConfig.autoMicOnEmpty
|
||||
);
|
||||
|
||||
let hasModelSelected = $derived(!isRouter || !!conversationModel || !!selectedModelId());
|
||||
|
||||
let isSelectedModelInCache = $derived.by(() => {
|
||||
if (!isRouter) return true;
|
||||
|
||||
if (conversationModel) {
|
||||
return modelOptions().some((option) => option.model === conversationModel);
|
||||
}
|
||||
|
||||
const currentModelId = selectedModelId();
|
||||
if (!currentModelId) return false;
|
||||
|
||||
return modelOptions().some((option) => option.id === currentModelId);
|
||||
});
|
||||
|
||||
let submitTooltip = $derived.by(() => {
|
||||
if (!hasModelSelected) {
|
||||
return 'Please select a model first';
|
||||
}
|
||||
|
||||
if (!isSelectedModelInCache) {
|
||||
return 'Selected model is not available, please select another';
|
||||
}
|
||||
|
||||
return '';
|
||||
});
|
||||
|
||||
let selectorModelRef: ModelsSelectorDropdown | ModelsSelectorSheet | undefined =
|
||||
$state(undefined);
|
||||
|
||||
let isMobile = new IsMobile();
|
||||
|
||||
export function openModelSelector() {
|
||||
selectorModelRef?.open();
|
||||
}
|
||||
|
||||
let hasMcpPromptsSupport = $derived.by(() => {
|
||||
const perChatOverrides = conversationsStore.getAllMcpServerOverrides();
|
||||
@@ -183,25 +63,34 @@
|
||||
|
||||
return mcpStore.hasResourcesCapability(perChatOverrides);
|
||||
});
|
||||
|
||||
let hasAudioModality = $state(false);
|
||||
let hasVisionModality = $state(false);
|
||||
let hasModelSelected = $state(false);
|
||||
let isSelectedModelInCache = $state(true);
|
||||
let submitTooltip = $state('');
|
||||
|
||||
let hasAudioAttachments = $derived(
|
||||
uploadedFiles.some((file) => getFileTypeCategory(file.type) === FileTypeCategory.AUDIO)
|
||||
);
|
||||
let shouldShowRecordButton = $derived(
|
||||
hasAudioModality && !canSubmit && !hasAudioAttachments && currentConfig.autoMicOnEmpty
|
||||
);
|
||||
|
||||
let selectorModelRef: ChatFormActionModels | undefined = $state(undefined);
|
||||
|
||||
export function openModelSelector() {
|
||||
selectorModelRef?.open();
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex w-full items-center gap-3 {className}" style="container-type: inline-size">
|
||||
<div class="mr-auto flex items-center gap-2">
|
||||
{#if isMobile.current}
|
||||
<ChatFormActionAttachmentsSheet
|
||||
{disabled}
|
||||
{hasAudioModality}
|
||||
{hasVisionModality}
|
||||
{hasMcpPromptsSupport}
|
||||
{hasMcpResourcesSupport}
|
||||
{onFileUpload}
|
||||
{onSystemPromptClick}
|
||||
{onMcpPromptClick}
|
||||
onMcpSettingsClick={() => goto('#/settings/mcp')}
|
||||
{onMcpResourcesClick}
|
||||
/>
|
||||
{:else}
|
||||
<ChatFormActionAttachmentsDropdown
|
||||
<div
|
||||
class="flex w-full items-center gap-3 {className} {showAddButton ? '' : 'justify-end'}"
|
||||
style="container-type: inline-size"
|
||||
>
|
||||
{#if showAddButton}
|
||||
<div class="mr-auto flex items-center gap-2">
|
||||
<ChatFormActionsAdd
|
||||
{disabled}
|
||||
{hasAudioModality}
|
||||
{hasVisionModality}
|
||||
@@ -213,30 +102,24 @@
|
||||
{onMcpResourcesClick}
|
||||
onMcpSettingsClick={() => goto('#/settings/mcp')}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="ml-auto flex items-center gap-2">
|
||||
{#if isMobile.current}
|
||||
<ModelsSelectorSheet
|
||||
disabled={disabled || isOffline}
|
||||
bind:this={selectorModelRef}
|
||||
currentModel={conversationModel}
|
||||
forceForegroundText
|
||||
useGlobalSelection
|
||||
/>
|
||||
{:else}
|
||||
<ModelsSelectorDropdown
|
||||
disabled={disabled || isOffline}
|
||||
bind:this={selectorModelRef}
|
||||
currentModel={conversationModel}
|
||||
forceForegroundText
|
||||
useGlobalSelection
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
{#if showModelSelector}
|
||||
<ChatFormActionModels
|
||||
{disabled}
|
||||
bind:this={selectorModelRef}
|
||||
bind:hasAudioModality
|
||||
bind:hasVisionModality
|
||||
bind:hasModelSelected
|
||||
bind:isSelectedModelInCache
|
||||
bind:submitTooltip
|
||||
forceForegroundText
|
||||
useGlobalSelection
|
||||
/>
|
||||
{/if}
|
||||
|
||||
{#if isLoading && !hasText}
|
||||
{#if isLoading && !canSubmit}
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
@@ -253,10 +136,10 @@
|
||||
<ChatFormActionRecord {disabled} {hasAudioModality} {isLoading} {isRecording} {onMicClick} />
|
||||
{:else}
|
||||
<ChatFormActionSubmit
|
||||
canSend={canSend && hasModelSelected && isSelectedModelInCache}
|
||||
canSend={canSend && (showModelSelector ? hasModelSelected && isSelectedModelInCache : true)}
|
||||
{disabled}
|
||||
tooltipLabel={submitTooltip}
|
||||
showErrorState={hasModelSelected && !isSelectedModelInCache}
|
||||
showErrorState={showModelSelector && hasModelSelected && !isSelectedModelInCache}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
function handleFileSelect(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
|
||||
if (input.files) {
|
||||
onFileSelect?.(Array.from(input.files));
|
||||
}
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { browser } from '$app/environment';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
show?: boolean;
|
||||
}
|
||||
|
||||
let { class: className = '', show = true }: Props = $props();
|
||||
|
||||
let sendOnEnter = $derived(config().sendOnEnter !== false);
|
||||
let modKey = browser && /Mac|iPhone|iPad|iPod/.test(navigator.platform) ? 'Cmd' : 'Ctrl';
|
||||
</script>
|
||||
|
||||
{#if show}
|
||||
<div class="mt-6 items-center justify-center {className} hidden md:flex">
|
||||
{#if sendOnEnter}
|
||||
<p class="text-xs text-muted-foreground">
|
||||
Press <kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">Enter</kbd> to send,
|
||||
<kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">Shift + Enter</kbd> for new line
|
||||
</p>
|
||||
{:else}
|
||||
<p class="text-xs text-muted-foreground">
|
||||
Press <kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">{modKey} + Enter</kbd> to
|
||||
send,
|
||||
<kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">Enter</kbd> for new line
|
||||
</p>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
@@ -4,7 +4,10 @@
|
||||
mcpResourceAttachments,
|
||||
mcpHasResourceAttachments
|
||||
} from '$lib/stores/mcp-resources.svelte';
|
||||
import { ChatAttachmentMcpResource, HorizontalScrollCarousel } from '$lib/components/app';
|
||||
import {
|
||||
ChatAttachmentsListItemMcpResource,
|
||||
HorizontalScrollCarousel
|
||||
} from '$lib/components/app';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -29,11 +32,11 @@
|
||||
<div class={className}>
|
||||
<HorizontalScrollCarousel gapSize="2">
|
||||
{#each attachments as attachment, i (attachment.id)}
|
||||
<ChatAttachmentMcpResource
|
||||
<ChatAttachmentsListItemMcpResource
|
||||
class={i === 0 ? 'ml-3' : ''}
|
||||
{attachment}
|
||||
onRemove={handleRemove}
|
||||
onClick={() => handleResourceClick(attachment.resource.uri)}
|
||||
onclick={() => handleResourceClick(attachment.resource.uri)}
|
||||
/>
|
||||
{/each}
|
||||
</HorizontalScrollCarousel>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user