mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-30 16:47:31 +03:00
Compare commits
15 Commits
b6912
...
gg/clip-fa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d441c31b19 | ||
|
|
cdb3deae76 | ||
|
|
b67a168f10 | ||
|
|
29330dcb55 | ||
|
|
bdb43f6e9c | ||
|
|
b4955f0ae6 | ||
|
|
19116a4b38 | ||
|
|
2f68ce7cfd | ||
|
|
e4a71599e5 | ||
|
|
dd5e8cab51 | ||
|
|
cf659bbb8e | ||
|
|
d8b860a219 | ||
|
|
1ae74882f8 | ||
|
|
a4b54f2697 | ||
|
|
3aa835bfe6 |
@@ -707,6 +707,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
if (op->src[0]->ne[0] != 32 &&
|
||||
op->src[0]->ne[0] != 40 &&
|
||||
op->src[0]->ne[0] != 64 &&
|
||||
op->src[0]->ne[0] != 72 &&
|
||||
op->src[0]->ne[0] != 80 &&
|
||||
op->src[0]->ne[0] != 96 &&
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
|
||||
@@ -5362,6 +5362,7 @@ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, hal
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
|
||||
@@ -5374,6 +5375,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
|
||||
@@ -5387,6 +5389,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
||||
@@ -5400,6 +5403,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
|
||||
@@ -5412,6 +5416,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
|
||||
@@ -5424,6 +5429,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
|
||||
@@ -5436,6 +5442,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
|
||||
@@ -5448,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
|
||||
|
||||
@@ -7225,8 +7225,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
|
||||
}
|
||||
|
||||
for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) {
|
||||
for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) {
|
||||
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
|
||||
for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
|
||||
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
|
||||
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
|
||||
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
|
||||
|
||||
@@ -154,8 +154,8 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_LFM2,
|
||||
PROJECTOR_TYPE_KIMIVL,
|
||||
PROJECTOR_TYPE_LIGHTONOCR,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
PROJECTOR_TYPE_COGVLM,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
@@ -22,9 +23,16 @@ enum clip_modality {
|
||||
CLIP_MODALITY_AUDIO,
|
||||
};
|
||||
|
||||
enum clip_flash_attn_type {
|
||||
CLIP_FLASH_ATTN_TYPE_AUTO = -1,
|
||||
CLIP_FLASH_ATTN_TYPE_DISABLED = 0,
|
||||
CLIP_FLASH_ATTN_TYPE_ENABLED = 1,
|
||||
};
|
||||
|
||||
struct clip_context_params {
|
||||
bool use_gpu;
|
||||
enum ggml_log_level verbosity;
|
||||
enum clip_flash_attn_type flash_attn_type;
|
||||
};
|
||||
|
||||
struct clip_init_result {
|
||||
|
||||
@@ -136,6 +136,7 @@ struct mtmd_cli_context {
|
||||
mparams.print_timings = true;
|
||||
mparams.n_threads = params.cpuparams.n_threads;
|
||||
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params.flash_attn_type;
|
||||
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
|
||||
if (!ctx_vision.get()) {
|
||||
LOG_ERR("Failed to load vision model from %s\n", clip_path);
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
// represents raw image data, layout is RGBRGBRGB...
|
||||
@@ -92,6 +91,15 @@ const char * mtmd_default_marker() {
|
||||
return "<__media__>";
|
||||
}
|
||||
|
||||
static clip_flash_attn_type mtmd_get_clip_flash_attn_type(enum llama_flash_attn_type flash_attn_type) {
|
||||
switch (flash_attn_type) {
|
||||
case LLAMA_FLASH_ATTN_TYPE_AUTO: return CLIP_FLASH_ATTN_TYPE_AUTO;
|
||||
case LLAMA_FLASH_ATTN_TYPE_DISABLED: return CLIP_FLASH_ATTN_TYPE_DISABLED;
|
||||
case LLAMA_FLASH_ATTN_TYPE_ENABLED: return CLIP_FLASH_ATTN_TYPE_ENABLED;
|
||||
}
|
||||
return CLIP_FLASH_ATTN_TYPE_AUTO;
|
||||
}
|
||||
|
||||
mtmd_context_params mtmd_context_params_default() {
|
||||
mtmd_context_params params;
|
||||
params.use_gpu = true;
|
||||
@@ -100,6 +108,7 @@ mtmd_context_params mtmd_context_params_default() {
|
||||
params.verbosity = GGML_LOG_LEVEL_INFO;
|
||||
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
|
||||
params.media_marker = mtmd_default_marker();
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
return params;
|
||||
}
|
||||
|
||||
@@ -164,6 +173,7 @@ struct mtmd_context {
|
||||
clip_context_params ctx_clip_params;
|
||||
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
||||
ctx_clip_params.verbosity = ctx_params.verbosity;
|
||||
ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type);
|
||||
auto res = clip_init(mmproj_fname, ctx_clip_params);
|
||||
ctx_v = res.ctx_v;
|
||||
ctx_a = res.ctx_a;
|
||||
@@ -378,9 +388,7 @@ mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||
}
|
||||
|
||||
void mtmd_free(mtmd_context * ctx) {
|
||||
if (ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
struct mtmd_tokenizer {
|
||||
|
||||
@@ -82,6 +82,7 @@ struct mtmd_context_params {
|
||||
enum ggml_log_level verbosity;
|
||||
const char * image_marker; // deprecated, use media_marker instead
|
||||
const char * media_marker;
|
||||
enum llama_flash_attn_type flash_attn_type;
|
||||
};
|
||||
|
||||
MTMD_API const char * mtmd_default_marker(void);
|
||||
|
||||
Binary file not shown.
@@ -2456,6 +2456,7 @@ struct server_context {
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params_base.flash_attn_type;
|
||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
||||
if (mctx == nullptr) {
|
||||
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
||||
|
||||
@@ -3,7 +3,16 @@
|
||||
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
||||
import { isLoading } from '$lib/stores/chat.svelte';
|
||||
import { fade } from 'svelte/transition';
|
||||
import { Check, Copy, Package, X } from '@lucide/svelte';
|
||||
import {
|
||||
Check,
|
||||
Copy,
|
||||
Package,
|
||||
X,
|
||||
Gauge,
|
||||
Clock,
|
||||
WholeWord,
|
||||
ChartNoAxesColumn
|
||||
} from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
||||
@@ -76,8 +85,8 @@
|
||||
let displayedModel = $derived((): string | null => {
|
||||
if (!currentConfig.showModelInfo) return null;
|
||||
|
||||
if (currentConfig.modelSelectorEnabled) {
|
||||
return message.model ?? null;
|
||||
if (message.model) {
|
||||
return message.model;
|
||||
}
|
||||
|
||||
return serverModel;
|
||||
@@ -160,22 +169,58 @@
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if displayedModel()}
|
||||
<span class="mt-6 mb-4 inline-flex items-center gap-1 text-xs text-muted-foreground">
|
||||
<Package class="h-3.5 w-3.5" />
|
||||
<div class="info my-6 grid gap-4">
|
||||
{#if displayedModel()}
|
||||
<span class="inline-flex items-center gap-2 text-xs text-muted-foreground">
|
||||
<span class="inline-flex items-center gap-1">
|
||||
<Package class="h-3.5 w-3.5" />
|
||||
|
||||
<span>Model used:</span>
|
||||
<span>Model used:</span>
|
||||
</span>
|
||||
|
||||
<button
|
||||
class="inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
|
||||
onclick={handleCopyModel}
|
||||
>
|
||||
{displayedModel()}
|
||||
<button
|
||||
class="inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
|
||||
onclick={handleCopyModel}
|
||||
>
|
||||
{displayedModel()}
|
||||
|
||||
<Copy class="ml-1 h-3 w-3 " />
|
||||
</button>
|
||||
</span>
|
||||
{/if}
|
||||
<Copy class="ml-1 h-3 w-3 " />
|
||||
</button>
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
{#if currentConfig.showMessageStats && message.timings && message.timings.predicted_n && message.timings.predicted_ms}
|
||||
{@const tokensPerSecond = (message.timings.predicted_n / message.timings.predicted_ms) * 1000}
|
||||
<span class="inline-flex items-center gap-2 text-xs text-muted-foreground">
|
||||
<span class="inline-flex items-center gap-1">
|
||||
<ChartNoAxesColumn class="h-3.5 w-3.5" />
|
||||
|
||||
<span>Statistics:</span>
|
||||
</span>
|
||||
|
||||
<div class="inline-flex flex-wrap items-center gap-2 text-xs text-muted-foreground">
|
||||
<span
|
||||
class="inline-flex items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
|
||||
>
|
||||
<Gauge class="h-3 w-3" />
|
||||
{tokensPerSecond.toFixed(2)} tokens/s
|
||||
</span>
|
||||
<span
|
||||
class="inline-flex items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
|
||||
>
|
||||
<WholeWord class="h-3 w-3" />
|
||||
{message.timings.predicted_n} tokens
|
||||
</span>
|
||||
<span
|
||||
class="inline-flex items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
|
||||
>
|
||||
<Clock class="h-3 w-3" />
|
||||
{(message.timings.predicted_ms / 1000).toFixed(2)}s
|
||||
</span>
|
||||
</div>
|
||||
</span>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if message.timestamp && !isEditing}
|
||||
<ChatMessageActions
|
||||
|
||||
@@ -52,6 +52,11 @@
|
||||
{ value: 'dark', label: 'Dark', icon: Moon }
|
||||
]
|
||||
},
|
||||
{
|
||||
key: 'showMessageStats',
|
||||
label: 'Show message generation statistics',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'showTokensPerSecond',
|
||||
label: 'Show tokens per second',
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
<script lang="ts">
|
||||
import { Dialog as DialogPrimitive } from 'bits-ui';
|
||||
import XIcon from '@lucide/svelte/icons/x';
|
||||
|
||||
interface Props {
|
||||
open: boolean;
|
||||
code: string;
|
||||
language: string;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
}
|
||||
|
||||
let { open = $bindable(), code, language, onOpenChange }: Props = $props();
|
||||
|
||||
let iframeRef = $state<HTMLIFrameElement | null>(null);
|
||||
|
||||
$effect(() => {
|
||||
if (!iframeRef) return;
|
||||
|
||||
if (open) {
|
||||
iframeRef.srcdoc = code;
|
||||
} else {
|
||||
iframeRef.srcdoc = '';
|
||||
}
|
||||
});
|
||||
|
||||
function handleOpenChange(nextOpen: boolean) {
|
||||
open = nextOpen;
|
||||
onOpenChange?.(nextOpen);
|
||||
}
|
||||
</script>
|
||||
|
||||
<DialogPrimitive.Root {open} onOpenChange={handleOpenChange}>
|
||||
<DialogPrimitive.Portal>
|
||||
<DialogPrimitive.Overlay class="code-preview-overlay" />
|
||||
|
||||
<DialogPrimitive.Content class="code-preview-content">
|
||||
<iframe
|
||||
bind:this={iframeRef}
|
||||
title="Preview {language}"
|
||||
sandbox="allow-scripts"
|
||||
class="code-preview-iframe"
|
||||
></iframe>
|
||||
|
||||
<DialogPrimitive.Close
|
||||
class="code-preview-close absolute top-4 right-4 border-none bg-transparent text-white opacity-70 mix-blend-difference transition-opacity hover:opacity-100 focus-visible:ring-0 focus-visible:ring-offset-0 focus-visible:outline-none disabled:pointer-events-none [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-8"
|
||||
aria-label="Close preview"
|
||||
>
|
||||
<XIcon />
|
||||
<span class="sr-only">Close preview</span>
|
||||
</DialogPrimitive.Close>
|
||||
</DialogPrimitive.Content>
|
||||
</DialogPrimitive.Portal>
|
||||
</DialogPrimitive.Root>
|
||||
|
||||
<style lang="postcss">
|
||||
:global(.code-preview-overlay) {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
background-color: transparent;
|
||||
z-index: 100000;
|
||||
}
|
||||
|
||||
:global(.code-preview-content) {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
top: 0 !important;
|
||||
left: 0 !important;
|
||||
width: 100dvw;
|
||||
height: 100dvh;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
border: none;
|
||||
border-radius: 0;
|
||||
background-color: transparent;
|
||||
box-shadow: none;
|
||||
display: block;
|
||||
overflow: hidden;
|
||||
transform: none !important;
|
||||
z-index: 100001;
|
||||
}
|
||||
|
||||
:global(.code-preview-iframe) {
|
||||
display: block;
|
||||
width: 100dvw;
|
||||
height: 100dvh;
|
||||
border: 0;
|
||||
}
|
||||
|
||||
:global(.code-preview-close) {
|
||||
position: absolute;
|
||||
z-index: 100002;
|
||||
}
|
||||
</style>
|
||||
@@ -15,6 +15,7 @@
|
||||
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
||||
import { mode } from 'mode-watcher';
|
||||
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
|
||||
import CodePreviewDialog from './CodePreviewDialog.svelte';
|
||||
|
||||
interface Props {
|
||||
content: string;
|
||||
@@ -25,6 +26,9 @@
|
||||
|
||||
let containerRef = $state<HTMLDivElement>();
|
||||
let processedHtml = $state('');
|
||||
let previewDialogOpen = $state(false);
|
||||
let previewCode = $state('');
|
||||
let previewLanguage = $state('text');
|
||||
|
||||
function loadHighlightTheme(isDark: boolean) {
|
||||
if (!browser) return;
|
||||
@@ -117,7 +121,6 @@
|
||||
|
||||
const rawCode = codeElement.textContent || '';
|
||||
const codeId = `code-${Date.now()}-${index}`;
|
||||
|
||||
codeElement.setAttribute('data-code-id', codeId);
|
||||
codeElement.setAttribute('data-raw-code', rawCode);
|
||||
|
||||
@@ -138,11 +141,30 @@
|
||||
copyButton.setAttribute('type', 'button');
|
||||
|
||||
copyButton.innerHTML = `
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-copy-icon lucide-copy"><rect width="14" height="14" x="8" y="8" rx="2" ry="2"/><path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/></svg>
|
||||
`;
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-copy-icon lucide-copy"><rect width="14" height="14" x="8" y="8" rx="2" ry="2"/><path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/></svg>
|
||||
`;
|
||||
|
||||
const actions = document.createElement('div');
|
||||
actions.className = 'code-block-actions';
|
||||
|
||||
actions.appendChild(copyButton);
|
||||
|
||||
if (language.toLowerCase() === 'html') {
|
||||
const previewButton = document.createElement('button');
|
||||
previewButton.className = 'preview-code-btn';
|
||||
previewButton.setAttribute('data-code-id', codeId);
|
||||
previewButton.setAttribute('title', 'Preview code');
|
||||
previewButton.setAttribute('type', 'button');
|
||||
|
||||
previewButton.innerHTML = `
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-eye lucide-eye-icon"><path d="M2.062 12.345a1 1 0 0 1 0-.69C3.5 7.73 7.36 5 12 5s8.5 2.73 9.938 6.655a1 1 0 0 1 0 .69C20.5 16.27 16.64 19 12 19s-8.5-2.73-9.938-6.655"/><circle cx="12" cy="12" r="3"/></svg>
|
||||
`;
|
||||
|
||||
actions.appendChild(previewButton);
|
||||
}
|
||||
|
||||
header.appendChild(languageLabel);
|
||||
header.appendChild(copyButton);
|
||||
header.appendChild(actions);
|
||||
wrapper.appendChild(header);
|
||||
|
||||
const clonedPre = pre.cloneNode(true) as HTMLElement;
|
||||
@@ -180,49 +202,105 @@
|
||||
}
|
||||
}
|
||||
|
||||
function setupCopyButtons() {
|
||||
function getCodeInfoFromTarget(target: HTMLElement) {
|
||||
const wrapper = target.closest('.code-block-wrapper');
|
||||
|
||||
if (!wrapper) {
|
||||
console.error('No wrapper found');
|
||||
return null;
|
||||
}
|
||||
|
||||
const codeElement = wrapper.querySelector<HTMLElement>('code[data-code-id]');
|
||||
|
||||
if (!codeElement) {
|
||||
console.error('No code element found in wrapper');
|
||||
return null;
|
||||
}
|
||||
|
||||
const rawCode = codeElement.getAttribute('data-raw-code');
|
||||
|
||||
if (rawCode === null) {
|
||||
console.error('No raw code found');
|
||||
return null;
|
||||
}
|
||||
|
||||
const languageLabel = wrapper.querySelector<HTMLElement>('.code-language');
|
||||
const language = languageLabel?.textContent?.trim() || 'text';
|
||||
|
||||
return { rawCode, language };
|
||||
}
|
||||
|
||||
async function handleCopyClick(event: Event) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
|
||||
const target = event.currentTarget as HTMLButtonElement | null;
|
||||
|
||||
if (!target) {
|
||||
return;
|
||||
}
|
||||
|
||||
const info = getCodeInfoFromTarget(target);
|
||||
|
||||
if (!info) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await copyCodeToClipboard(info.rawCode);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy code:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function handlePreviewClick(event: Event) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
|
||||
const target = event.currentTarget as HTMLButtonElement | null;
|
||||
|
||||
if (!target) {
|
||||
return;
|
||||
}
|
||||
|
||||
const info = getCodeInfoFromTarget(target);
|
||||
|
||||
if (!info) {
|
||||
return;
|
||||
}
|
||||
|
||||
previewCode = info.rawCode;
|
||||
previewLanguage = info.language;
|
||||
previewDialogOpen = true;
|
||||
}
|
||||
|
||||
function setupCodeBlockActions() {
|
||||
if (!containerRef) return;
|
||||
|
||||
const copyButtons = containerRef.querySelectorAll('.copy-code-btn');
|
||||
const wrappers = containerRef.querySelectorAll<HTMLElement>('.code-block-wrapper');
|
||||
|
||||
for (const button of copyButtons) {
|
||||
button.addEventListener('click', async (e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
for (const wrapper of wrappers) {
|
||||
const copyButton = wrapper.querySelector<HTMLButtonElement>('.copy-code-btn');
|
||||
const previewButton = wrapper.querySelector<HTMLButtonElement>('.preview-code-btn');
|
||||
|
||||
const target = e.currentTarget as HTMLButtonElement;
|
||||
const codeId = target.getAttribute('data-code-id');
|
||||
if (copyButton && copyButton.dataset.listenerBound !== 'true') {
|
||||
copyButton.dataset.listenerBound = 'true';
|
||||
copyButton.addEventListener('click', handleCopyClick);
|
||||
}
|
||||
|
||||
if (!codeId) {
|
||||
console.error('No code ID found on button');
|
||||
return;
|
||||
}
|
||||
if (previewButton && previewButton.dataset.listenerBound !== 'true') {
|
||||
previewButton.dataset.listenerBound = 'true';
|
||||
previewButton.addEventListener('click', handlePreviewClick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find the code element within the same wrapper
|
||||
const wrapper = target.closest('.code-block-wrapper');
|
||||
if (!wrapper) {
|
||||
console.error('No wrapper found');
|
||||
return;
|
||||
}
|
||||
function handlePreviewDialogOpenChange(open: boolean) {
|
||||
previewDialogOpen = open;
|
||||
|
||||
const codeElement = wrapper.querySelector('code[data-code-id]');
|
||||
if (!codeElement) {
|
||||
console.error('No code element found in wrapper');
|
||||
return;
|
||||
}
|
||||
|
||||
const rawCode = codeElement.getAttribute('data-raw-code');
|
||||
if (!rawCode) {
|
||||
console.error('No raw code found');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await copyCodeToClipboard(rawCode);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy code:', error);
|
||||
}
|
||||
});
|
||||
if (!open) {
|
||||
previewCode = '';
|
||||
previewLanguage = 'text';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,7 +321,7 @@
|
||||
|
||||
$effect(() => {
|
||||
if (containerRef && processedHtml) {
|
||||
setupCopyButtons();
|
||||
setupCodeBlockActions();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
@@ -253,6 +331,13 @@
|
||||
{@html processedHtml}
|
||||
</div>
|
||||
|
||||
<CodePreviewDialog
|
||||
open={previewDialogOpen}
|
||||
code={previewCode}
|
||||
language={previewLanguage}
|
||||
onOpenChange={handlePreviewDialogOpenChange}
|
||||
/>
|
||||
|
||||
<style>
|
||||
/* Base typography styles */
|
||||
div :global(p:not(:last-child)) {
|
||||
@@ -472,7 +557,14 @@
|
||||
letter-spacing: 0.05em;
|
||||
}
|
||||
|
||||
div :global(.copy-code-btn) {
|
||||
div :global(.code-block-actions) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
div :global(.copy-code-btn),
|
||||
div :global(.preview-code-btn) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
@@ -483,11 +575,13 @@
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
div :global(.copy-code-btn:hover) {
|
||||
div :global(.copy-code-btn:hover),
|
||||
div :global(.preview-code-btn:hover) {
|
||||
transform: scale(1.05);
|
||||
}
|
||||
|
||||
div :global(.copy-code-btn:active) {
|
||||
div :global(.copy-code-btn:active),
|
||||
div :global(.preview-code-btn:active) {
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
||||
showThoughtInProgress: false,
|
||||
disableReasoningFormat: false,
|
||||
keepStatsVisible: false,
|
||||
showMessageStats: true,
|
||||
askForTitleConfirmation: false,
|
||||
pasteLongTextToFileLen: 2500,
|
||||
pdfAsImage: false,
|
||||
@@ -82,6 +83,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
disableReasoningFormat:
|
||||
'Show raw LLM output without backend parsing and frontend Markdown rendering to inspect streaming across different models.',
|
||||
keepStatsVisible: 'Keep processing statistics visible after generation finishes.',
|
||||
showMessageStats:
|
||||
'Display generation statistics (tokens/second, token count, duration) below each assistant message.',
|
||||
askForTitleConfirmation:
|
||||
'Ask for confirmation before automatically changing conversation title when editing the first message.',
|
||||
pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).',
|
||||
|
||||
@@ -69,6 +69,10 @@ export const TEXT_FILE_TYPES = {
|
||||
extensions: [FileExtensionText.MD],
|
||||
mimeTypes: [MimeTypeText.MARKDOWN]
|
||||
},
|
||||
[FileTypeText.ASCIIDOC]: {
|
||||
extensions: [FileExtensionText.ADOC],
|
||||
mimeTypes: [MimeTypeText.ASCIIDOC]
|
||||
},
|
||||
[FileTypeText.JAVASCRIPT]: {
|
||||
extensions: [FileExtensionText.JS],
|
||||
mimeTypes: [MimeTypeText.JAVASCRIPT, MimeTypeText.JAVASCRIPT_APP]
|
||||
|
||||
@@ -33,6 +33,7 @@ export enum FileTypePdf {
|
||||
export enum FileTypeText {
|
||||
PLAIN_TEXT = 'plainText',
|
||||
MARKDOWN = 'markdown',
|
||||
ASCIIDOC = 'asciidoc',
|
||||
JAVASCRIPT = 'javascript',
|
||||
TYPESCRIPT = 'typescript',
|
||||
JSX = 'jsx',
|
||||
@@ -86,6 +87,7 @@ export enum FileExtensionPdf {
|
||||
export enum FileExtensionText {
|
||||
TXT = '.txt',
|
||||
MD = '.md',
|
||||
ADOC = '.adoc',
|
||||
JS = '.js',
|
||||
TS = '.ts',
|
||||
JSX = '.jsx',
|
||||
@@ -147,6 +149,7 @@ export enum MimeTypeImage {
|
||||
export enum MimeTypeText {
|
||||
PLAIN = 'text/plain',
|
||||
MARKDOWN = 'text/markdown',
|
||||
ASCIIDOC = 'text/asciidoc',
|
||||
JAVASCRIPT = 'text/javascript',
|
||||
JAVASCRIPT_APP = 'application/javascript',
|
||||
TYPESCRIPT = 'text/typescript',
|
||||
|
||||
@@ -54,6 +54,7 @@ export class ChatService {
|
||||
onError,
|
||||
onReasoningChunk,
|
||||
onModel,
|
||||
onFirstValidChunk,
|
||||
// Generation parameters
|
||||
temperature,
|
||||
max_tokens,
|
||||
@@ -201,6 +202,7 @@ export class ChatService {
|
||||
onError,
|
||||
onReasoningChunk,
|
||||
onModel,
|
||||
onFirstValidChunk,
|
||||
conversationId,
|
||||
abortController.signal
|
||||
);
|
||||
@@ -267,6 +269,7 @@ export class ChatService {
|
||||
onError?: (error: Error) => void,
|
||||
onReasoningChunk?: (chunk: string) => void,
|
||||
onModel?: (model: string) => void,
|
||||
onFirstValidChunk?: () => void,
|
||||
conversationId?: string,
|
||||
abortSignal?: AbortSignal
|
||||
): Promise<void> {
|
||||
@@ -283,6 +286,7 @@ export class ChatService {
|
||||
let lastTimings: ChatMessageTimings | undefined;
|
||||
let streamFinished = false;
|
||||
let modelEmitted = false;
|
||||
let firstValidChunkEmitted = false;
|
||||
|
||||
try {
|
||||
let chunk = '';
|
||||
@@ -311,10 +315,12 @@ export class ChatService {
|
||||
try {
|
||||
const parsed: ApiChatCompletionStreamChunk = JSON.parse(data);
|
||||
|
||||
const chunkModel = this.extractModelName(parsed);
|
||||
if (chunkModel && !modelEmitted) {
|
||||
modelEmitted = true;
|
||||
onModel?.(chunkModel);
|
||||
if (!firstValidChunkEmitted && parsed.object === 'chat.completion.chunk') {
|
||||
firstValidChunkEmitted = true;
|
||||
|
||||
if (!abortSignal?.aborted) {
|
||||
onFirstValidChunk?.();
|
||||
}
|
||||
}
|
||||
|
||||
const content = parsed.choices[0]?.delta?.content;
|
||||
@@ -322,6 +328,12 @@ export class ChatService {
|
||||
const timings = parsed.timings;
|
||||
const promptProgress = parsed.prompt_progress;
|
||||
|
||||
const chunkModel = this.extractModelName(parsed);
|
||||
if (chunkModel && !modelEmitted) {
|
||||
modelEmitted = true;
|
||||
onModel?.(chunkModel);
|
||||
}
|
||||
|
||||
if (timings || promptProgress) {
|
||||
this.updateProcessingState(timings, promptProgress, conversationId);
|
||||
if (timings) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { DatabaseStore } from '$lib/stores/database';
|
||||
import { chatService, slotsService } from '$lib/services';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { normalizeModelName } from '$lib/utils/model-names';
|
||||
import { filterByLeafNodeId, findLeafNode, findDescendantMessages } from '$lib/utils/branching';
|
||||
import { browser } from '$app/environment';
|
||||
@@ -362,9 +363,41 @@ class ChatStore {
|
||||
|
||||
let resolvedModel: string | null = null;
|
||||
let modelPersisted = false;
|
||||
const currentConfig = config();
|
||||
const preferServerPropsModel = !currentConfig.modelSelectorEnabled;
|
||||
let serverPropsRefreshed = false;
|
||||
let updateModelFromServerProps: ((persistImmediately?: boolean) => void) | null = null;
|
||||
|
||||
const recordModel = (modelName: string, persistImmediately = true): void => {
|
||||
const normalizedModel = normalizeModelName(modelName);
|
||||
const refreshServerPropsOnce = () => {
|
||||
if (serverPropsRefreshed) {
|
||||
return;
|
||||
}
|
||||
|
||||
serverPropsRefreshed = true;
|
||||
|
||||
const hasExistingProps = serverStore.serverProps !== null;
|
||||
|
||||
serverStore
|
||||
.fetchServerProps({ silent: hasExistingProps })
|
||||
.then(() => {
|
||||
updateModelFromServerProps?.(true);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.warn('Failed to refresh server props after streaming started:', error);
|
||||
});
|
||||
};
|
||||
|
||||
const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => {
|
||||
const serverModelName = serverStore.modelName;
|
||||
const preferredModelSource = preferServerPropsModel
|
||||
? (serverModelName ?? modelName ?? null)
|
||||
: (modelName ?? serverModelName ?? null);
|
||||
|
||||
if (!preferredModelSource) {
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedModel = normalizeModelName(preferredModelSource);
|
||||
|
||||
if (!normalizedModel || normalizedModel === resolvedModel) {
|
||||
return;
|
||||
@@ -388,6 +421,20 @@ class ChatStore {
|
||||
}
|
||||
};
|
||||
|
||||
if (preferServerPropsModel) {
|
||||
updateModelFromServerProps = (persistImmediately = true) => {
|
||||
const currentServerModel = serverStore.modelName;
|
||||
|
||||
if (!currentServerModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
recordModel(currentServerModel, persistImmediately);
|
||||
};
|
||||
|
||||
updateModelFromServerProps(false);
|
||||
}
|
||||
|
||||
slotsService.startStreaming();
|
||||
slotsService.setActiveConversation(assistantMessage.convId);
|
||||
|
||||
@@ -396,6 +443,9 @@ class ChatStore {
|
||||
{
|
||||
...this.getApiOptions(),
|
||||
|
||||
onFirstValidChunk: () => {
|
||||
refreshServerPropsOnce();
|
||||
},
|
||||
onChunk: (chunk: string) => {
|
||||
streamedContent += chunk;
|
||||
this.setConversationStreaming(
|
||||
|
||||
@@ -52,6 +52,7 @@ class ServerStore {
|
||||
private _error = $state<string | null>(null);
|
||||
private _serverWarning = $state<string | null>(null);
|
||||
private _slotsEndpointAvailable = $state<boolean | null>(null);
|
||||
private fetchServerPropsPromise: Promise<void> | null = null;
|
||||
|
||||
private readCachedServerProps(): ApiLlamaCppServerProps | null {
|
||||
if (!browser) return null;
|
||||
@@ -171,73 +172,65 @@ class ServerStore {
|
||||
/**
|
||||
* Fetches server properties from the server
|
||||
*/
|
||||
async fetchServerProps(): Promise<void> {
|
||||
this._loading = true;
|
||||
this._error = null;
|
||||
this._serverWarning = null;
|
||||
async fetchServerProps(options: { silent?: boolean } = {}): Promise<void> {
|
||||
const { silent = false } = options;
|
||||
const isSilent = silent && this._serverProps !== null;
|
||||
|
||||
try {
|
||||
console.log('Fetching server properties...');
|
||||
const props = await ChatService.getServerProps();
|
||||
this._serverProps = props;
|
||||
this.persistServerProps(props);
|
||||
console.log('Server properties loaded:', props);
|
||||
if (this.fetchServerPropsPromise) {
|
||||
return this.fetchServerPropsPromise;
|
||||
}
|
||||
|
||||
// Check slots endpoint availability after server props are loaded
|
||||
await this.checkSlotsEndpointAvailability();
|
||||
} catch (error) {
|
||||
const hadCachedProps = this._serverProps !== null;
|
||||
let errorMessage = 'Failed to connect to server';
|
||||
let isOfflineLikeError = false;
|
||||
let isServerSideError = false;
|
||||
if (!isSilent) {
|
||||
this._loading = true;
|
||||
this._error = null;
|
||||
this._serverWarning = null;
|
||||
}
|
||||
|
||||
if (error instanceof Error) {
|
||||
// Handle specific error types with user-friendly messages
|
||||
if (error.name === 'TypeError' && error.message.includes('fetch')) {
|
||||
errorMessage = 'Server is not running or unreachable';
|
||||
isOfflineLikeError = true;
|
||||
} else if (error.message.includes('ECONNREFUSED')) {
|
||||
errorMessage = 'Connection refused - server may be offline';
|
||||
isOfflineLikeError = true;
|
||||
} else if (error.message.includes('ENOTFOUND')) {
|
||||
errorMessage = 'Server not found - check server address';
|
||||
isOfflineLikeError = true;
|
||||
} else if (error.message.includes('ETIMEDOUT')) {
|
||||
errorMessage = 'Request timed out - the server took too long to respond';
|
||||
isOfflineLikeError = true;
|
||||
} else if (error.message.includes('503')) {
|
||||
errorMessage = 'Server temporarily unavailable - try again shortly';
|
||||
isServerSideError = true;
|
||||
} else if (error.message.includes('500')) {
|
||||
errorMessage = 'Server error - check server logs';
|
||||
isServerSideError = true;
|
||||
} else if (error.message.includes('404')) {
|
||||
errorMessage = 'Server endpoint not found';
|
||||
} else if (error.message.includes('403') || error.message.includes('401')) {
|
||||
errorMessage = 'Access denied';
|
||||
const hadProps = this._serverProps !== null;
|
||||
|
||||
const fetchPromise = (async () => {
|
||||
try {
|
||||
const props = await ChatService.getServerProps();
|
||||
this._serverProps = props;
|
||||
this.persistServerProps(props);
|
||||
this._error = null;
|
||||
this._serverWarning = null;
|
||||
await this.checkSlotsEndpointAvailability();
|
||||
} catch (error) {
|
||||
if (isSilent && hadProps) {
|
||||
console.warn('Silent server props refresh failed, keeping cached data:', error);
|
||||
return;
|
||||
}
|
||||
|
||||
this.handleFetchServerPropsError(error, hadProps);
|
||||
} finally {
|
||||
if (!isSilent) {
|
||||
this._loading = false;
|
||||
}
|
||||
|
||||
this.fetchServerPropsPromise = null;
|
||||
}
|
||||
})();
|
||||
|
||||
let cachedProps: ApiLlamaCppServerProps | null = null;
|
||||
this.fetchServerPropsPromise = fetchPromise;
|
||||
|
||||
if (!hadCachedProps) {
|
||||
cachedProps = this.readCachedServerProps();
|
||||
if (cachedProps) {
|
||||
this._serverProps = cachedProps;
|
||||
this._error = null;
|
||||
await fetchPromise;
|
||||
}
|
||||
|
||||
if (isOfflineLikeError || isServerSideError) {
|
||||
this._serverWarning = errorMessage;
|
||||
}
|
||||
/**
|
||||
* Handles fetch failures by attempting to recover cached server props and
|
||||
* updating the user-facing error or warning state appropriately.
|
||||
*/
|
||||
private handleFetchServerPropsError(error: unknown, hadProps: boolean): void {
|
||||
const { errorMessage, isOfflineLikeError, isServerSideError } = this.normalizeFetchError(error);
|
||||
|
||||
console.warn(
|
||||
'Failed to refresh server properties, using cached values from localStorage:',
|
||||
errorMessage
|
||||
);
|
||||
} else {
|
||||
this._error = errorMessage;
|
||||
}
|
||||
} else {
|
||||
let cachedProps: ApiLlamaCppServerProps | null = null;
|
||||
|
||||
if (!hadProps) {
|
||||
cachedProps = this.readCachedServerProps();
|
||||
|
||||
if (cachedProps) {
|
||||
this._serverProps = cachedProps;
|
||||
this._error = null;
|
||||
|
||||
if (isOfflineLikeError || isServerSideError) {
|
||||
@@ -245,14 +238,66 @@ class ServerStore {
|
||||
}
|
||||
|
||||
console.warn(
|
||||
'Failed to refresh server properties, continuing with cached values:',
|
||||
'Failed to refresh server properties, using cached values from localStorage:',
|
||||
errorMessage
|
||||
);
|
||||
} else {
|
||||
this._error = errorMessage;
|
||||
}
|
||||
console.error('Error fetching server properties:', error);
|
||||
} finally {
|
||||
this._loading = false;
|
||||
} else {
|
||||
this._error = null;
|
||||
|
||||
if (isOfflineLikeError || isServerSideError) {
|
||||
this._serverWarning = errorMessage;
|
||||
}
|
||||
|
||||
console.warn(
|
||||
'Failed to refresh server properties, continuing with cached values:',
|
||||
errorMessage
|
||||
);
|
||||
}
|
||||
|
||||
console.error('Error fetching server properties:', error);
|
||||
}
|
||||
|
||||
private normalizeFetchError(error: unknown): {
|
||||
errorMessage: string;
|
||||
isOfflineLikeError: boolean;
|
||||
isServerSideError: boolean;
|
||||
} {
|
||||
let errorMessage = 'Failed to connect to server';
|
||||
let isOfflineLikeError = false;
|
||||
let isServerSideError = false;
|
||||
|
||||
if (error instanceof Error) {
|
||||
const message = error.message || '';
|
||||
|
||||
if (error.name === 'TypeError' && message.includes('fetch')) {
|
||||
errorMessage = 'Server is not running or unreachable';
|
||||
isOfflineLikeError = true;
|
||||
} else if (message.includes('ECONNREFUSED')) {
|
||||
errorMessage = 'Connection refused - server may be offline';
|
||||
isOfflineLikeError = true;
|
||||
} else if (message.includes('ENOTFOUND')) {
|
||||
errorMessage = 'Server not found - check server address';
|
||||
isOfflineLikeError = true;
|
||||
} else if (message.includes('ETIMEDOUT')) {
|
||||
errorMessage = 'Request timed out - the server took too long to respond';
|
||||
isOfflineLikeError = true;
|
||||
} else if (message.includes('503')) {
|
||||
errorMessage = 'Server temporarily unavailable - try again shortly';
|
||||
isServerSideError = true;
|
||||
} else if (message.includes('500')) {
|
||||
errorMessage = 'Server error - check server logs';
|
||||
isServerSideError = true;
|
||||
} else if (message.includes('404')) {
|
||||
errorMessage = 'Server endpoint not found';
|
||||
} else if (message.includes('403') || message.includes('401')) {
|
||||
errorMessage = 'Access denied';
|
||||
}
|
||||
}
|
||||
|
||||
return { errorMessage, isOfflineLikeError, isServerSideError };
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -264,6 +309,7 @@ class ServerStore {
|
||||
this._serverWarning = null;
|
||||
this._loading = false;
|
||||
this._slotsEndpointAvailable = null;
|
||||
this.fetchServerPropsPromise = null;
|
||||
this.persistServerProps(null);
|
||||
}
|
||||
}
|
||||
|
||||
1
tools/server/webui/src/lib/types/api.d.ts
vendored
1
tools/server/webui/src/lib/types/api.d.ts
vendored
@@ -186,6 +186,7 @@ export interface ApiChatCompletionRequest {
|
||||
}
|
||||
|
||||
export interface ApiChatCompletionStreamChunk {
|
||||
object?: string;
|
||||
model?: string;
|
||||
choices: Array<{
|
||||
model?: string;
|
||||
|
||||
@@ -42,6 +42,7 @@ export interface SettingsChatServiceOptions {
|
||||
onChunk?: (chunk: string) => void;
|
||||
onReasoningChunk?: (chunk: string) => void;
|
||||
onModel?: (model: string) => void;
|
||||
onFirstValidChunk?: () => void;
|
||||
onComplete?: (response: string, reasoningContent?: string, timings?: ChatMessageTimings) => void;
|
||||
onError?: (error: Error) => void;
|
||||
}
|
||||
|
||||
264
vendor/cpp-httplib/httplib.h
vendored
264
vendor/cpp-httplib/httplib.h
vendored
@@ -8,8 +8,8 @@
|
||||
#ifndef CPPHTTPLIB_HTTPLIB_H
|
||||
#define CPPHTTPLIB_HTTPLIB_H
|
||||
|
||||
#define CPPHTTPLIB_VERSION "0.26.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x001A00"
|
||||
#define CPPHTTPLIB_VERSION "0.27.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x001B00"
|
||||
|
||||
/*
|
||||
* Platform compatibility check
|
||||
@@ -1052,6 +1052,9 @@ private:
|
||||
|
||||
ssize_t write_headers(Stream &strm, const Headers &headers);
|
||||
|
||||
std::string make_host_and_port_string(const std::string &host, int port,
|
||||
bool is_ssl);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
class Server {
|
||||
@@ -1129,6 +1132,8 @@ public:
|
||||
Server &
|
||||
set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);
|
||||
|
||||
Server &set_trusted_proxies(const std::vector<std::string> &proxies);
|
||||
|
||||
Server &set_keep_alive_max_count(size_t count);
|
||||
Server &set_keep_alive_timeout(time_t sec);
|
||||
|
||||
@@ -1167,6 +1172,9 @@ protected:
|
||||
const std::function<void(Request &)> &setup_request);
|
||||
|
||||
std::atomic<socket_t> svr_sock_{INVALID_SOCKET};
|
||||
|
||||
std::vector<std::string> trusted_proxies_;
|
||||
|
||||
size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
|
||||
time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND;
|
||||
time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND;
|
||||
@@ -1719,8 +1727,6 @@ private:
|
||||
const std::string &boundary, const UploadFormDataItems &items,
|
||||
const FormDataProviderItems &provider_items) const;
|
||||
|
||||
std::string adjust_host_string(const std::string &host) const;
|
||||
|
||||
virtual bool
|
||||
process_socket(const Socket &socket,
|
||||
std::chrono::time_point<std::chrono::steady_clock> start_time,
|
||||
@@ -1953,14 +1959,17 @@ public:
|
||||
void update_certs(X509 *cert, EVP_PKEY *private_key,
|
||||
X509_STORE *client_ca_cert_store = nullptr);
|
||||
|
||||
int ssl_last_error() const { return last_ssl_error_; }
|
||||
|
||||
private:
|
||||
bool process_and_close_socket(socket_t sock) override;
|
||||
|
||||
STACK_OF(X509_NAME) * extract_ca_names_from_x509_store(X509_STORE *store);
|
||||
|
||||
SSL_CTX *ctx_;
|
||||
std::mutex ctx_mutex_;
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
|
||||
int last_ssl_error_ = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
class SSLClient final : public ClientImpl {
|
||||
@@ -4596,13 +4605,35 @@ inline bool zstd_decompressor::decompress(const char *data, size_t data_length,
|
||||
}
|
||||
#endif
|
||||
|
||||
inline bool is_prohibited_header_name(const std::string &name) {
|
||||
using udl::operator""_t;
|
||||
|
||||
switch (str2tag(name)) {
|
||||
case "REMOTE_ADDR"_t:
|
||||
case "REMOTE_PORT"_t:
|
||||
case "LOCAL_ADDR"_t:
|
||||
case "LOCAL_PORT"_t: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool has_header(const Headers &headers, const std::string &key) {
|
||||
if (is_prohibited_header_name(key)) { return false; }
|
||||
return headers.find(key) != headers.end();
|
||||
}
|
||||
|
||||
inline const char *get_header_value(const Headers &headers,
|
||||
const std::string &key, const char *def,
|
||||
size_t id) {
|
||||
if (is_prohibited_header_name(key)) {
|
||||
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
|
||||
std::string msg = "Prohibited header name '" + key + "' is specified.";
|
||||
throw std::invalid_argument(msg);
|
||||
#else
|
||||
return "";
|
||||
#endif
|
||||
}
|
||||
|
||||
auto rng = headers.equal_range(key);
|
||||
auto it = rng.first;
|
||||
std::advance(it, static_cast<ssize_t>(id));
|
||||
@@ -7261,6 +7292,30 @@ inline bool RegexMatcher::match(Request &request) const {
|
||||
return std::regex_match(request.path, request.matches, regex_);
|
||||
}
|
||||
|
||||
inline std::string make_host_and_port_string(const std::string &host, int port,
|
||||
bool is_ssl) {
|
||||
std::string result;
|
||||
|
||||
// Enclose IPv6 address in brackets (but not if already enclosed)
|
||||
if (host.find(':') == std::string::npos ||
|
||||
(!host.empty() && host[0] == '[')) {
|
||||
// IPv4, hostname, or already bracketed IPv6
|
||||
result = host;
|
||||
} else {
|
||||
// IPv6 address without brackets
|
||||
result = "[" + host + "]";
|
||||
}
|
||||
|
||||
// Append port if not default
|
||||
if ((!is_ssl && port == 80) || (is_ssl && port == 443)) {
|
||||
; // do nothing
|
||||
} else {
|
||||
result += ":" + std::to_string(port);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// HTTP server implementation
|
||||
@@ -7473,6 +7528,12 @@ inline Server &Server::set_header_writer(
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline Server &
|
||||
Server::set_trusted_proxies(const std::vector<std::string> &proxies) {
|
||||
trusted_proxies_ = proxies;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline Server &Server::set_keep_alive_max_count(size_t count) {
|
||||
keep_alive_max_count_ = count;
|
||||
return *this;
|
||||
@@ -8261,6 +8322,40 @@ inline bool Server::dispatch_request_for_content_reader(
|
||||
return false;
|
||||
}
|
||||
|
||||
inline std::string
|
||||
get_client_ip(const std::string &x_forwarded_for,
|
||||
const std::vector<std::string> &trusted_proxies) {
|
||||
// X-Forwarded-For is a comma-separated list per RFC 7239
|
||||
std::vector<std::string> ip_list;
|
||||
detail::split(x_forwarded_for.data(),
|
||||
x_forwarded_for.data() + x_forwarded_for.size(), ',',
|
||||
[&](const char *b, const char *e) {
|
||||
auto r = detail::trim(b, e, 0, static_cast<size_t>(e - b));
|
||||
ip_list.emplace_back(std::string(b + r.first, b + r.second));
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < ip_list.size(); ++i) {
|
||||
auto ip = ip_list[i];
|
||||
|
||||
auto is_trusted_proxy =
|
||||
std::any_of(trusted_proxies.begin(), trusted_proxies.end(),
|
||||
[&](const std::string &proxy) { return ip == proxy; });
|
||||
|
||||
if (is_trusted_proxy) {
|
||||
if (i == 0) {
|
||||
// If the trusted proxy is the first IP, there's no preceding client IP
|
||||
return ip;
|
||||
} else {
|
||||
// Return the IP immediately before the trusted proxy
|
||||
return ip_list[i - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no trusted proxy is found, return the first IP in the list
|
||||
return ip_list.front();
|
||||
}
|
||||
|
||||
inline bool
|
||||
Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||
int remote_port, const std::string &local_addr,
|
||||
@@ -8324,15 +8419,16 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||
connection_closed = true;
|
||||
}
|
||||
|
||||
req.remote_addr = remote_addr;
|
||||
if (!trusted_proxies_.empty() && req.has_header("X-Forwarded-For")) {
|
||||
auto x_forwarded_for = req.get_header_value("X-Forwarded-For");
|
||||
req.remote_addr = get_client_ip(x_forwarded_for, trusted_proxies_);
|
||||
} else {
|
||||
req.remote_addr = remote_addr;
|
||||
}
|
||||
req.remote_port = remote_port;
|
||||
req.set_header("REMOTE_ADDR", req.remote_addr);
|
||||
req.set_header("REMOTE_PORT", std::to_string(req.remote_port));
|
||||
|
||||
req.local_addr = local_addr;
|
||||
req.local_port = local_port;
|
||||
req.set_header("LOCAL_ADDR", req.local_addr);
|
||||
req.set_header("LOCAL_PORT", std::to_string(req.local_port));
|
||||
|
||||
if (req.has_header("Accept")) {
|
||||
const auto &accept_header = req.get_header_value("Accept");
|
||||
@@ -8522,7 +8618,7 @@ inline ClientImpl::ClientImpl(const std::string &host, int port,
|
||||
const std::string &client_cert_path,
|
||||
const std::string &client_key_path)
|
||||
: host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port),
|
||||
host_and_port_(adjust_host_string(host_) + ":" + std::to_string(port)),
|
||||
host_and_port_(detail::make_host_and_port_string(host_, port, is_ssl())),
|
||||
client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
|
||||
|
||||
inline ClientImpl::~ClientImpl() {
|
||||
@@ -8703,8 +8799,9 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) {
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(socket_mutex_);
|
||||
|
||||
// Set this to false immediately - if it ever gets set to true by the end of
|
||||
// the request, we know another thread instructed us to close the socket.
|
||||
// Set this to false immediately - if it ever gets set to true by the end
|
||||
// of the request, we know another thread instructed us to close the
|
||||
// socket.
|
||||
socket_should_be_closed_when_request_is_done_ = false;
|
||||
|
||||
auto is_alive = false;
|
||||
@@ -8720,10 +8817,10 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) {
|
||||
#endif
|
||||
|
||||
if (!is_alive) {
|
||||
// Attempt to avoid sigpipe by shutting down non-gracefully if it seems
|
||||
// like the other side has already closed the connection Also, there
|
||||
// cannot be any requests in flight from other threads since we locked
|
||||
// request_mutex_, so safe to close everything immediately
|
||||
// Attempt to avoid sigpipe by shutting down non-gracefully if it
|
||||
// seems like the other side has already closed the connection Also,
|
||||
// there cannot be any requests in flight from other threads since we
|
||||
// locked request_mutex_, so safe to close everything immediately
|
||||
const bool shutdown_gracefully = false;
|
||||
shutdown_ssl(socket_, shutdown_gracefully);
|
||||
shutdown_socket(socket_);
|
||||
@@ -9027,7 +9124,8 @@ inline bool ClientImpl::create_redirect_client(
|
||||
}
|
||||
}
|
||||
|
||||
// New method for robust client setup (based on basic_manual_redirect.cpp logic)
|
||||
// New method for robust client setup (based on basic_manual_redirect.cpp
|
||||
// logic)
|
||||
template <typename ClientType>
|
||||
inline void ClientImpl::setup_redirect_client(ClientType &client) {
|
||||
// Copy basic settings first
|
||||
@@ -9131,18 +9229,8 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req,
|
||||
// curl behavior)
|
||||
if (address_family_ == AF_UNIX) {
|
||||
req.set_header("Host", "localhost");
|
||||
} else if (is_ssl()) {
|
||||
if (port_ == 443) {
|
||||
req.set_header("Host", host_);
|
||||
} else {
|
||||
req.set_header("Host", host_and_port_);
|
||||
}
|
||||
} else {
|
||||
if (port_ == 80) {
|
||||
req.set_header("Host", host_);
|
||||
} else {
|
||||
req.set_header("Host", host_and_port_);
|
||||
}
|
||||
req.set_header("Host", host_and_port_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9409,12 +9497,6 @@ inline Result ClientImpl::send_with_content_provider(
|
||||
#endif
|
||||
}
|
||||
|
||||
inline std::string
|
||||
ClientImpl::adjust_host_string(const std::string &host) const {
|
||||
if (host.find(':') != std::string::npos) { return "[" + host + "]"; }
|
||||
return host;
|
||||
}
|
||||
|
||||
inline void ClientImpl::output_log(const Request &req,
|
||||
const Response &res) const {
|
||||
if (logger_) {
|
||||
@@ -9538,8 +9620,8 @@ inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider(
|
||||
const FormDataProviderItems &provider_items) const {
|
||||
size_t cur_item = 0;
|
||||
size_t cur_start = 0;
|
||||
// cur_item and cur_start are copied to within the std::function and maintain
|
||||
// state between successive calls
|
||||
// cur_item and cur_start are copied to within the std::function and
|
||||
// maintain state between successive calls
|
||||
return [&, cur_item, cur_start](size_t offset,
|
||||
DataSink &sink) mutable -> bool {
|
||||
if (!offset && !items.empty()) {
|
||||
@@ -10251,8 +10333,8 @@ inline void ClientImpl::stop() {
|
||||
// If there is anything ongoing right now, the ONLY thread-safe thing we can
|
||||
// do is to shutdown_socket, so that threads using this socket suddenly
|
||||
// discover they can't read/write any more and error out. Everything else
|
||||
// (closing the socket, shutting ssl down) is unsafe because these actions are
|
||||
// not thread-safe.
|
||||
// (closing the socket, shutting ssl down) is unsafe because these actions
|
||||
// are not thread-safe.
|
||||
if (socket_requests_in_flight_ > 0) {
|
||||
shutdown_socket(socket_);
|
||||
|
||||
@@ -10705,6 +10787,19 @@ inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path,
|
||||
SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path,
|
||||
client_ca_cert_dir_path);
|
||||
|
||||
// Set client CA list to be sent to clients during TLS handshake
|
||||
if (client_ca_cert_file_path) {
|
||||
auto ca_list = SSL_load_client_CA_file(client_ca_cert_file_path);
|
||||
if (ca_list != nullptr) {
|
||||
SSL_CTX_set_client_CA_list(ctx_, ca_list);
|
||||
} else {
|
||||
// Failed to load client CA list, but we continue since
|
||||
// SSL_CTX_load_verify_locations already succeeded and
|
||||
// certificate verification will still work
|
||||
last_ssl_error_ = static_cast<int>(ERR_get_error());
|
||||
}
|
||||
}
|
||||
|
||||
SSL_CTX_set_verify(
|
||||
ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
|
||||
}
|
||||
@@ -10729,6 +10824,15 @@ inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key,
|
||||
} else if (client_ca_cert_store) {
|
||||
SSL_CTX_set_cert_store(ctx_, client_ca_cert_store);
|
||||
|
||||
// Extract CA names from the store and set them as the client CA list
|
||||
auto ca_list = extract_ca_names_from_x509_store(client_ca_cert_store);
|
||||
if (ca_list) {
|
||||
SSL_CTX_set_client_CA_list(ctx_, ca_list);
|
||||
} else {
|
||||
// Failed to extract CA names, record the error
|
||||
last_ssl_error_ = static_cast<int>(ERR_get_error());
|
||||
}
|
||||
|
||||
SSL_CTX_set_verify(
|
||||
ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
|
||||
}
|
||||
@@ -10809,6 +10913,44 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline STACK_OF(X509_NAME) * SSLServer::extract_ca_names_from_x509_store(
|
||||
X509_STORE *store) {
|
||||
if (!store) { return nullptr; }
|
||||
|
||||
auto ca_list = sk_X509_NAME_new_null();
|
||||
if (!ca_list) { return nullptr; }
|
||||
|
||||
// Get all objects from the store
|
||||
auto objs = X509_STORE_get0_objects(store);
|
||||
if (!objs) {
|
||||
sk_X509_NAME_free(ca_list);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Iterate through objects and extract certificate subject names
|
||||
for (int i = 0; i < sk_X509_OBJECT_num(objs); i++) {
|
||||
auto obj = sk_X509_OBJECT_value(objs, i);
|
||||
if (X509_OBJECT_get_type(obj) == X509_LU_X509) {
|
||||
auto cert = X509_OBJECT_get0_X509(obj);
|
||||
if (cert) {
|
||||
auto subject = X509_get_subject_name(cert);
|
||||
if (subject) {
|
||||
auto name_dup = X509_NAME_dup(subject);
|
||||
if (name_dup) { sk_X509_NAME_push(ca_list, name_dup); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no names were extracted, free the list and return nullptr
|
||||
if (sk_X509_NAME_num(ca_list) == 0) {
|
||||
sk_X509_NAME_free(ca_list);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ca_list;
|
||||
}
|
||||
|
||||
// SSL HTTP client implementation
|
||||
inline SSLClient::SSLClient(const std::string &host)
|
||||
: SSLClient(host, 443, std::string(), std::string()) {}
|
||||
@@ -10889,7 +11031,8 @@ inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) {
|
||||
if (ca_cert_store) {
|
||||
if (ctx_) {
|
||||
if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) {
|
||||
// Free memory allocated for old cert and use new store `ca_cert_store`
|
||||
// Free memory allocated for old cert and use new store
|
||||
// `ca_cert_store`
|
||||
SSL_CTX_set_cert_store(ctx_, ca_cert_store);
|
||||
ca_cert_store_ = ca_cert_store;
|
||||
}
|
||||
@@ -10911,10 +11054,15 @@ inline long SSLClient::get_openssl_verify_result() const {
|
||||
inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; }
|
||||
|
||||
inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) {
|
||||
return is_valid() && ClientImpl::create_and_connect_socket(socket, error);
|
||||
if (!is_valid()) {
|
||||
error = Error::SSLConnection;
|
||||
return false;
|
||||
}
|
||||
return ClientImpl::create_and_connect_socket(socket, error);
|
||||
}
|
||||
|
||||
// Assumes that socket_mutex_ is locked and that there are no requests in flight
|
||||
// Assumes that socket_mutex_ is locked and that there are no requests in
|
||||
// flight
|
||||
inline bool SSLClient::connect_with_proxy(
|
||||
Socket &socket,
|
||||
std::chrono::time_point<std::chrono::steady_clock> start_time,
|
||||
@@ -11128,6 +11276,11 @@ inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ctx_ == nullptr) {
|
||||
error = Error::SSLConnection;
|
||||
last_openssl_error_ = ERR_get_error();
|
||||
}
|
||||
|
||||
shutdown_socket(socket);
|
||||
close_socket(socket);
|
||||
return false;
|
||||
@@ -11221,21 +11374,22 @@ SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const {
|
||||
|
||||
for (decltype(count) i = 0; i < count && !dsn_matched; i++) {
|
||||
auto val = sk_GENERAL_NAME_value(alt_names, i);
|
||||
if (val->type == type) {
|
||||
auto name =
|
||||
reinterpret_cast<const char *>(ASN1_STRING_get0_data(val->d.ia5));
|
||||
auto name_len = static_cast<size_t>(ASN1_STRING_length(val->d.ia5));
|
||||
if (!val || val->type != type) { continue; }
|
||||
|
||||
switch (type) {
|
||||
case GEN_DNS: dsn_matched = check_host_name(name, name_len); break;
|
||||
auto name =
|
||||
reinterpret_cast<const char *>(ASN1_STRING_get0_data(val->d.ia5));
|
||||
if (name == nullptr) { continue; }
|
||||
|
||||
case GEN_IPADD:
|
||||
if (!memcmp(&addr6, name, addr_len) ||
|
||||
!memcmp(&addr, name, addr_len)) {
|
||||
ip_matched = true;
|
||||
}
|
||||
break;
|
||||
auto name_len = static_cast<size_t>(ASN1_STRING_length(val->d.ia5));
|
||||
|
||||
switch (type) {
|
||||
case GEN_DNS: dsn_matched = check_host_name(name, name_len); break;
|
||||
|
||||
case GEN_IPADD:
|
||||
if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) {
|
||||
ip_matched = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user