mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-16 16:27:32 +03:00
Compare commits
11 Commits
b8771
...
mtmd-video
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c5b682b25c | ||
|
|
f558360b32 | ||
|
|
75f3bc94e6 | ||
|
|
aa00911d12 | ||
|
|
ce8fd4b1a6 | ||
|
|
9f5e1edb10 | ||
|
|
920b3e78cb | ||
|
|
974c8c94cc | ||
|
|
227ed28e12 | ||
|
|
bafae27654 | ||
|
|
573f2cf58e |
@@ -258,6 +258,9 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
if (progress_step >= p.total / 1000 || p.downloaded == p.total) {
|
||||
if (callback) {
|
||||
callback->on_update(p);
|
||||
if (callback->is_cancelled()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
progress_step = 0;
|
||||
}
|
||||
@@ -373,6 +376,9 @@ static int common_download_file_single_online(const std::string & url,
|
||||
}
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
if (opts.callback && opts.callback->is_cancelled()) {
|
||||
break;
|
||||
}
|
||||
if (i) {
|
||||
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(delay));
|
||||
@@ -412,6 +418,12 @@ static int common_download_file_single_online(const std::string & url,
|
||||
if (opts.callback) {
|
||||
opts.callback->on_done(p, success);
|
||||
}
|
||||
if (opts.callback && opts.callback->is_cancelled() &&
|
||||
std::filesystem::exists(path_temporary)) {
|
||||
if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, path_temporary.c_str());
|
||||
}
|
||||
}
|
||||
if (!success) {
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
return -1; // max attempts reached
|
||||
|
||||
@@ -21,6 +21,7 @@ public:
|
||||
virtual void on_start(const common_download_progress & p) = 0;
|
||||
virtual void on_update(const common_download_progress & p) = 0;
|
||||
virtual void on_done(const common_download_progress & p, bool ok) = 0;
|
||||
virtual bool is_cancelled() const { return false; }
|
||||
};
|
||||
|
||||
struct common_remote_params {
|
||||
|
||||
@@ -58,26 +58,48 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
||||
bool is_capturing = false;
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
// Currently (confirmed for CCCL <= 3.2) DeviceSegmentedSort does not support stream capture, while DeviceSegmentedRadixSort does.
|
||||
// See https://github.com/NVIDIA/cccl/issues/5661#issuecomment-3229037149
|
||||
// TODO: constrain this to the CCCL versions that have this issue once it's resolved in a future CCCL release.
|
||||
cudaStreamCaptureStatus capture_status;
|
||||
CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status));
|
||||
is_capturing = (capture_status != cudaStreamCaptureStatusNone);
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(
|
||||
nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
|
||||
stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,22 +108,33 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream));
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys,
|
||||
temp_keys, temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2858,11 +2858,10 @@ struct vk_fa_tuning_params {
|
||||
}
|
||||
};
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
GGML_UNUSED(kv_type);
|
||||
|
||||
vk_fa_tuning_params result{};
|
||||
result.path = FA_SCALAR;
|
||||
@@ -2914,7 +2913,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
|
||||
|
||||
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
|
||||
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
|
||||
result.block_rows /= 2;
|
||||
}
|
||||
|
||||
@@ -3445,21 +3444,47 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
if (device->fp16) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product && device->subgroup_clustered) {
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
|
||||
}
|
||||
} else {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product && device->subgroup_clustered) {
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
|
||||
}
|
||||
}
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
@@ -8780,7 +8805,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
|
||||
GGML_UNUSED(f32acc);
|
||||
// Needs to be kept up to date on shader changes
|
||||
const uint32_t wg_size = params.workgroup_size;
|
||||
@@ -8789,21 +8814,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
|
||||
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
|
||||
const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
|
||||
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
|
||||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
|
||||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);
|
||||
|
||||
// tmpsh is overestimated slightly
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
|
||||
|
||||
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
|
||||
|
||||
const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
uint32_t Qf, kvsh, kblocksh_size;
|
||||
if (mmq) {
|
||||
// block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds
|
||||
const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size;
|
||||
Qf = Br * (hsk / 32) * block_b_size;
|
||||
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
// kvsh uses D = HSV (K goes through kblocksh instead)
|
||||
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
|
||||
// block_a_cache size depends on quant type
|
||||
uint32_t block_a_size;
|
||||
switch (kv_type) {
|
||||
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
|
||||
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
|
||||
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
|
||||
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
|
||||
default: block_a_size = 0; break;
|
||||
}
|
||||
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
|
||||
} else {
|
||||
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
kblocksh_size = 0;
|
||||
}
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
@@ -10,6 +10,13 @@
|
||||
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
|
||||
#endif
|
||||
|
||||
#ifdef MMQ
|
||||
#extension GL_EXT_integer_dot_product : require
|
||||
#extension GL_KHR_shader_subgroup_clustered : require
|
||||
|
||||
#include "mul_mmq_shmem_types.glsl"
|
||||
#endif
|
||||
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
|
||||
@@ -41,15 +48,34 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
|
||||
const uint32_t masksh_stride = Br + 1;
|
||||
shared FLOAT_TYPE masksh[Bc * masksh_stride];
|
||||
|
||||
#ifndef MMQ
|
||||
const uint32_t qf_stride = HSK / 4 + 1;
|
||||
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
|
||||
#else
|
||||
|
||||
const uint32_t qf_stride = HSK / 32;
|
||||
shared block_b_cache Qf[Br * qf_stride];
|
||||
#endif
|
||||
|
||||
#ifndef MMQ
|
||||
const uint32_t D = HSK > HSV ? HSK : HSV;
|
||||
#else
|
||||
const uint32_t D = HSV;
|
||||
#endif
|
||||
const uint32_t kvsh_stride = D / 4 + 1;
|
||||
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
|
||||
|
||||
#ifdef MMQ
|
||||
|
||||
shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1];
|
||||
#endif
|
||||
|
||||
shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
|
||||
|
||||
#ifdef MMQ
|
||||
#include "flash_attn_mmq_funcs.glsl"
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
@@ -82,10 +108,39 @@ void main() {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t r = (idx + tid) / (HSK / 4);
|
||||
if (r < Br && d < HSK / 4 &&
|
||||
i * Br + r < N) {
|
||||
const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N;
|
||||
#ifndef MMQ
|
||||
if (is_in_bounds) {
|
||||
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||
}
|
||||
#else
|
||||
const uint buf_ib = r * qf_stride + d / 8;
|
||||
const uint buf_iqs = d % 8;
|
||||
|
||||
FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f);
|
||||
const FLOAT_TYPEV4 abs_vals = abs(vals);
|
||||
|
||||
const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
|
||||
const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8);
|
||||
const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0);
|
||||
const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0);
|
||||
vals = round(vals * qd_inv);
|
||||
|
||||
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
|
||||
|
||||
#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
|
||||
}
|
||||
#else // Q4_0, Q4_1, Q5_0, Q5_1
|
||||
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
|
||||
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
|
||||
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
barrier();
|
||||
|
||||
@@ -195,6 +250,7 @@ void main() {
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
barrier();
|
||||
#ifndef MMQ
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
@@ -214,9 +270,29 @@ void main() {
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
#else // MMQ
|
||||
const uint ints_per_block = 8 / QUANT_R_MMQ;
|
||||
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
|
||||
const uint32_t iqs = (idx + tid) % ints_per_block;
|
||||
const uint32_t ib = (idx + tid) / ints_per_block;
|
||||
const uint32_t c = ib / (HSK / 32);
|
||||
const uint32_t block = ib % (HSK / 32);
|
||||
if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) {
|
||||
const uint buf_ib = c * qf_stride + block;
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
const uint global_ib = (j * Bc + c) * k_stride + block;
|
||||
k_block_to_shmem(buf_ib, global_ib, iqs, k_offset);
|
||||
} else {
|
||||
k_block_to_shmem_zero(buf_ib, iqs);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MMQ
|
||||
barrier();
|
||||
}
|
||||
|
||||
#ifndef MMQ
|
||||
// More d iterations means Q register caching becomes relevant
|
||||
// Few iterations means the additional registers needed are worse than the speed-up from caching
|
||||
if (HSK_per_thread / 4 > 4) {
|
||||
@@ -275,6 +351,110 @@ void main() {
|
||||
}
|
||||
}
|
||||
}
|
||||
#else // MMQ
|
||||
const uint hsk4 = HSK_per_thread / 4;
|
||||
const uint d_per_step = (hsk4 % 8 == 0) ? 8 :
|
||||
(hsk4 % 4 == 0) ? 4 :
|
||||
(hsk4 % 2 == 0) ? 2 : 1;
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) {
|
||||
int32_t k_quants[d_per_step];
|
||||
ACC_TYPEV2 k_dm;
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
|
||||
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
|
||||
#if QUANT_AUXF == 1
|
||||
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
|
||||
#else
|
||||
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
if (d_per_step == 8) {
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
uint vui = kblocksh[buf_ib].qs[d];
|
||||
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
#endif
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
|
||||
const uint ib = coord / BLOCK_SIZE;
|
||||
const uint iqs = (coord % BLOCK_SIZE);
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
|
||||
#else
|
||||
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
|
||||
#endif
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
if (d_per_step == 8) {
|
||||
#if defined(DATA_A_Q5_0)
|
||||
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
|
||||
k_packed.k_data_packed16[k_offset + ib].qh[1]));
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
|
||||
#endif
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
|
||||
#else
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
|
||||
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
|
||||
#endif
|
||||
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint qh_lo = (qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
#endif
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8;
|
||||
const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8;
|
||||
|
||||
int32_t acc = 0;
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]);
|
||||
}
|
||||
|
||||
Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x;
|
||||
if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) {
|
||||
Sf[r][c] += k_dot_correction(qib, k_dm);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MMQ
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
// Compute sum across the D_split
|
||||
|
||||
@@ -89,6 +89,11 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32;
|
||||
layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32;
|
||||
#endif
|
||||
|
||||
#ifndef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 1
|
||||
#endif
|
||||
|
||||
149
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl
Normal file
149
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl
Normal file
@@ -0,0 +1,149 @@
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
#ifdef DATA_A_Q4_0
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
#else
|
||||
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
#endif
|
||||
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
return int32_t(vui & 0x0F0F0F0F);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
#ifdef DATA_A_Q5_0
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qh[1]));
|
||||
#else
|
||||
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
|
||||
#endif
|
||||
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
uint qh_bits = (qh >> iqs) & 0xF;
|
||||
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
u8vec4 idx = unpack8(vui & 0x0F0F0F0F);
|
||||
return pack32(i8vec4(kvalues_iq4nl_const[idx.x],
|
||||
kvalues_iq4nl_const[idx.y],
|
||||
kvalues_iq4nl_const[idx.z],
|
||||
kvalues_iq4nl_const[idx.w]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
FLOAT_TYPE get_k_d(uint ib, uint a_offset) {
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d);
|
||||
}
|
||||
#else
|
||||
FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) {
|
||||
return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qh[1]));
|
||||
}
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh;
|
||||
}
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
||||
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
||||
kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y],
|
||||
kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w]));
|
||||
kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y],
|
||||
kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w]));
|
||||
#endif
|
||||
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d);
|
||||
#else
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) {
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4 : 0;
|
||||
return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
|
||||
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4 : 0;
|
||||
int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
|
||||
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF;
|
||||
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
return kblocksh[buf_ib].qs[pos];
|
||||
#endif
|
||||
}
|
||||
|
||||
ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
|
||||
#else
|
||||
return ACC_TYPE(0.0);
|
||||
#endif
|
||||
}
|
||||
|
||||
void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) {
|
||||
kblocksh[buf_ib].qs[iqs] = 0;
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
kblocksh[buf_ib].qs[iqs + 4] = 0;
|
||||
#endif
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f);
|
||||
#else
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -32,6 +32,12 @@ struct block_a_cache {
|
||||
int32_t qs[32/4];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
|
||||
@@ -1692,6 +1692,7 @@ struct block_iq4_nl_packed16
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
#define QUANT_K QUANT_K_IQ4_NL
|
||||
#define QUANT_R QUANT_R_IQ4_NL
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_iq4_nl
|
||||
#define A_TYPE_PACKED16 block_iq4_nl_packed16
|
||||
#endif
|
||||
|
||||
@@ -406,8 +406,8 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
|
||||
}
|
||||
|
||||
static std::vector<std::future<void>> compiles;
|
||||
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
||||
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
||||
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") {
|
||||
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix;
|
||||
std::string out_path = join_paths(output_dir, name + ".spv");
|
||||
|
||||
if (input_filepath == "") {
|
||||
@@ -625,15 +625,16 @@ void process_shaders() {
|
||||
for (const bool& fp16 : {false, true}) {
|
||||
std::map<std::string, std::string> base_dict;
|
||||
if (fp16) {
|
||||
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
|
||||
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
|
||||
} else {
|
||||
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
|
||||
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}};
|
||||
}
|
||||
|
||||
// flash attention
|
||||
for (const bool& f16acc : {false, true}) {
|
||||
std::map<std::string, std::string> fa_base_dict = base_dict;
|
||||
fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
|
||||
fa_base_dict["ACC_TYPEV2"] = fp16 && f16acc ? "f16vec2" : "vec2";
|
||||
fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
|
||||
if (fp16 && f16acc) {
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
@@ -672,6 +673,12 @@ void process_shaders() {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (tname != "f32") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -534,11 +534,7 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
|
||||
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
ctx->queue.Submit(1, &commands);
|
||||
if (!ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0,
|
||||
ctx->debug_host_buf.GetSize())) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Debug buffer map failed\n");
|
||||
return;
|
||||
}
|
||||
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
|
||||
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
|
||||
std::cout << "debug[0]: " << debug_data[0] << "\n";
|
||||
ctx->debug_host_buf.Unmap();
|
||||
|
||||
@@ -8397,6 +8397,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 512, 1, 1}, order)); // test CUDA dispatching to radix sort for nrows > = 1 in graph mode
|
||||
}
|
||||
|
||||
for (int n = 1; n < 5; ++n) {
|
||||
@@ -8579,7 +8580,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
for (int nb : { 1, 3, 32, 75, }) {
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL}) {
|
||||
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue;
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
|
||||
|
||||
@@ -32,6 +32,9 @@ struct clip_graph {
|
||||
float kq_scale; // TODO: maybe move this to hparams
|
||||
const clip_flash_attn_type flash_attn_type;
|
||||
|
||||
// TODO [QWEN_VIDEO]: improve this in the future
|
||||
int nt = 1; // number of temporal dim, to be used by Qwen-VL models
|
||||
|
||||
ggml_context_ptr ctx0_ptr;
|
||||
ggml_context * ctx0;
|
||||
ggml_cgraph * gf;
|
||||
|
||||
@@ -448,6 +448,7 @@ struct clip_image_u8_batch {
|
||||
struct clip_image_f32_batch {
|
||||
std::vector<clip_image_f32_ptr> entries;
|
||||
bool is_audio = false;
|
||||
bool is_seq = true;
|
||||
|
||||
// for llava-uhd style models, we need to know the grid size
|
||||
// note: entries.size() == grid_x * grid_y + 1 (one overview image)
|
||||
@@ -458,6 +459,7 @@ struct clip_image_f32_batch {
|
||||
clip_image_f32_batch new_batch{
|
||||
/* entries */ {},
|
||||
/* is_audio */ is_audio,
|
||||
/* is_seq */ is_seq,
|
||||
/* grid_x */ grid_x,
|
||||
/* grid_y */ grid_y,
|
||||
};
|
||||
|
||||
@@ -515,7 +515,7 @@ ggml_tensor * clip_graph::build_inp() {
|
||||
}
|
||||
|
||||
ggml_tensor * clip_graph::build_inp_raw(int channels) {
|
||||
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
|
||||
ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels, nt);
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
return inp_raw;
|
||||
@@ -951,6 +951,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
GGML_ABORT("missing cgraph builder");
|
||||
}
|
||||
|
||||
// TODO [QWEN_VIDEO]: improve this in the future
|
||||
builder->nt = imgs.entries.size();
|
||||
|
||||
return builder->build();
|
||||
}
|
||||
|
||||
@@ -3042,10 +3045,11 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3
|
||||
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
|
||||
const clip_image_f32_batch & imgs = *imgs_c_ptr;
|
||||
int batch_size = imgs.entries.size();
|
||||
bool support_seq = clip_model_supports_seq_input(ctx);
|
||||
|
||||
// TODO @ngxson : implement batch size > 1 as a loop
|
||||
// we don't need true batching support because the cgraph will gonna be big anyway
|
||||
if (batch_size != 1) {
|
||||
if (batch_size != 1 && !support_seq) {
|
||||
return false; // only support batch size of 1
|
||||
}
|
||||
|
||||
@@ -3117,6 +3121,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
// └─────┘ │
|
||||
// ──────┘ x B
|
||||
|
||||
// IMPORTANT: [QWEN_VIDEO] the batch dim is currently used for temporal dim in Qwen-VL models
|
||||
|
||||
for (size_t i = 0; i < imgs.entries.size(); i++) {
|
||||
const int nx = imgs.entries[i]->nx;
|
||||
const int ny = imgs.entries[i]->ny;
|
||||
@@ -3747,6 +3753,17 @@ bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
bool clip_model_supports_seq_input(const struct clip_ctx * ctx) {
|
||||
switch (ctx->proj_type()) {
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
||||
clip_image_f32 clip_img;
|
||||
clip_img.buf.resize(h * w * 3);
|
||||
|
||||
@@ -116,3 +116,6 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel
|
||||
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
|
||||
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
|
||||
bool clip_has_whisper_encoder(const struct clip_ctx * ctx);
|
||||
|
||||
// true if model graph support image->nt (temporal dimension) as input
|
||||
bool clip_model_supports_seq_input(const struct clip_ctx * ctx);
|
||||
|
||||
@@ -26,10 +26,11 @@ struct clip_graph_pixtral : clip_graph {
|
||||
struct clip_graph_qwen2vl : clip_graph {
|
||||
clip_graph_qwen2vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
ggml_tensor * build_inp_with_temporal_merge();
|
||||
};
|
||||
|
||||
struct clip_graph_qwen3vl : clip_graph {
|
||||
clip_graph_qwen3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
struct clip_graph_qwen3vl : clip_graph_qwen2vl {
|
||||
clip_graph_qwen3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph_qwen2vl(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,31 @@
|
||||
#include "models.h"
|
||||
|
||||
ggml_tensor * clip_graph_qwen2vl::build_inp_with_temporal_merge() {
|
||||
ggml_tensor * inp_raw = build_inp_raw();
|
||||
|
||||
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
|
||||
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
|
||||
|
||||
const size_t nb1 = ggml_row_size(inp_raw->type, img.nx);
|
||||
const size_t nb2 = nb1 * img.ny;
|
||||
|
||||
if (nt == 1) {
|
||||
// still image input
|
||||
return ggml_add(ctx0,
|
||||
ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1),
|
||||
ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1));
|
||||
} else if (nt == 2) {
|
||||
// 2 frames input (video input)
|
||||
ggml_tensor * inp_0 = ggml_view_3d(ctx0, inp_raw, img.nx, img.ny, 3, nb1, nb2, 0);
|
||||
ggml_tensor * inp_1 = ggml_view_3d(ctx0, inp_raw, img.nx, img.ny, 3, nb1, nb2, nb2 * 3);
|
||||
return ggml_add(ctx0,
|
||||
ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_0, patch_size, patch_size, 0, 0, 1, 1),
|
||||
ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_1, patch_size, patch_size, 0, 0, 1, 1));
|
||||
} else {
|
||||
GGML_ASSERT(false && "nt > 2 is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cgraph * clip_graph_qwen2vl::build() {
|
||||
GGML_ASSERT(model.patch_bias == nullptr);
|
||||
GGML_ASSERT(model.class_embedding == nullptr);
|
||||
@@ -16,17 +42,10 @@ ggml_cgraph * clip_graph_qwen2vl::build() {
|
||||
|
||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||
|
||||
ggml_tensor * inp_raw = build_inp_raw();
|
||||
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
|
||||
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
|
||||
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
|
||||
ggml_tensor * inp = build_inp_with_temporal_merge();
|
||||
|
||||
// second conv dimension
|
||||
{
|
||||
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, inp_1);
|
||||
|
||||
inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
|
||||
inp = ggml_cont_4d(
|
||||
ctx0, inp,
|
||||
|
||||
@@ -13,17 +13,10 @@ ggml_cgraph * clip_graph_qwen3vl::build() {
|
||||
|
||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||
|
||||
ggml_tensor * inp_raw = build_inp_raw();
|
||||
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
ggml_tensor * inp = build_inp_with_temporal_merge();
|
||||
|
||||
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
|
||||
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
|
||||
|
||||
// second conv dimension
|
||||
// spatial merge
|
||||
{
|
||||
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, inp_1);
|
||||
|
||||
inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
|
||||
inp = ggml_cont_4d(
|
||||
ctx0, inp,
|
||||
|
||||
@@ -274,7 +274,8 @@ int32_t mtmd_helper_decode_image_chunk(
|
||||
batch_embd.set_position_normal(n_past, seq_id);
|
||||
}
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
const bool use_non_causal = mtmd_decode_use_non_causal(ctx, chunk);
|
||||
if (use_non_causal) {
|
||||
llama_set_causal_attn(lctx, false);
|
||||
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
|
||||
}
|
||||
@@ -302,7 +303,7 @@ int32_t mtmd_helper_decode_image_chunk(
|
||||
n_past += mtmd_input_chunk_get_n_pos(chunk);
|
||||
*new_n_past = n_past;
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
if (use_non_causal) {
|
||||
llama_set_causal_attn(lctx, true);
|
||||
}
|
||||
return 0;
|
||||
|
||||
@@ -25,9 +25,11 @@
|
||||
|
||||
// represents raw image data, layout is RGBRGBRGB...
|
||||
// length of data must be nx * ny * 3
|
||||
// for sequence of images (i.e. video): data is nt sequential RGB frames, each nx * ny * 3 bytes
|
||||
struct mtmd_bitmap {
|
||||
uint32_t nx;
|
||||
uint32_t ny;
|
||||
uint32_t nt = 1; // 1 for single images, >= 2 (even) for sequence
|
||||
std::vector<unsigned char> data;
|
||||
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
|
||||
bool is_audio = false; // true if the bitmap is audio
|
||||
@@ -37,8 +39,8 @@ struct mtmd_image_tokens {
|
||||
uint32_t nx; // number of tokens in x direction
|
||||
uint32_t ny; // number of tokens in y direction
|
||||
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
|
||||
uint32_t n_tokens() const { return nx * ny; }
|
||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||
uint32_t n_tokens() const { return nx * ny; } // TODO [QWEN_VIDEO]: we don't count nt here to be compatible with Qwen-VL, but other models in the future might have different logic
|
||||
std::string id; // optional user-defined ID, useful for KV cache tracking
|
||||
|
||||
mtmd_image_tokens clone() {
|
||||
@@ -875,6 +877,73 @@ struct mtmd_tokenizer {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t add_seq_image(const mtmd_bitmap * bitmap) {
|
||||
GGML_ASSERT(ctx->ctx_v);
|
||||
GGML_ASSERT(bitmap->nt > 1);
|
||||
// TODO [QWEN_VIDEO]: we only support even frames (Qwen-VL style) for now
|
||||
GGML_ASSERT(bitmap->nt % 2 == 0);
|
||||
bool support_seq = clip_model_supports_seq_input(ctx->ctx_v);
|
||||
if (!support_seq) {
|
||||
LOG_ERR("%s: error: model does not support sequential image input (usually requires Qwen-VL style models)\n", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
||||
const uint32_t n_frames = bitmap->nt;
|
||||
const size_t frame_bytes = (size_t)bitmap->nx * bitmap->ny * 3;
|
||||
|
||||
// preprocess each frame individually
|
||||
clip_image_f32_batch all_frames;
|
||||
all_frames.is_seq = true;
|
||||
all_frames.grid_x = 0; // currently, we don't support tiling for video input
|
||||
all_frames.grid_y = 0; // currently, we don't support tiling for video input
|
||||
|
||||
for (uint32_t f = 0; f < n_frames; f++) {
|
||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||
img_u8->nx = bitmap->nx;
|
||||
img_u8->ny = bitmap->ny;
|
||||
img_u8->buf.resize(frame_bytes);
|
||||
std::memcpy(img_u8->buf.data(), bitmap->data.data() + f * frame_bytes, frame_bytes);
|
||||
|
||||
clip_image_f32_batch frame_batch;
|
||||
bool ok = ctx->image_preproc->preprocess(*img_u8, frame_batch);
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to preprocess image\n");
|
||||
return 2;
|
||||
}
|
||||
GGML_ASSERT(frame_batch.entries.size() == 1);
|
||||
all_frames.entries.push_back(std::move(frame_batch.entries[0]));
|
||||
}
|
||||
|
||||
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
|
||||
if (mtmd_decode_use_mrope(ctx)) {
|
||||
// for Qwen2VL, we need this information for M-RoPE decoding positions
|
||||
image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, all_frames.entries[0].get());
|
||||
image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, all_frames.entries[0].get());
|
||||
image_tokens->use_mrope_pos = true;
|
||||
} else {
|
||||
GGML_ASSERT(false && "not supported");
|
||||
}
|
||||
image_tokens->batch_f32 = std::move(all_frames);
|
||||
image_tokens->id = bitmap->id; // optional
|
||||
|
||||
LOG_DBG("seq_image: nt=%u, nx=%u, ny=%u, n_tokens=%u\n",
|
||||
bitmap->nt, image_tokens->nx, image_tokens->ny, image_tokens->n_tokens());
|
||||
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||
{}, // text tokens
|
||||
std::move(image_tokens),
|
||||
nullptr, // audio tokens
|
||||
};
|
||||
cur.entries.emplace_back(std::move(chunk));
|
||||
|
||||
if (!ctx->img_end.empty()) {
|
||||
add_text(ctx->img_end, true);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<mtmd_input_chunk> split_batch_to_chunk(clip_image_f32_batch && batch_f32, const std::string & id) {
|
||||
std::vector<mtmd_input_chunk> chunks;
|
||||
|
||||
@@ -993,6 +1062,7 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
|
||||
|| clip_is_glm(ctx_clip)
|
||||
|| proj_type == PROJECTOR_TYPE_INTERNVL) {
|
||||
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
|
||||
// video: each entry is one frame pair, encoded with per-frame attention
|
||||
const auto & entries = image_tokens->batch_f32.entries;
|
||||
for (size_t i = 0; i < entries.size(); i++) {
|
||||
int n_tokens_per_image = clip_n_output_tokens(ctx_clip, entries[i].get());
|
||||
@@ -1017,8 +1087,12 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
|
||||
return ctx->image_embd_v.data();
|
||||
}
|
||||
|
||||
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
|
||||
switch (ctx->proj_type_v()) {
|
||||
bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
|
||||
auto proj_type = ctx->proj_type_v();
|
||||
if (chunk && chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
proj_type = ctx->proj_type_a();
|
||||
}
|
||||
switch (proj_type) {
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_GEMMA4V:
|
||||
return true;
|
||||
@@ -1071,17 +1145,54 @@ mtmd_bitmap * mtmd_bitmap_init(uint32_t nx,
|
||||
mtmd_bitmap * bitmap = new mtmd_bitmap;
|
||||
bitmap->nx = nx;
|
||||
bitmap->ny = ny;
|
||||
bitmap->nt = 1;
|
||||
size_t data_size = (size_t)nx * ny * 3;
|
||||
bitmap->data.resize(data_size);
|
||||
std::memcpy(bitmap->data.data(), data, data_size);
|
||||
return bitmap;
|
||||
}
|
||||
|
||||
mtmd_bitmap * mtmd_bitmap_init_from_seq(uint32_t nx,
|
||||
uint32_t ny,
|
||||
uint32_t nt,
|
||||
const unsigned char * data) {
|
||||
if (nt == 0) {
|
||||
LOG_ERR("%s: error: nt must be greater than 0 for sequence input\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
if (nt == 1) {
|
||||
// if nt == 1, it's not really a sequence, we can treat it as a single image
|
||||
return mtmd_bitmap_init(nx, ny, data);
|
||||
}
|
||||
// TODO [QWEN_VIDEO]: we only support Qwen-VL style for now, which requires even number of frames
|
||||
// therefore, we duplicate the last frame if nt is odd, to avoid issues in video preprocessing
|
||||
bool is_odd = (nt % 2 == 1);
|
||||
if (is_odd) {
|
||||
nt += 1;
|
||||
}
|
||||
size_t frame_size = (size_t)nx * ny * 3;
|
||||
mtmd_bitmap * bitmap = new mtmd_bitmap;
|
||||
bitmap->nx = nx;
|
||||
bitmap->ny = ny;
|
||||
bitmap->nt = nt;
|
||||
size_t data_size = frame_size * nt;
|
||||
bitmap->data.resize(data_size);
|
||||
std::memcpy(bitmap->data.data(), data, data_size);
|
||||
if (is_odd) {
|
||||
// duplicate the last frame
|
||||
std::memcpy(bitmap->data.data() + (nt - 1) * frame_size,
|
||||
data + (nt - 2) * frame_size,
|
||||
frame_size);
|
||||
}
|
||||
return bitmap;
|
||||
}
|
||||
|
||||
mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples,
|
||||
const float * data) {
|
||||
mtmd_bitmap * bitmap = new mtmd_bitmap;
|
||||
bitmap->nx = n_samples;
|
||||
bitmap->ny = 1;
|
||||
bitmap->nt = 1;
|
||||
bitmap->is_audio = true;
|
||||
size_t data_size = n_samples * sizeof(float);
|
||||
bitmap->data.resize(data_size);
|
||||
@@ -1097,6 +1208,10 @@ uint32_t mtmd_bitmap_get_ny(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->ny;
|
||||
}
|
||||
|
||||
uint32_t mtmd_bitmap_get_nt(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->nt;
|
||||
}
|
||||
|
||||
const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->data.data();
|
||||
}
|
||||
@@ -1109,6 +1224,10 @@ bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->is_audio;
|
||||
}
|
||||
|
||||
bool mtmd_bitmap_is_seq(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->nt >= 2;
|
||||
}
|
||||
|
||||
const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->id.c_str();
|
||||
}
|
||||
@@ -1251,8 +1370,8 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
|
||||
|
||||
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
|
||||
if (image_tokens->use_mrope_pos) {
|
||||
// for M-RoPE, temporal dimension = max(t,h,w)
|
||||
// t is omitted as we don't support video input
|
||||
// for M-RoPE, n_pos = max(t, h, w)
|
||||
// t is omitted as we don't support batching
|
||||
return std::max(image_tokens->nx, image_tokens->ny);
|
||||
}
|
||||
return image_tokens->n_tokens();
|
||||
|
||||
@@ -114,7 +114,8 @@ MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||
MTMD_API void mtmd_free(mtmd_context * ctx);
|
||||
|
||||
// whether we need to set non-causal mask before llama_decode
|
||||
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
|
||||
// if chunk is nullptr, we assume the default case where chunk is an image chunk
|
||||
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk);
|
||||
|
||||
// whether the current model use M-RoPE for llama_decode
|
||||
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
|
||||
@@ -134,16 +135,23 @@ MTMD_API int mtmd_get_audio_sample_rate(mtmd_context * ctx);
|
||||
// if bitmap is image:
|
||||
// length of data must be nx * ny * 3
|
||||
// the data is in RGBRGBRGB... format
|
||||
// if bitmap is sequence of images (i.e. video):
|
||||
// nt is the number of frames
|
||||
// length of data must be nx * ny * 3 * nt
|
||||
// frames are sequential RGB, each nx * ny * 3 bytes
|
||||
// if bitmap is audio:
|
||||
// length of data must be n_samples * sizeof(float)
|
||||
// the data is in float format (PCM F32)
|
||||
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, uint32_t ny, const unsigned char * data);
|
||||
MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_seq (uint32_t nx, uint32_t ny, uint32_t nt, const unsigned char * data);
|
||||
MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples, const float * data);
|
||||
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
|
||||
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
|
||||
MTMD_API uint32_t mtmd_bitmap_get_nt (const mtmd_bitmap * bitmap);
|
||||
MTMD_API const unsigned char * mtmd_bitmap_get_data (const mtmd_bitmap * bitmap);
|
||||
MTMD_API size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap);
|
||||
MTMD_API bool mtmd_bitmap_is_audio (const mtmd_bitmap * bitmap);
|
||||
MTMD_API bool mtmd_bitmap_is_seq (const mtmd_bitmap * bitmap);
|
||||
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
|
||||
// bitmap ID is optional, but useful for KV cache tracking
|
||||
// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
|
||||
@@ -276,9 +284,14 @@ struct bitmap {
|
||||
bitmap(uint32_t nx, uint32_t ny, const unsigned char * data) {
|
||||
ptr.reset(mtmd_bitmap_init(nx, ny, data));
|
||||
}
|
||||
bitmap(uint32_t nx, uint32_t ny, uint32_t nt, const unsigned char * data) {
|
||||
ptr.reset(mtmd_bitmap_init_from_seq(nx, ny, nt, data));
|
||||
}
|
||||
~bitmap() = default;
|
||||
uint32_t nx() const { return mtmd_bitmap_get_nx(ptr.get()); }
|
||||
uint32_t ny() const { return mtmd_bitmap_get_ny(ptr.get()); }
|
||||
uint32_t nx() const { return mtmd_bitmap_get_nx(ptr.get()); }
|
||||
uint32_t ny() const { return mtmd_bitmap_get_ny(ptr.get()); }
|
||||
uint32_t nt() const { return mtmd_bitmap_get_nt(ptr.get()); }
|
||||
bool is_seq() const { return mtmd_bitmap_is_seq(ptr.get()); }
|
||||
const unsigned char * data() const { return mtmd_bitmap_get_data(ptr.get()); }
|
||||
size_t n_bytes() const { return mtmd_bitmap_get_n_bytes(ptr.get()); }
|
||||
std::string id() const { return mtmd_bitmap_get_id(ptr.get()); }
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -926,7 +926,8 @@ void server_models_routes::init_routes() {
|
||||
res_ok(res, {
|
||||
// TODO: add support for this on web UI
|
||||
{"role", "router"},
|
||||
{"max_instances", 4}, // dummy value for testing
|
||||
{"max_instances", params.models_max},
|
||||
{"models_autoload", params.models_autoload},
|
||||
// this is a dummy response to make sure webui doesn't break
|
||||
{"model_alias", "llama-server"},
|
||||
{"model_path", "none"},
|
||||
@@ -935,6 +936,7 @@ void server_models_routes::init_routes() {
|
||||
{"n_ctx", 0},
|
||||
}},
|
||||
{"webui_settings", webui_settings},
|
||||
{"build_info", build_info},
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -9,6 +9,19 @@ def create_server():
|
||||
server = ServerPreset.router()
|
||||
|
||||
|
||||
def test_router_props():
|
||||
global server
|
||||
server.models_max = 2
|
||||
server.no_models_autoload = True
|
||||
server.start()
|
||||
res = server.make_request("GET", "/props")
|
||||
assert res.status_code == 200
|
||||
assert res.body["role"] == "router"
|
||||
assert res.body["max_instances"] == 2
|
||||
assert res.body["models_autoload"] is False
|
||||
assert res.body["build_info"].startswith("b")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,success",
|
||||
[
|
||||
|
||||
@@ -89,6 +89,11 @@
|
||||
key: SETTINGS_KEYS.ASK_FOR_TITLE_CONFIRMATION,
|
||||
label: 'Ask for confirmation before changing conversation title',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.TITLE_GENERATION_USE_FIRST_LINE,
|
||||
label: 'Use first non-empty line for conversation title',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@@ -15,6 +15,18 @@
|
||||
let { logs, connectionTimeMs, defaultExpanded = false, class: className }: Props = $props();
|
||||
|
||||
let isExpanded = $derived(defaultExpanded);
|
||||
|
||||
function formatLogDetails(details: unknown): string {
|
||||
if (details == null) {
|
||||
return '';
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.stringify(details, null, 2);
|
||||
} catch {
|
||||
return String(details);
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if logs.length > 0}
|
||||
@@ -53,6 +65,16 @@
|
||||
|
||||
<span class="break-all">{log.message}</span>
|
||||
</div>
|
||||
|
||||
{#if log.details !== undefined}
|
||||
<details class="ml-11">
|
||||
<summary class="cursor-pointer text-[10px] text-muted-foreground"> details </summary>
|
||||
|
||||
<pre
|
||||
class="mt-1 overflow-x-auto rounded bg-background/70 p-2 text-[10px] break-all whitespace-pre-wrap text-foreground/80">
|
||||
{formatLogDetails(log.details)}</pre>
|
||||
</details>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
</Collapsible.Content>
|
||||
|
||||
@@ -48,6 +48,26 @@ export const EXPECTED_THEMED_ICON_PAIR_COUNT = 2;
|
||||
/** CORS proxy URL query parameter name */
|
||||
export const CORS_PROXY_URL_PARAM = 'url';
|
||||
|
||||
/** Number of trailing characters to keep visible when partially redacting mcp-session-id */
|
||||
export const MCP_SESSION_ID_VISIBLE_CHARS = 5;
|
||||
|
||||
/** Partial-redaction rules for MCP headers: header name -> visible trailing chars */
|
||||
export const MCP_PARTIAL_REDACT_HEADERS = new Map<string, number>([
|
||||
['mcp-session-id', MCP_SESSION_ID_VISIBLE_CHARS]
|
||||
]);
|
||||
|
||||
/** Header names whose values should be redacted in diagnostic logs */
|
||||
export const REDACTED_HEADERS = new Set([
|
||||
'authorization',
|
||||
'api-key',
|
||||
'cookie',
|
||||
'mcp-session-id',
|
||||
'proxy-authorization',
|
||||
'set-cookie',
|
||||
'x-auth-token',
|
||||
'x-api-key'
|
||||
]);
|
||||
|
||||
/** Human-readable labels for MCP transport types */
|
||||
export const MCP_TRANSPORT_LABELS: Record<MCPTransportType, string> = {
|
||||
[MCPTransportType.WEBSOCKET]: 'WebSocket',
|
||||
|
||||
@@ -15,6 +15,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean |
|
||||
keepStatsVisible: false,
|
||||
showMessageStats: true,
|
||||
askForTitleConfirmation: false,
|
||||
titleGenerationUseFirstLine: false,
|
||||
pasteLongTextToFileLen: 2500,
|
||||
copyTextAttachmentsAsPlainText: false,
|
||||
pdfAsImage: false,
|
||||
@@ -118,6 +119,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
'Display generation statistics (tokens/second, token count, duration) below each assistant message.',
|
||||
askForTitleConfirmation:
|
||||
'Ask for confirmation before automatically changing conversation title when editing the first message.',
|
||||
titleGenerationUseFirstLine:
|
||||
'Use only the first non-empty line of the prompt to generate the conversation title.',
|
||||
pdfAsImage:
|
||||
'Parse PDF as image instead of text. Automatically falls back to text processing for non-vision models.',
|
||||
disableAutoScroll:
|
||||
|
||||
@@ -15,6 +15,7 @@ export const SETTINGS_KEYS = {
|
||||
ENABLE_CONTINUE_GENERATION: 'enableContinueGeneration',
|
||||
PDF_AS_IMAGE: 'pdfAsImage',
|
||||
ASK_FOR_TITLE_CONFIRMATION: 'askForTitleConfirmation',
|
||||
TITLE_GENERATION_USE_FIRST_LINE: 'titleGenerationUseFirstLine',
|
||||
// Display
|
||||
SHOW_MESSAGE_STATS: 'showMessageStats',
|
||||
SHOW_THOUGHT_IN_PROGRESS: 'showThoughtInProgress',
|
||||
|
||||
@@ -15,7 +15,8 @@ import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import {
|
||||
DEFAULT_MCP_CONFIG,
|
||||
DEFAULT_CLIENT_VERSION,
|
||||
DEFAULT_IMAGE_MIME_TYPE
|
||||
DEFAULT_IMAGE_MIME_TYPE,
|
||||
MCP_PARTIAL_REDACT_HEADERS
|
||||
} from '$lib/constants';
|
||||
import {
|
||||
MCPConnectionPhase,
|
||||
@@ -43,9 +44,17 @@ import {
|
||||
buildProxiedUrl,
|
||||
buildProxiedHeaders,
|
||||
getAuthHeaders,
|
||||
sanitizeHeaders,
|
||||
throwIfAborted,
|
||||
isAbortError,
|
||||
createBase64DataUrl
|
||||
createBase64DataUrl,
|
||||
getRequestUrl,
|
||||
getRequestMethod,
|
||||
getRequestBody,
|
||||
summarizeRequestBody,
|
||||
formatDiagnosticErrorMessage,
|
||||
extractJsonRpcMethods,
|
||||
type RequestBodySummary
|
||||
} from '$lib/utils';
|
||||
|
||||
interface ToolResultContentItem {
|
||||
@@ -62,6 +71,16 @@ interface ToolCallResult {
|
||||
_meta?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
interface DiagnosticRequestDetails {
|
||||
url: string;
|
||||
method: string;
|
||||
credentials?: RequestCredentials;
|
||||
mode?: RequestMode;
|
||||
headers: Record<string, string>;
|
||||
body: RequestBodySummary;
|
||||
jsonRpcMethods?: string[];
|
||||
}
|
||||
|
||||
export class MCPService {
|
||||
/**
|
||||
* Create a connection log entry for phase tracking.
|
||||
@@ -87,6 +106,225 @@ export class MCPService {
|
||||
};
|
||||
}
|
||||
|
||||
private static createDiagnosticRequestDetails(
|
||||
input: RequestInfo | URL,
|
||||
init: RequestInit | undefined,
|
||||
baseInit: RequestInit,
|
||||
requestHeaders: Headers,
|
||||
extraRedactedHeaders?: Iterable<string>
|
||||
): DiagnosticRequestDetails {
|
||||
const body = getRequestBody(input, init);
|
||||
const details: DiagnosticRequestDetails = {
|
||||
url: getRequestUrl(input),
|
||||
method: getRequestMethod(input, init, baseInit).toUpperCase(),
|
||||
credentials: init?.credentials ?? baseInit.credentials,
|
||||
mode: init?.mode ?? baseInit.mode,
|
||||
headers: sanitizeHeaders(requestHeaders, extraRedactedHeaders, MCP_PARTIAL_REDACT_HEADERS),
|
||||
body: summarizeRequestBody(body)
|
||||
};
|
||||
const jsonRpcMethods = extractJsonRpcMethods(body);
|
||||
|
||||
if (jsonRpcMethods) {
|
||||
details.jsonRpcMethods = jsonRpcMethods;
|
||||
}
|
||||
|
||||
return details;
|
||||
}
|
||||
|
||||
private static summarizeError(error: unknown): Record<string, unknown> {
|
||||
if (error instanceof Error) {
|
||||
return {
|
||||
name: error.name,
|
||||
message: error.message,
|
||||
cause:
|
||||
error.cause instanceof Error
|
||||
? { name: error.cause.name, message: error.cause.message }
|
||||
: error.cause,
|
||||
stack: error.stack?.split('\n').slice(0, 6).join('\n')
|
||||
};
|
||||
}
|
||||
|
||||
return { value: String(error) };
|
||||
}
|
||||
|
||||
private static getBrowserContext(
|
||||
targetUrl: URL,
|
||||
useProxy: boolean
|
||||
): Record<string, unknown> | undefined {
|
||||
if (typeof window === 'undefined') {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
location: window.location.href,
|
||||
origin: window.location.origin,
|
||||
protocol: window.location.protocol,
|
||||
isSecureContext: window.isSecureContext,
|
||||
targetOrigin: targetUrl.origin,
|
||||
targetProtocol: targetUrl.protocol,
|
||||
sameOrigin: window.location.origin === targetUrl.origin,
|
||||
useProxy
|
||||
};
|
||||
}
|
||||
|
||||
private static getConnectionHints(
|
||||
targetUrl: URL,
|
||||
config: MCPServerConfig,
|
||||
error: unknown
|
||||
): string[] {
|
||||
const hints: string[] = [];
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
const headerNames = Object.keys(config.headers ?? {});
|
||||
|
||||
if (typeof window !== 'undefined') {
|
||||
if (
|
||||
window.location.protocol === 'https:' &&
|
||||
targetUrl.protocol === 'http:' &&
|
||||
!config.useProxy
|
||||
) {
|
||||
hints.push(
|
||||
'The page is running over HTTPS but the MCP server is HTTP. Browsers often block this as mixed content; enable the proxy or use HTTPS/WSS for the MCP server.'
|
||||
);
|
||||
}
|
||||
|
||||
if (window.location.origin !== targetUrl.origin && !config.useProxy) {
|
||||
hints.push(
|
||||
'This is a cross-origin browser request. If the server is reachable from curl or Node but not from the browser, missing CORS headers are the most likely cause.'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (headerNames.length > 0) {
|
||||
hints.push(
|
||||
`Custom request headers are configured (${headerNames.join(', ')}). That triggers a CORS preflight, so the server must allow OPTIONS and include the matching Access-Control-Allow-Headers response.`
|
||||
);
|
||||
}
|
||||
|
||||
if (config.credentials && config.credentials !== 'omit') {
|
||||
hints.push(
|
||||
'Credentials are enabled for this connection. Cross-origin credentialed requests need Access-Control-Allow-Credentials: true and cannot use a wildcard Access-Control-Allow-Origin.'
|
||||
);
|
||||
}
|
||||
|
||||
if (message.includes('Failed to fetch')) {
|
||||
hints.push(
|
||||
'"Failed to fetch" is a browser-level network failure. Common causes are CORS rejection, mixed-content blocking, certificate/TLS errors, DNS failures, or nothing listening on the target port.'
|
||||
);
|
||||
}
|
||||
|
||||
return hints;
|
||||
}
|
||||
|
||||
private static createDiagnosticFetch(
|
||||
serverName: string,
|
||||
config: MCPServerConfig,
|
||||
baseInit: RequestInit,
|
||||
targetUrl: URL,
|
||||
useProxy: boolean,
|
||||
onLog?: (log: MCPConnectionLog) => void
|
||||
): {
|
||||
fetch: typeof fetch;
|
||||
disable: () => void;
|
||||
} {
|
||||
let enabled = true;
|
||||
const logIfEnabled = (log: MCPConnectionLog) => {
|
||||
if (enabled) {
|
||||
onLog?.(log);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
fetch: async (input, init) => {
|
||||
const startedAt = performance.now();
|
||||
const requestHeaders = new Headers(baseInit.headers);
|
||||
|
||||
if (typeof Request !== 'undefined' && input instanceof Request) {
|
||||
for (const [key, value] of input.headers.entries()) {
|
||||
requestHeaders.set(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
if (init?.headers) {
|
||||
for (const [key, value] of new Headers(init.headers).entries()) {
|
||||
requestHeaders.set(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
const request = this.createDiagnosticRequestDetails(
|
||||
input,
|
||||
init,
|
||||
baseInit,
|
||||
requestHeaders,
|
||||
Object.keys(config.headers ?? {})
|
||||
);
|
||||
const { method, url } = request;
|
||||
|
||||
logIfEnabled(
|
||||
this.createLog(
|
||||
MCPConnectionPhase.INITIALIZING,
|
||||
`HTTP ${method} ${url}`,
|
||||
MCPLogLevel.INFO,
|
||||
{
|
||||
serverName,
|
||||
request
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
try {
|
||||
const response = await fetch(input, {
|
||||
...baseInit,
|
||||
...init,
|
||||
headers: requestHeaders
|
||||
});
|
||||
const durationMs = Math.round(performance.now() - startedAt);
|
||||
|
||||
logIfEnabled(
|
||||
this.createLog(
|
||||
MCPConnectionPhase.INITIALIZING,
|
||||
`HTTP ${response.status} ${method} ${url} (${durationMs}ms)`,
|
||||
response.ok ? MCPLogLevel.INFO : MCPLogLevel.WARN,
|
||||
{
|
||||
response: {
|
||||
url,
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
headers: sanitizeHeaders(response.headers, undefined, MCP_PARTIAL_REDACT_HEADERS),
|
||||
durationMs
|
||||
}
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
return response;
|
||||
} catch (error) {
|
||||
const durationMs = Math.round(performance.now() - startedAt);
|
||||
|
||||
logIfEnabled(
|
||||
this.createLog(
|
||||
MCPConnectionPhase.ERROR,
|
||||
`HTTP ${method} ${url} failed: ${formatDiagnosticErrorMessage(error)}`,
|
||||
MCPLogLevel.ERROR,
|
||||
{
|
||||
serverName,
|
||||
request,
|
||||
error: this.summarizeError(error),
|
||||
browser: this.getBrowserContext(targetUrl, useProxy),
|
||||
hints: this.getConnectionHints(targetUrl, config, error),
|
||||
durationMs
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
disable: () => {
|
||||
enabled = false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect if an error indicates an expired/invalidated MCP session.
|
||||
* Per MCP spec 2025-11-25: HTTP 404 means session invalidated, client MUST
|
||||
@@ -113,9 +351,14 @@ export class MCPService {
|
||||
* @returns Object containing the created transport and the transport type used
|
||||
* @throws {Error} If url is missing, WebSocket + proxy combination, or all transports fail
|
||||
*/
|
||||
static createTransport(config: MCPServerConfig): {
|
||||
static createTransport(
|
||||
serverName: string,
|
||||
config: MCPServerConfig,
|
||||
onLog?: (log: MCPConnectionLog) => void
|
||||
): {
|
||||
transport: Transport;
|
||||
type: MCPTransportType;
|
||||
stopPhaseLogging: () => void;
|
||||
} {
|
||||
if (!config.url) {
|
||||
throw new Error('MCP server configuration is missing url');
|
||||
@@ -154,11 +397,20 @@ export class MCPService {
|
||||
|
||||
return {
|
||||
transport: new WebSocketClientTransport(url),
|
||||
type: MCPTransportType.WEBSOCKET
|
||||
type: MCPTransportType.WEBSOCKET,
|
||||
stopPhaseLogging: () => {}
|
||||
};
|
||||
}
|
||||
|
||||
const url = useProxy ? buildProxiedUrl(config.url) : new URL(config.url);
|
||||
const { fetch: diagnosticFetch, disable: stopPhaseLogging } = this.createDiagnosticFetch(
|
||||
serverName,
|
||||
config,
|
||||
requestInit,
|
||||
url,
|
||||
useProxy,
|
||||
onLog
|
||||
);
|
||||
|
||||
if (useProxy && import.meta.env.DEV) {
|
||||
console.log(`[MCPService] Using CORS proxy for ${config.url} -> ${url.href}`);
|
||||
@@ -171,17 +423,24 @@ export class MCPService {
|
||||
|
||||
return {
|
||||
transport: new StreamableHTTPClientTransport(url, {
|
||||
requestInit
|
||||
requestInit,
|
||||
fetch: diagnosticFetch
|
||||
}),
|
||||
type: MCPTransportType.STREAMABLE_HTTP
|
||||
type: MCPTransportType.STREAMABLE_HTTP,
|
||||
stopPhaseLogging
|
||||
};
|
||||
} catch (httpError) {
|
||||
console.warn(`[MCPService] StreamableHTTP failed, trying SSE transport...`, httpError);
|
||||
|
||||
try {
|
||||
return {
|
||||
transport: new SSEClientTransport(url, { requestInit }),
|
||||
type: MCPTransportType.SSE
|
||||
transport: new SSEClientTransport(url, {
|
||||
requestInit,
|
||||
fetch: diagnosticFetch,
|
||||
eventSourceInit: { fetch: diagnosticFetch }
|
||||
}),
|
||||
type: MCPTransportType.SSE,
|
||||
stopPhaseLogging
|
||||
};
|
||||
} catch (sseError) {
|
||||
const httpMsg = httpError instanceof Error ? httpError.message : String(httpError);
|
||||
@@ -263,7 +522,11 @@ export class MCPService {
|
||||
console.log(`[MCPService][${serverName}] Creating transport...`);
|
||||
}
|
||||
|
||||
const { transport, type: transportType } = this.createTransport(serverConfig);
|
||||
const {
|
||||
transport,
|
||||
type: transportType,
|
||||
stopPhaseLogging
|
||||
} = this.createTransport(serverName, serverConfig, (log) => onPhase?.(log.phase, log));
|
||||
|
||||
// Setup WebSocket reconnection handler
|
||||
if (transportType === MCPTransportType.WEBSOCKET) {
|
||||
@@ -294,6 +557,24 @@ export class MCPService {
|
||||
}
|
||||
);
|
||||
|
||||
const runtimeErrorHandler = (error: Error) => {
|
||||
console.error(`[MCPService][${serverName}] Protocol error after initialize:`, error);
|
||||
};
|
||||
|
||||
client.onerror = (error) => {
|
||||
onPhase?.(
|
||||
MCPConnectionPhase.ERROR,
|
||||
this.createLog(
|
||||
MCPConnectionPhase.ERROR,
|
||||
`Protocol error: ${error.message}`,
|
||||
MCPLogLevel.ERROR,
|
||||
{
|
||||
error: this.summarizeError(error)
|
||||
}
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
// Phase: Initializing
|
||||
onPhase?.(
|
||||
MCPConnectionPhase.INITIALIZING,
|
||||
@@ -301,7 +582,49 @@ export class MCPService {
|
||||
);
|
||||
|
||||
console.log(`[MCPService][${serverName}] Connecting to server...`);
|
||||
await client.connect(transport);
|
||||
try {
|
||||
await client.connect(transport);
|
||||
// Transport diagnostics are only for the initial handshake, not long-lived traffic.
|
||||
stopPhaseLogging();
|
||||
client.onerror = runtimeErrorHandler;
|
||||
} catch (error) {
|
||||
client.onerror = runtimeErrorHandler;
|
||||
const url =
|
||||
(serverConfig.useProxy ?? false)
|
||||
? buildProxiedUrl(serverConfig.url)
|
||||
: new URL(serverConfig.url);
|
||||
|
||||
onPhase?.(
|
||||
MCPConnectionPhase.ERROR,
|
||||
this.createLog(
|
||||
MCPConnectionPhase.ERROR,
|
||||
`Connection failed during initialize: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
MCPLogLevel.ERROR,
|
||||
{
|
||||
error: this.summarizeError(error),
|
||||
config: {
|
||||
serverName,
|
||||
configuredUrl: serverConfig.url,
|
||||
effectiveUrl: url.href,
|
||||
transportType,
|
||||
useProxy: serverConfig.useProxy ?? false,
|
||||
headers: sanitizeHeaders(
|
||||
serverConfig.headers,
|
||||
Object.keys(serverConfig.headers ?? {}),
|
||||
MCP_PARTIAL_REDACT_HEADERS
|
||||
),
|
||||
credentials: serverConfig.credentials
|
||||
},
|
||||
browser: this.getBrowserContext(url, serverConfig.useProxy ?? false),
|
||||
hints: this.getConnectionHints(url, serverConfig, error)
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
throw error;
|
||||
}
|
||||
|
||||
const serverVersion = client.getServerVersion();
|
||||
const serverCapabilities = client.getServerCapabilities();
|
||||
|
||||
@@ -130,6 +130,12 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'titleGenerationUseFirstLine',
|
||||
serverKey: 'titleGenerationUseFirstLine',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'disableAutoScroll',
|
||||
serverKey: 'disableAutoScroll',
|
||||
|
||||
@@ -30,7 +30,8 @@ import {
|
||||
findDescendantMessages,
|
||||
findLeafNode,
|
||||
findMessageById,
|
||||
isAbortError
|
||||
isAbortError,
|
||||
generateConversationTitle
|
||||
} from '$lib/utils';
|
||||
import {
|
||||
MAX_INACTIVE_CONVERSATION_STATES,
|
||||
@@ -504,7 +505,10 @@ class ChatStore {
|
||||
allExtras
|
||||
);
|
||||
if (isNewConversation && content)
|
||||
await conversationsStore.updateConversationName(currentConv.id, content.trim());
|
||||
await conversationsStore.updateConversationName(
|
||||
currentConv.id,
|
||||
generateConversationTitle(content, Boolean(config().titleGenerationUseFirstLine))
|
||||
);
|
||||
const assistantMessage = await this.createAssistantMessage(userMessage.id);
|
||||
conversationsStore.addMessageToActive(assistantMessage);
|
||||
await this.streamChatCompletion(
|
||||
@@ -896,7 +900,7 @@ class ChatStore {
|
||||
if (isFirstUserMessage && newContent.trim())
|
||||
await conversationsStore.updateConversationTitleWithConfirmation(
|
||||
activeConv.id,
|
||||
newContent.trim()
|
||||
generateConversationTitle(newContent, Boolean(config().titleGenerationUseFirstLine))
|
||||
);
|
||||
const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex + 1);
|
||||
for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id);
|
||||
@@ -1317,7 +1321,7 @@ class ChatStore {
|
||||
if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) {
|
||||
await conversationsStore.updateConversationTitleWithConfirmation(
|
||||
activeConv.id,
|
||||
newContent.trim()
|
||||
generateConversationTitle(newContent, Boolean(config().titleGenerationUseFirstLine))
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1391,7 +1395,7 @@ class ChatStore {
|
||||
if (isFirstUserMessage && newContent.trim())
|
||||
await conversationsStore.updateConversationTitleWithConfirmation(
|
||||
activeConv.id,
|
||||
newContent.trim()
|
||||
generateConversationTitle(newContent, Boolean(config().titleGenerationUseFirstLine))
|
||||
);
|
||||
await conversationsStore.refreshActiveMessages();
|
||||
if (msg.role === MessageRole.USER)
|
||||
|
||||
@@ -23,7 +23,12 @@ import { browser } from '$app/environment';
|
||||
import { toast } from 'svelte-sonner';
|
||||
import { DatabaseService } from '$lib/services/database.service';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { filterByLeafNodeId, findLeafNode, runLegacyMigration } from '$lib/utils';
|
||||
import {
|
||||
filterByLeafNodeId,
|
||||
findLeafNode,
|
||||
runLegacyMigration,
|
||||
generateConversationTitle
|
||||
} from '$lib/utils';
|
||||
import type { McpServerOverride } from '$lib/types/database';
|
||||
import { MessageRole } from '$lib/enums';
|
||||
import {
|
||||
@@ -548,7 +553,10 @@ class ConversationsStore {
|
||||
) {
|
||||
await this.updateConversationTitleWithConfirmation(
|
||||
this.activeConversation.id,
|
||||
newFirstUserMessage.content.trim()
|
||||
generateConversationTitle(
|
||||
newFirstUserMessage.content,
|
||||
Boolean(config().titleGenerationUseFirstLine)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1460,12 +1460,14 @@ class MCPStore {
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : 'Unknown error occurred';
|
||||
|
||||
logs.push({
|
||||
timestamp: new Date(),
|
||||
phase: MCPConnectionPhase.ERROR,
|
||||
message: `Connection failed: ${message}`,
|
||||
level: MCPLogLevel.ERROR
|
||||
});
|
||||
if (logs.at(-1)?.phase !== MCPConnectionPhase.ERROR) {
|
||||
logs.push({
|
||||
timestamp: new Date(),
|
||||
phase: MCPConnectionPhase.ERROR,
|
||||
message: `Connection failed: ${message}`,
|
||||
level: MCPLogLevel.ERROR
|
||||
});
|
||||
}
|
||||
|
||||
this.updateHealthCheck(server.id, {
|
||||
status: HealthCheckStatus.ERROR,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { REDACTED_HEADERS } from '$lib/constants';
|
||||
import { redactValue } from './redact';
|
||||
|
||||
/**
|
||||
* Get authorization headers for API requests
|
||||
@@ -20,3 +22,46 @@ export function getJsonHeaders(): Record<string, string> {
|
||||
...getAuthHeaders()
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize HTTP headers by redacting sensitive values.
|
||||
* Known sensitive headers (from REDACTED_HEADERS) and any extra headers
|
||||
* specified by the caller are fully redacted. Headers listed in
|
||||
* `partialRedactHeaders` are partially redacted, showing only the
|
||||
* specified number of trailing characters.
|
||||
*
|
||||
* @param headers - Headers to sanitize
|
||||
* @param extraRedactedHeaders - Additional header names to fully redact
|
||||
* @param partialRedactHeaders - Map of header name -> number of trailing chars to keep visible
|
||||
* @returns Object with header names as keys and (possibly redacted) values
|
||||
*/
|
||||
export function sanitizeHeaders(
|
||||
headers?: HeadersInit,
|
||||
extraRedactedHeaders?: Iterable<string>,
|
||||
partialRedactHeaders?: Map<string, number>
|
||||
): Record<string, string> {
|
||||
if (!headers) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const normalized = new Headers(headers);
|
||||
const sanitized: Record<string, string> = {};
|
||||
const redactedHeaders = new Set(
|
||||
Array.from(extraRedactedHeaders ?? [], (header) => header.toLowerCase())
|
||||
);
|
||||
|
||||
for (const [key, value] of normalized.entries()) {
|
||||
const normalizedKey = key.toLowerCase();
|
||||
const partialChars = partialRedactHeaders?.get(normalizedKey);
|
||||
|
||||
if (partialChars !== undefined) {
|
||||
sanitized[key] = redactValue(value, partialChars);
|
||||
} else if (REDACTED_HEADERS.has(normalizedKey) || redactedHeaders.has(normalizedKey)) {
|
||||
sanitized[key] = redactValue(value);
|
||||
} else {
|
||||
sanitized[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized;
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
*/
|
||||
|
||||
// API utilities
|
||||
export { getAuthHeaders, getJsonHeaders } from './api-headers';
|
||||
export { getAuthHeaders, getJsonHeaders, sanitizeHeaders } from './api-headers';
|
||||
export { apiFetch, apiFetchWithParams, apiPost, type ApiFetchOptions } from './api-fetch';
|
||||
export { validateApiKey } from './api-key-validation';
|
||||
|
||||
@@ -55,7 +55,7 @@ export {
|
||||
|
||||
// File preview utilities
|
||||
export { getFileTypeLabel } from './file-preview';
|
||||
export { getPreviewText } from './text';
|
||||
export { getPreviewText, generateConversationTitle } from './text';
|
||||
|
||||
// File type utilities
|
||||
export {
|
||||
@@ -164,6 +164,20 @@ export { runLegacyMigration, isMigrationNeeded } from './legacy-migration';
|
||||
// Cache utilities
|
||||
export { TTLCache, ReactiveTTLMap, type TTLCacheOptions } from './cache-ttl';
|
||||
|
||||
// Redaction utilities
|
||||
export { redactValue } from './redact';
|
||||
|
||||
// Request inspection utilities
|
||||
export {
|
||||
getRequestUrl,
|
||||
getRequestMethod,
|
||||
getRequestBody,
|
||||
summarizeRequestBody,
|
||||
formatDiagnosticErrorMessage,
|
||||
extractJsonRpcMethods,
|
||||
type RequestBodySummary
|
||||
} from './request-helpers';
|
||||
|
||||
// Abort signal utilities
|
||||
export {
|
||||
throwIfAborted,
|
||||
|
||||
14
tools/server/webui/src/lib/utils/redact.ts
Normal file
14
tools/server/webui/src/lib/utils/redact.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
/**
|
||||
* Redacts a sensitive value, optionally showing the last N characters.
|
||||
*
|
||||
* @param value - The value to redact
|
||||
* @param showLastChars - If provided, reveals the last N characters with a leading mask
|
||||
* @returns The redacted string
|
||||
*/
|
||||
export function redactValue(value: string, showLastChars?: number): string {
|
||||
if (showLastChars) {
|
||||
return `....${value.slice(-showLastChars)}`;
|
||||
}
|
||||
|
||||
return '[redacted]';
|
||||
}
|
||||
111
tools/server/webui/src/lib/utils/request-helpers.ts
Normal file
111
tools/server/webui/src/lib/utils/request-helpers.ts
Normal file
@@ -0,0 +1,111 @@
|
||||
/**
|
||||
* HTTP request inspection utilities for diagnostic logging.
|
||||
* These helpers extract metadata from fetch-style request arguments
|
||||
* without exposing sensitive payload data.
|
||||
*/
|
||||
|
||||
export interface RequestBodySummary {
|
||||
kind: string;
|
||||
size?: number;
|
||||
}
|
||||
|
||||
export function getRequestUrl(input: RequestInfo | URL): string {
|
||||
if (typeof input === 'string') {
|
||||
return input;
|
||||
}
|
||||
|
||||
if (input instanceof URL) {
|
||||
return input.href;
|
||||
}
|
||||
|
||||
return input.url;
|
||||
}
|
||||
|
||||
export function getRequestMethod(
|
||||
input: RequestInfo | URL,
|
||||
init?: RequestInit,
|
||||
baseInit?: RequestInit
|
||||
): string {
|
||||
if (init?.method) {
|
||||
return init.method;
|
||||
}
|
||||
|
||||
if (typeof Request !== 'undefined' && input instanceof Request) {
|
||||
return input.method;
|
||||
}
|
||||
|
||||
return baseInit?.method ?? 'GET';
|
||||
}
|
||||
|
||||
export function getRequestBody(
|
||||
input: RequestInfo | URL,
|
||||
init?: RequestInit
|
||||
): BodyInit | null | undefined {
|
||||
if (init?.body !== undefined) {
|
||||
return init.body;
|
||||
}
|
||||
|
||||
if (typeof Request !== 'undefined' && input instanceof Request) {
|
||||
return input.body;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export function summarizeRequestBody(body: BodyInit | null | undefined): RequestBodySummary {
|
||||
if (body == null) {
|
||||
return { kind: 'empty' };
|
||||
}
|
||||
|
||||
if (typeof body === 'string') {
|
||||
return { kind: 'string', size: body.length };
|
||||
}
|
||||
|
||||
if (body instanceof Blob) {
|
||||
return { kind: 'blob', size: body.size };
|
||||
}
|
||||
|
||||
if (body instanceof URLSearchParams) {
|
||||
return { kind: 'urlsearchparams', size: body.toString().length };
|
||||
}
|
||||
|
||||
if (body instanceof FormData) {
|
||||
return { kind: 'formdata' };
|
||||
}
|
||||
|
||||
if (body instanceof ArrayBuffer) {
|
||||
return { kind: 'arraybuffer', size: body.byteLength };
|
||||
}
|
||||
|
||||
if (ArrayBuffer.isView(body)) {
|
||||
return { kind: body.constructor.name, size: body.byteLength };
|
||||
}
|
||||
|
||||
return { kind: typeof body };
|
||||
}
|
||||
|
||||
export function formatDiagnosticErrorMessage(error: unknown): string {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
|
||||
return message.includes('Failed to fetch') ? `${message} (check CORS?)` : message;
|
||||
}
|
||||
|
||||
export function extractJsonRpcMethods(body: BodyInit | null | undefined): string[] | undefined {
|
||||
if (typeof body !== 'string') {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(body);
|
||||
const messages = Array.isArray(parsed) ? parsed : [parsed];
|
||||
const methods = messages
|
||||
.map((message: Record<string, unknown>) =>
|
||||
typeof message?.method === 'string' ? (message.method as string) : undefined
|
||||
)
|
||||
.filter((method: string | undefined): method is string => Boolean(method));
|
||||
|
||||
return methods.length > 0 ? methods : undefined;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
import { NEWLINE_SEPARATOR } from '$lib/constants';
|
||||
|
||||
/**
|
||||
* Returns a shortened preview of the provided content capped at the given length.
|
||||
* Appends an ellipsis when the content exceeds the maximum.
|
||||
@@ -5,3 +7,16 @@
|
||||
export function getPreviewText(content: string, max = 150): string {
|
||||
return content.length > max ? content.slice(0, max) + '...' : content;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a single-line title from a potentially multi-line prompt.
|
||||
* Uses the first non-empty line if `useFirstLine` is true.
|
||||
*/
|
||||
export function generateConversationTitle(content: string, useFirstLine: boolean = false): string {
|
||||
if (useFirstLine) {
|
||||
const firstLine = content.split(NEWLINE_SEPARATOR).find((line) => line.trim().length > 0);
|
||||
return firstLine ? firstLine.trim() : content.trim();
|
||||
}
|
||||
|
||||
return content.trim();
|
||||
}
|
||||
|
||||
252
tools/server/webui/tests/unit/mcp-service.test.ts
Normal file
252
tools/server/webui/tests/unit/mcp-service.test.ts
Normal file
@@ -0,0 +1,252 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client';
|
||||
import { MCPService } from '$lib/services/mcp.service';
|
||||
import { MCPConnectionPhase, MCPTransportType } from '$lib/enums';
|
||||
import type { MCPConnectionLog, MCPServerConfig } from '$lib/types';
|
||||
|
||||
type DiagnosticFetchFactory = (
|
||||
serverName: string,
|
||||
config: MCPServerConfig,
|
||||
baseInit: RequestInit,
|
||||
targetUrl: URL,
|
||||
useProxy: boolean,
|
||||
onLog?: (log: MCPConnectionLog) => void
|
||||
) => { fetch: typeof fetch; disable: () => void };
|
||||
|
||||
const createDiagnosticFetch = (
|
||||
config: MCPServerConfig,
|
||||
onLog?: (log: MCPConnectionLog) => void,
|
||||
baseInit: RequestInit = {}
|
||||
) =>
|
||||
(
|
||||
MCPService as unknown as { createDiagnosticFetch: DiagnosticFetchFactory }
|
||||
).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), false, onLog);
|
||||
|
||||
describe('MCPService', () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it('stops transport phase logging after handshake diagnostics are disabled', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' }
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
|
||||
|
||||
await controller.fetch(config.url, { method: 'POST', body: '{}' });
|
||||
expect(logs).toHaveLength(2);
|
||||
expect(logs.every((log) => log.message.includes('https://example.com/mcp'))).toBe(true);
|
||||
|
||||
controller.disable();
|
||||
await controller.fetch(config.url, { method: 'POST', body: '{}' });
|
||||
|
||||
expect(logs).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('redacts all configured custom headers in diagnostic request logs', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' }
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP,
|
||||
headers: {
|
||||
'x-auth-token': 'secret-token',
|
||||
'x-vendor-api-key': 'secret-key'
|
||||
}
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log), {
|
||||
headers: config.headers
|
||||
});
|
||||
|
||||
await controller.fetch(config.url, {
|
||||
method: 'POST',
|
||||
headers: { 'content-type': 'application/json' },
|
||||
body: '{}'
|
||||
});
|
||||
|
||||
expect(logs).toHaveLength(2);
|
||||
expect(logs[0].details).toMatchObject({
|
||||
request: {
|
||||
headers: {
|
||||
'x-auth-token': '[redacted]',
|
||||
'x-vendor-api-key': '[redacted]',
|
||||
'content-type': 'application/json'
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('partially redacts mcp-session-id in diagnostic request and response logs', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': 'session-response-67890'
|
||||
}
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
|
||||
|
||||
await controller.fetch(config.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': 'session-request-12345'
|
||||
},
|
||||
body: '{}'
|
||||
});
|
||||
|
||||
expect(logs).toHaveLength(2);
|
||||
expect(logs[0].details).toMatchObject({
|
||||
request: {
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': '....12345'
|
||||
}
|
||||
}
|
||||
});
|
||||
expect(logs[1].details).toMatchObject({
|
||||
response: {
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': '....67890'
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('extracts JSON-RPC methods without logging the raw request body', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' }
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
|
||||
|
||||
await controller.fetch(config.url, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify([
|
||||
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
|
||||
{ jsonrpc: '2.0', method: 'notifications/initialized' }
|
||||
])
|
||||
});
|
||||
|
||||
expect(logs[0].details).toMatchObject({
|
||||
request: {
|
||||
method: 'POST',
|
||||
body: {
|
||||
kind: 'string',
|
||||
size: expect.any(Number)
|
||||
},
|
||||
jsonRpcMethods: ['initialize', 'notifications/initialized']
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('adds a CORS hint to Failed to fetch diagnostic log messages', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const fetchError = new TypeError('Failed to fetch');
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockRejectedValue(fetchError));
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'http://localhost:8000/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
|
||||
|
||||
await expect(controller.fetch(config.url, { method: 'POST', body: '{}' })).rejects.toThrow(
|
||||
'Failed to fetch'
|
||||
);
|
||||
|
||||
expect(logs).toHaveLength(2);
|
||||
expect(logs[1].message).toBe(
|
||||
'HTTP POST http://localhost:8000/mcp failed: Failed to fetch (check CORS?)'
|
||||
);
|
||||
});
|
||||
|
||||
it('detaches phase error logging after the initialize handshake completes', async () => {
|
||||
const phaseLogs: Array<{ phase: MCPConnectionPhase; log: MCPConnectionLog }> = [];
|
||||
const stopPhaseLogging = vi.fn();
|
||||
let emitClientError: ((error: Error) => void) | undefined;
|
||||
|
||||
vi.spyOn(MCPService, 'createTransport').mockReturnValue({
|
||||
transport: {} as never,
|
||||
type: MCPTransportType.WEBSOCKET,
|
||||
stopPhaseLogging
|
||||
});
|
||||
vi.spyOn(MCPService, 'listTools').mockResolvedValue([]);
|
||||
vi.spyOn(Client.prototype, 'getServerVersion').mockReturnValue(undefined);
|
||||
vi.spyOn(Client.prototype, 'getServerCapabilities').mockReturnValue(undefined);
|
||||
vi.spyOn(Client.prototype, 'getInstructions').mockReturnValue(undefined);
|
||||
vi.spyOn(Client.prototype, 'connect').mockImplementation(async function (this: Client) {
|
||||
emitClientError = (error: Error) => this.onerror?.(error);
|
||||
this.onerror?.(new Error('handshake protocol error'));
|
||||
});
|
||||
|
||||
await MCPService.connect(
|
||||
'test-server',
|
||||
{
|
||||
url: 'ws://example.com/mcp',
|
||||
transport: MCPTransportType.WEBSOCKET
|
||||
},
|
||||
undefined,
|
||||
undefined,
|
||||
(phase, log) => phaseLogs.push({ phase, log })
|
||||
);
|
||||
|
||||
expect(stopPhaseLogging).toHaveBeenCalledTimes(1);
|
||||
expect(
|
||||
phaseLogs.filter(
|
||||
({ phase, log }) =>
|
||||
phase === MCPConnectionPhase.ERROR &&
|
||||
log.message === 'Protocol error: handshake protocol error'
|
||||
)
|
||||
).toHaveLength(1);
|
||||
|
||||
emitClientError?.(new Error('runtime protocol error'));
|
||||
|
||||
expect(
|
||||
phaseLogs.filter(
|
||||
({ phase, log }) =>
|
||||
phase === MCPConnectionPhase.ERROR &&
|
||||
log.message === 'Protocol error: runtime protocol error'
|
||||
)
|
||||
).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
20
tools/server/webui/tests/unit/redact.test.ts
Normal file
20
tools/server/webui/tests/unit/redact.test.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { redactValue } from '$lib/utils/redact';
|
||||
|
||||
describe('redactValue', () => {
|
||||
it('returns [redacted] by default', () => {
|
||||
expect(redactValue('secret-token')).toBe('[redacted]');
|
||||
});
|
||||
|
||||
it('shows last N characters when showLastChars is provided', () => {
|
||||
expect(redactValue('session-abc12', 5)).toBe('....abc12');
|
||||
});
|
||||
|
||||
it('handles value shorter than showLastChars', () => {
|
||||
expect(redactValue('ab', 5)).toBe('....ab');
|
||||
});
|
||||
|
||||
it('returns [redacted] when showLastChars is 0', () => {
|
||||
expect(redactValue('secret', 0)).toBe('[redacted]');
|
||||
});
|
||||
});
|
||||
124
tools/server/webui/tests/unit/request-helpers.test.ts
Normal file
124
tools/server/webui/tests/unit/request-helpers.test.ts
Normal file
@@ -0,0 +1,124 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import {
|
||||
getRequestUrl,
|
||||
getRequestMethod,
|
||||
getRequestBody,
|
||||
summarizeRequestBody,
|
||||
formatDiagnosticErrorMessage,
|
||||
extractJsonRpcMethods
|
||||
} from '$lib/utils/request-helpers';
|
||||
|
||||
describe('getRequestUrl', () => {
|
||||
it('returns a plain string input as-is', () => {
|
||||
expect(getRequestUrl('https://example.com/mcp')).toBe('https://example.com/mcp');
|
||||
});
|
||||
|
||||
it('returns href from a URL object', () => {
|
||||
expect(getRequestUrl(new URL('https://example.com/mcp'))).toBe('https://example.com/mcp');
|
||||
});
|
||||
|
||||
it('returns url from a Request object', () => {
|
||||
const req = new Request('https://example.com/mcp');
|
||||
expect(getRequestUrl(req)).toBe('https://example.com/mcp');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getRequestMethod', () => {
|
||||
it('prefers method from init', () => {
|
||||
expect(getRequestMethod('https://example.com', { method: 'POST' })).toBe('POST');
|
||||
});
|
||||
|
||||
it('falls back to Request.method', () => {
|
||||
const req = new Request('https://example.com', { method: 'PUT' });
|
||||
expect(getRequestMethod(req)).toBe('PUT');
|
||||
});
|
||||
|
||||
it('falls back to baseInit.method', () => {
|
||||
expect(getRequestMethod('https://example.com', undefined, { method: 'DELETE' })).toBe('DELETE');
|
||||
});
|
||||
|
||||
it('defaults to GET', () => {
|
||||
expect(getRequestMethod('https://example.com')).toBe('GET');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getRequestBody', () => {
|
||||
it('returns body from init', () => {
|
||||
expect(getRequestBody('https://example.com', { body: 'payload' })).toBe('payload');
|
||||
});
|
||||
|
||||
it('returns undefined when no body is present', () => {
|
||||
expect(getRequestBody('https://example.com')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('summarizeRequestBody', () => {
|
||||
it('returns empty for null', () => {
|
||||
expect(summarizeRequestBody(null)).toEqual({ kind: 'empty' });
|
||||
});
|
||||
|
||||
it('returns empty for undefined', () => {
|
||||
expect(summarizeRequestBody(undefined)).toEqual({ kind: 'empty' });
|
||||
});
|
||||
|
||||
it('returns string kind with size', () => {
|
||||
expect(summarizeRequestBody('hello')).toEqual({ kind: 'string', size: 5 });
|
||||
});
|
||||
|
||||
it('returns blob kind with size', () => {
|
||||
const blob = new Blob(['abc']);
|
||||
expect(summarizeRequestBody(blob)).toEqual({ kind: 'blob', size: 3 });
|
||||
});
|
||||
|
||||
it('returns formdata kind', () => {
|
||||
expect(summarizeRequestBody(new FormData())).toEqual({ kind: 'formdata' });
|
||||
});
|
||||
|
||||
it('returns arraybuffer kind with size', () => {
|
||||
expect(summarizeRequestBody(new ArrayBuffer(8))).toEqual({ kind: 'arraybuffer', size: 8 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('formatDiagnosticErrorMessage', () => {
|
||||
it('appends CORS hint for Failed to fetch', () => {
|
||||
expect(formatDiagnosticErrorMessage(new TypeError('Failed to fetch'))).toBe(
|
||||
'Failed to fetch (check CORS?)'
|
||||
);
|
||||
});
|
||||
|
||||
it('passes through other error messages unchanged', () => {
|
||||
expect(formatDiagnosticErrorMessage(new Error('timeout'))).toBe('timeout');
|
||||
});
|
||||
|
||||
it('handles non-Error values', () => {
|
||||
expect(formatDiagnosticErrorMessage('some string')).toBe('some string');
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractJsonRpcMethods', () => {
|
||||
it('extracts methods from a JSON-RPC array', () => {
|
||||
const body = JSON.stringify([
|
||||
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
|
||||
{ jsonrpc: '2.0', method: 'notifications/initialized' }
|
||||
]);
|
||||
expect(extractJsonRpcMethods(body)).toEqual(['initialize', 'notifications/initialized']);
|
||||
});
|
||||
|
||||
it('extracts method from a single JSON-RPC message', () => {
|
||||
const body = JSON.stringify({ jsonrpc: '2.0', id: 1, method: 'tools/list' });
|
||||
expect(extractJsonRpcMethods(body)).toEqual(['tools/list']);
|
||||
});
|
||||
|
||||
it('returns undefined for non-string body', () => {
|
||||
expect(extractJsonRpcMethods(null)).toBeUndefined();
|
||||
expect(extractJsonRpcMethods(undefined)).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns undefined for invalid JSON', () => {
|
||||
expect(extractJsonRpcMethods('not json')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns undefined when no methods found', () => {
|
||||
expect(extractJsonRpcMethods(JSON.stringify({ foo: 'bar' }))).toBeUndefined();
|
||||
});
|
||||
});
|
||||
55
tools/server/webui/tests/unit/sanitize-headers.test.ts
Normal file
55
tools/server/webui/tests/unit/sanitize-headers.test.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { sanitizeHeaders } from '$lib/utils/api-headers';
|
||||
|
||||
describe('sanitizeHeaders', () => {
|
||||
it('returns empty object for undefined input', () => {
|
||||
expect(sanitizeHeaders()).toEqual({});
|
||||
});
|
||||
|
||||
it('passes through non-sensitive headers', () => {
|
||||
const headers = new Headers({ 'content-type': 'application/json', accept: 'text/html' });
|
||||
expect(sanitizeHeaders(headers)).toEqual({
|
||||
'content-type': 'application/json',
|
||||
accept: 'text/html'
|
||||
});
|
||||
});
|
||||
|
||||
it('redacts known sensitive headers', () => {
|
||||
const headers = new Headers({
|
||||
authorization: 'Bearer secret',
|
||||
'x-api-key': 'key-123',
|
||||
'content-type': 'application/json'
|
||||
});
|
||||
const result = sanitizeHeaders(headers);
|
||||
expect(result.authorization).toBe('[redacted]');
|
||||
expect(result['x-api-key']).toBe('[redacted]');
|
||||
expect(result['content-type']).toBe('application/json');
|
||||
});
|
||||
|
||||
it('partially redacts headers specified in partialRedactHeaders', () => {
|
||||
const headers = new Headers({ 'mcp-session-id': 'session-12345' });
|
||||
const partial = new Map([['mcp-session-id', 5]]);
|
||||
expect(sanitizeHeaders(headers, undefined, partial)['mcp-session-id']).toBe('....12345');
|
||||
});
|
||||
|
||||
it('fully redacts mcp-session-id when no partialRedactHeaders is given', () => {
|
||||
const headers = new Headers({ 'mcp-session-id': 'session-12345' });
|
||||
expect(sanitizeHeaders(headers)['mcp-session-id']).toBe('[redacted]');
|
||||
});
|
||||
|
||||
it('redacts extra headers provided by the caller', () => {
|
||||
const headers = new Headers({
|
||||
'x-vendor-key': 'vendor-secret',
|
||||
'content-type': 'application/json'
|
||||
});
|
||||
const result = sanitizeHeaders(headers, ['x-vendor-key']);
|
||||
expect(result['x-vendor-key']).toBe('[redacted]');
|
||||
expect(result['content-type']).toBe('application/json');
|
||||
});
|
||||
|
||||
it('handles case-insensitive extra header names', () => {
|
||||
const headers = new Headers({ 'X-Custom-Token': 'token-value' });
|
||||
const result = sanitizeHeaders(headers, ['X-CUSTOM-TOKEN']);
|
||||
expect(result['x-custom-token']).toBe('[redacted]');
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user