Compare commits

..

9 Commits
b8327 ... b8336

Author SHA1 Message Date
Rail Chabdarov
5a32a9b8a5 Fix data race in CUDA's "cpy" kernel (influences GGML's DUP, CONT operations). (#20507)
* Fix datarace in CUDA's "cpy" kernel.

* Remove extra barrier by using more of shared memory.
2026-03-14 13:19:44 +08:00
lhez
3b439504ba opencl: fix l2_norm (#20480) 2026-03-13 22:18:52 -07:00
Adrien Gallouët
463b6a963c tools : enable kvu in perplexity for hellaswag, winogrande, multiple-choice (#19954)
llama-perplexity -hf unsloth/Qwen3-0.6B-GGUF:Q4_K_M -f winogrande-debiased-eval.csv --winogrande

    winogrande_score : tokenizing selected tasks
    winogrande_score : calculating winogrande score over selected tasks.
    split_equal: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag)
    decode: failed to find a memory slot for batch of size 46
    failed to decode the batch, n_batch = 2048, ret = 1
    winogrande_score: llama_decode() failed

same for hellaswag:

    split_equal: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag)
    decode: failed to find a memory slot for batch of size 99
    failed to decode the batch, n_batch = 2048, ret = 1
    hellaswag_score: llama_decode() failed

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-03-13 21:25:57 +01:00
Georgi Gerganov
e30f1fdf74 graph : remove redundant GDN state transposes (#20443)
* ggml : transpose fused GDN state access for coalesced memory reads (#20436)

The fused Gated Delta Net kernel accessed the [S_v, S_v] state matrix
column-wise on row-major storage, causing strided reads (stride S_v =
128 floats = 512 bytes) that waste GPU cache bandwidth. This produced a
39% regression on Qwen3.5-9B (Metal, M4 Max) compared to the unfused
path.

Transpose the state indexing so threads read contiguously:
- Metal: s_ptr[is*S_v] -> s_ptr[is] (stride 1 vs S_v)
- CUDA:  curr_state[i*S_v+col] -> curr_state[col*S_v+i] (coalesced)
- CPU:   restructured loops for row-wise transposed access

Also add --fused-gdn [on|off|auto] CLI flag (mirrors --flash-attn) so
users can control fused GDN independently of auto-detection.

All GATED_DELTA_NET backend-ops tests pass.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* ggml : use SIMD dot products in CPU GDN kernel, couple AR/chunked fused flags

- Replace scalar inner loops with ggml_vec_dot_f32 for SIMD-optimized
  dot products in the CPU fused GDN kernel (delta and attention output)
- Couple fused_gdn_ar and fused_gdn_ch flags in auto-detection: if one
  path lacks device support, disable both to prevent state layout mismatch
  between transposed (fused) and non-transposed (unfused) formats

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* llama : rever fgdn argument changes

* graph : remove GDN state transposes

* vulkan : adapt

* cuda : remove obsolete smem code

---------

Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Oliver Simons <osimons@nvidia.com>
2026-03-13 22:12:54 +02:00
Piotr Wilkin (ilintar)
1430c35948 common/parser: gracefully handle undetected tool parser, print error message. (#20286) 2026-03-13 20:56:10 +01:00
ZeroV0LT
f17b3be63f llama : fix pooling assertion crash in chunked GDN detection path (#20468)
* llama : fix pooling assertion crash in chunked GDN detection path

The chunked fused Gated Delta Net detection in sched_reserve() calls
graph_reserve(16*n_seqs, n_seqs, n_outputs, ...) where n_outputs = n_seqs.
This creates a dimension mismatch in build_pooling() for embedding models
with mean/rank pooling: build_inp_mean() creates a tensor with shape
[n_tokens=16*n_seqs, ...] while t_embd is reduced to [n_outputs=n_seqs, ...]
via out_ids, causing ggml_mul_mat to assert on ggml_can_mul_mat(a, b).

Fix: pass n_tokens as n_outputs in the chunked GDN graph reservation,
matching the pattern used by the pp/tg worst-case reservations.

Regression introduced by #20340 (d28961d).
Same class of bug as #12517, fixed by #12545.

* server : add mean pooling tests to embedding test suite

Add test_embedding_pooling_mean and test_embedding_pooling_mean_multiple
to cover the --pooling mean codepath, which was previously untested.

These tests would have caught the regression introduced by #20340 where
build_pooling() crashes with a ggml_mul_mat assertion due to mismatched
dimensions in the chunked GDN detection path.

---------

Co-authored-by: Domenico Crupi <domenico@zerovolt.it>
2026-03-13 20:53:42 +02:00
SoftwareRenderer
d7ba99c485 server: reset counter related to kill-switch on client error (#20513)
* server: reset kill-switch on client error

This avoids triggering a server kill switch.

If the client sends a request that exceeds the configured context size, an appropriate HTTP 400 response is provided and no tokens are generated.

However since no tokens are generated, update_slots() increments n_empty_consecutive. If the client sends 3 such messages in a row, the server terminates.

* moved counter reset as per recommendation

* cont : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-03-13 19:58:09 +02:00
rehan-10xengineer
fbaa95bc29 ggml-cpu: add RVV vec dot kernels for quantization types (#18859)
* ggml-cpu: add rvv quantize_row_q8_K kernel

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>

* ggml-cpu: add rvv vec_dot for iq4_nl, mxfp4, iq2_xxs

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>

* ggml-cpu: add rvv vec_dot for iq4_xs, refactor

* ggml-cpu: remove ifunc for rvv vec dot

* ggml-cpu: add vec_dot for iq2_xs, iq3_xxs

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>

* ggml-cpu: refactor quants.c

---------

Co-authored-by: taimur-10x <taimur.ahmad@10xengineers.ai>
Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>
Co-authored-by: Rehan Qasim <rehanbhatti0317@gmail.com>
2026-03-13 17:36:04 +02:00
Adrien Gallouët
b5e1212063 ggml : fix typo gmml (#20512)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-03-13 14:36:13 +01:00
15 changed files with 1530 additions and 623 deletions

View File

@@ -3,6 +3,7 @@
#include "chat.h"
#include "common.h"
#include "json-schema-to-grammar.h"
#include "log.h"
#include "nlohmann/json.hpp"
#include <stdexcept>
@@ -182,7 +183,10 @@ common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const
case tool_format::TAG_WITH_TAGGED:
return build_tool_parser_tag_tagged(ctx);
default:
GGML_ABORT("Unable to create tool parser");
LOG_ERR("[ERROR] Template seems to support tool calls, but failed to determine tool format. Tool calling will not work properly. "
"Check for a fixed template for your model in the models/templates directory of your llama.cpp installation or "
"report an issue at https://github.com/ggml-org/llama.cpp/issues\n");
return ctx.p.eps();
}
}

View File

@@ -253,7 +253,7 @@ option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increas
option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON)
option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON)
set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
"gmml: OpenCL API version to target")
"ggml: OpenCL API version to target")
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")

View File

@@ -199,13 +199,6 @@
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__riscv)
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1

File diff suppressed because it is too large Load Diff

View File

@@ -9624,7 +9624,7 @@ void ggml_compute_forward_win_unpart(
}
}
//gmml_compute_forward_unary
//ggml_compute_forward_unary
void ggml_compute_forward_unary(
const ggml_compute_params * params,
@@ -10477,34 +10477,40 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
// state is stored transposed: s_out[j*S_v + i] = S[i][j]
// so row j of s_out = column j of S (contiguous access)
if (kda) {
// precompute exp(g) into delta scratch (reused below)
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i]));
delta[i] = expf(g_d[i]);
}
// S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
for (int64_t j = 0; j < S_v; ++j) {
ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
}
} else {
ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
}
// delta[j] = sum_i S[j][i] * k[i]
memset(delta, 0, S_v * sizeof(float));
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]);
}
// delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
for (int64_t j = 0; j < S_v; ++j) {
delta[j] = (v_d[j] - delta[j]) * beta_val;
float sum = 0.0f;
ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
delta[j] = (v_d[j] - sum) * beta_val;
}
// outer product: S[j][i] += k[i] * delta[j]
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]);
// outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
for (int64_t j = 0; j < S_v; ++j) {
ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
}
// attn_out[j] = sum_i S[j][i] * q[i]
memset(attn_data, 0, S_v * sizeof(float));
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]);
// attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
for (int64_t j = 0; j < S_v; ++j) {
float sum = 0.0f;
ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
attn_data[j] = sum * scale;
}
ggml_vec_scale_f32(S_v, attn_data, scale);
attn_data += S_v * H; // advance to next token
}

View File

@@ -56,7 +56,8 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
__shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
__shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
int cur_tile_buf = 0;
#pragma unroll
for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
@@ -70,7 +71,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
if(x < ne01 && y + j < ne00){
const int row = threadIdx.y+j;
const int col = threadIdx.x * sizeof(float)/sizeof(T);
T *tile2 = reinterpret_cast<T*>(tile[row]);
T *tile2 = reinterpret_cast<T*>(tile[cur_tile_buf][row]);
tile2[col] = src[imat*n + (y+j)*ne01 + x];
}
}
@@ -81,10 +82,12 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
if (ty + j < ne01 && tx < ne00) {
const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]);
const T *tile2 = reinterpret_cast<const T*>(tile[cur_tile_buf][threadIdx.x]);
dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
}
}
cur_tile_buf = (cur_tile_buf + 1) % 2;
}
GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,

View File

@@ -45,10 +45,11 @@ __global__ void gated_delta_net_cuda(const float * q,
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
float s_shard[rows_per_lane];
// state is stored transposed: M[col][i] = S[i][col], row col is contiguous
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = curr_state[i * S_v + col];
s_shard[r] = curr_state[col * S_v + i];
}
for (int t = 0; t < n_tokens; t++) {
@@ -126,23 +127,14 @@ __global__ void gated_delta_net_cuda(const float * q,
attn_data += S_v * H;
}
// Write state back to global memory
// Write state back to global memory (transposed layout)
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
state[i * S_v + col] = s_shard[r];
state[col * S_v + i] = s_shard[r];
}
}
static size_t calculate_smem(const int sv, int cc)
{
size_t smem = 0;
if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
smem = sv * sv * sizeof(float);
}
return smem;
}
template <bool KDA>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
@@ -179,18 +171,14 @@ static void launch_gated_delta_net(
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 64: {
constexpr int sv = 64;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
}
case 128: {
constexpr int sv = 128;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);

View File

@@ -2469,13 +2469,14 @@ kernel void kernel_gated_delta_net_impl(
const float scale = 1.0f / sqrt((float)S_v);
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
float ls[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] = s_ptr[is*S_v];
ls[j] = s_ptr[is];
}
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
@@ -2536,11 +2537,11 @@ kernel void kernel_gated_delta_net_impl(
g_ptr += args.ne21*G;
}
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is*S_v] = ls[j];
dst_state[is] = ls[j];
}
#undef S_v

View File

@@ -63,7 +63,7 @@ kernel void kernel_l2_norm_f32(
barrier(CLK_LOCAL_MEM_FENCE);
const float scale = 1.0f/sqrt(max(sum[0], eps));
const float scale = 1.0f/max(sqrt(sum[0]), eps);
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
y[i00] = x[i00] * scale;

View File

@@ -44,7 +44,7 @@ void main() {
FLOAT_TYPE state[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]);
state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
}
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
@@ -123,6 +123,6 @@ void main() {
}
[[unroll]] for (uint i = 0; i < S_V; i++) {
data_dst[s_off + state_base + i * S_V + col] = state[i];
data_dst[s_off + state_base + col * S_V + i] = state[i];
}
}

View File

@@ -512,7 +512,12 @@ void llama_context::sched_reserve() {
if (cparams.fused_gdn_ch) {
// more than one token in the batch per sequence in order to take the chunked path
auto * gf = graph_reserve(16*n_seqs, n_seqs, n_outputs, mctx.get(), true);
// note: n_outputs must match n_tokens for embedding models with mean/rank pooling,
// because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies
// it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens,
// the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553).
const uint32_t n_tokens_ch = 16*n_seqs;
auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)");
}

View File

@@ -225,9 +225,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
cb(kg_t, "key_gdiff_t", il);
ggml_tensor * s_t = ggml_transpose(ctx0, s);
s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
cb(s_t, "dnet_add_ch_state", il);
s = ggml_reshape_4d(ctx0, s, S_v, S_v, 1, H_v * n_seqs);
cb(s, "dnet_add_ch_state", il);
// [CS, S_v, n_chunks, H_v * n_seqs]
ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
@@ -240,7 +239,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs]
// [CS, S_v, 1, H_v * n_seqs]
ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s);
cb(v_t_p, "v_prime", il);
// [CS, S_v, 1, H_v * n_seqs]
@@ -252,7 +251,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
cb(v_attn, "v_attn", il);
// [S_v, CS, 1, H_v * n_seqs]
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s, ch_q_g_exp);
cb(attn_inter, "attn_inter", il);
// [S_v, CS, 1, H_v * n_seqs]
@@ -268,13 +267,11 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk);
s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t);
s_t = ggml_add(ctx0, s_t, kgv);
cb(s_t, "dnet_add_ch_state", il);
s = ggml_mul(ctx0, s, ch_g_last_exp_t);
s = ggml_add(ctx0, s, kgv);
cb(s, "dnet_add_ch_state", il);
}
s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
// truncate padded tokens
ggml_tensor * o = ggml_view_4d(ctx0, v,
S_v, n_tokens, H_v, n_seqs,
@@ -282,7 +279,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
ggml_row_size(v->type, S_v * CS * n_chunks),
ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
s = ggml_transpose(ctx0, s_t);
s = ggml_reshape_4d(ctx0, s, S_v, S_v, H_v, n_seqs);
cb(s, "output_state", il);
return {o, s};
@@ -341,11 +338,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
g = ggml_exp(ctx0, g);
s = ggml_mul(ctx0, s, g);
ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
// [1, S_v, H_v, n_seqs]
ggml_tensor * sk;
sk = ggml_mul (ctx0, s_t, k);
sk = ggml_mul (ctx0, s, k);
sk = ggml_sum_rows(ctx0, sk);
// [S_v, 1, H_v, n_seqs]
@@ -362,15 +357,14 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
k = ggml_repeat(ctx0, k, s);
kd = ggml_mul (ctx0, k, d_t);
s_t = ggml_add(ctx0, s_t, kd);
s = ggml_add(ctx0, s, kd);
cb(s_t, "dnet_add_ar_state", il);
cb(s, "dnet_add_ar_state", il);
ggml_tensor * s_q = ggml_mul (ctx0, s_t, q);
ggml_tensor * s_q = ggml_mul (ctx0, s, q);
ggml_tensor * o = ggml_sum_rows(ctx0, s_q);
o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
return {o, s};
}

View File

@@ -2025,21 +2025,14 @@ int main(int argc, char ** argv) {
return 1;
}
const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
if (ppl || params.kl_divergence) {
const int32_t n_seq = std::max(1, params.n_batch / n_ctx);
const int32_t n_kv = n_seq * n_ctx;
params.n_parallel = n_seq;
params.n_ctx = n_kv;
params.n_batch = std::min(params.n_batch, n_kv);
} else {
params.n_batch = std::min(params.n_batch, params.n_ctx);
// ensure there's at least enough seq_ids for HellaSwag
if (params.hellaswag || params.winogrande || params.multiple_choice) {
params.n_parallel = std::max(4, params.n_parallel);
params.kv_unified = true;
} else { // Perplexity & KL divergence
params.n_parallel = std::max(1, params.n_batch / n_ctx);
}
params.n_ctx = params.n_parallel * n_ctx;
params.n_batch = std::min(params.n_batch, params.n_ctx);
if (params.ppl_stride > 0) {
LOG_INF("Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",

View File

@@ -1189,6 +1189,9 @@ private:
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
: SLOT_STATE_STARTED;
// reset server kill-switch counter
n_empty_consecutive = 0;
SLT_INF(slot, "processing task, is_child = %d\n", slot.task->is_child());
return true;
}

View File

@@ -101,6 +101,40 @@ def test_embedding_mixed_input(input, is_multi_prompt: bool):
assert len(data[0]['embedding']) > 1
def test_embedding_pooling_mean():
global server
server.pooling = 'mean'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
assert len(res.body['data']) == 1
assert 'embedding' in res.body['data'][0]
assert len(res.body['data'][0]['embedding']) > 1
# make sure embedding vector is normalized
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
def test_embedding_pooling_mean_multiple():
global server
server.pooling = 'mean'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI",
"This is a test",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 3
for d in res.body['data']:
assert 'embedding' in d
assert len(d['embedding']) > 1
def test_embedding_pooling_none():
global server
server.pooling = 'none'