Compare commits

...

16 Commits
b8585 ... b8601

Author SHA1 Message Date
Aldehir Rojas
624733d631 common : gpt-oss handle builtin and unsolicited tool calls (#21213) 2026-03-31 13:52:42 +02:00
lainon1
0b6ff47996 fix: correct misspellings in code comments (#21217)
- emdeddings → embeddings (gemma3.cpp, gemma3n-iswa.cpp,
gemma-embedding.cpp)
- imlpemented → implemented (llama-adapter.cpp)
- interere → interfere (llama-graph.cpp)
- overridde → overridden (chat.cpp)
- stastistics → statistics (ngram-map.h)
- layed → laid (llama-kv-cache.h)
- worster → worst (llama-context.cpp)
- sequantial → sequential (llama-batch.h)
2026-03-31 13:50:51 +02:00
Seungmin Kim
eec6f85d7b CI: Enable CPU and Vulkan ARM64 Release (#21207) 2026-03-31 19:02:56 +08:00
Georgi Gerganov
9281dd135d sync : ggml 2026-03-31 14:00:41 +03:00
Georgi Gerganov
0be6c7c9ce ggml : bump version to 0.9.9 (ggml/1449) 2026-03-31 14:00:41 +03:00
Adrien Gallouët
41361c8599 common : move up common_init() and fix Windows UTF-8 logs (#21176)
The build info is now only for debug, so we avoid the duplicate
with `--version`.

The UTF-8 setup at the beginning is needed to avoid logging
garbage on Windows.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-03-31 12:53:41 +02:00
Neo Zhang
62278cedde sycl : enhance fattn perf (#21185) 2026-03-31 13:31:50 +03:00
mtmcp
90aa83c6bd common: add bounds check in common_init_result::sampler to prevent segfault on failed model load (#21082)
* common: add bounds check in common_init_result::sampler to prevent segfault on failed model load

* Revert a308e584ca

* Add regression test

* Remove regression test for init-fail sampler check
2026-03-31 13:04:42 +03:00
SATISH K C
fcc2d598c8 fix: include API key in CORS proxy requests for MCP connections (#21193)
* fix: include API key in CORS proxy requests for MCP connections

When llama-server is started with --api-key-file and --webui-mcp-proxy,
the /cors-proxy endpoint requires authentication. The WebUI was not
including the Authorization header in proxy requests, causing MCP
connections to fail with 401.

Inject getAuthHeaders() into requestInit when useProxy is true so the
proxy request carries the Bearer token alongside the forwarded target
headers.

Fixes #21167

* fix: simplify headers assignment based on reviewer suggestion

Apply buildProxiedHeaders only when useProxy is true, pass headers
directly to the transport otherwise.
2026-03-31 10:52:34 +02:00
Piotr Wilkin (ilintar)
4453e77561 server/webui: cleanup dual representation approach, simplify to openai-compat (#21090)
* server/webui: cleanup dual representation approach, simplify to openai-compat

* feat: Fix regression for Agentic Loop UI

* chore: update webui build output

* refactor: Post-review code improvements

* chore: update webui build output

* refactor: Cleanup

* chore: update webui build output

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2026-03-31 10:42:06 +02:00
Adrien Gallouët
26dac845cc vendor : update BoringSSL to 0.20260327.0 (#21211)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-03-31 09:21:54 +02:00
Galunid
5ce013cd7e common : Disable backend sampling if reasoning budget is enabled (#21209) 2026-03-31 10:14:01 +03:00
shaofeiqi
08f21453ae opencl: add q4_K gemm and gemv kernels for Adreno (#20919)
* opencl: add q4_K gemm and gemv kernels for Adreno

* opencl: fix whitespace

* opencl: add workarounds for compiler bugs on older devices

* opencl: handle fp16 denorm on X Elite

* opencl: fix kernel build error

* opencl: fix whitespace

* opencl: make q4_K cvt kernels signature consistent

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-03-30 12:19:16 -07:00
Seungmin Kim
84ae8434d0 CI : Enable CUDA and Vulkan ARM64 runners and fix CI/CD (#21122)
* CI: Enable CUDA and Vulkan ARM64 runners and fix CI/CD

Co-authored-by: Ts-sound <44093942+Ts-sound@users.noreply.github.com>

* Obtain source tag name from git tag

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Ts-sound <44093942+Ts-sound@users.noreply.github.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-30 20:24:37 +02:00
Zhihao "Zephyr" Yao
ead417f01c jinja : handle empty expressions correctly (#20913)
* Reject empty computed member expressions before returning slices[0] from parse_member_expression_arguments().

* Treat empty computed member expressions with Jinja2 undefined semantics

Treat empty computed member expressions like `a[]` as undefined instead of
raising a parser error, to match Jinja2 behavior.

- return a noop expression for empty computed member arguments
- return undefined when a computed member key evaluates to undefined
- add Jinja tests covering `a[]|default('fallback')` and `a[] is undefined`

* Handle undefined computed member properties

Move undefined-property handling to the common member access path, and add a test covering `a[undefined] is undefined`.

* Use default undefined value in member access

Initialize val and then return it when property is undefined.

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* empty statement parses to blank_expression instead of noop_statement

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-30 20:08:46 +02:00
Oliver Simons
64ac9ab66a CUDA : Fix CUB's argsort when nrows % block_size == 0 CCCL < 3.1 (#21181)
* CUDA: Fix CUB's argsort when nrows % block_size == 0 CCCL < 3.1

We wrongly calculated offset_grid as `ceildiv(nrows, block_size)`,
while it must be `ceildiv(nrows + 1, block_size)`. As a consequence, we
had uninitialized values in `offset_iterator[nrows]` for the case when
`nrows % block_size == 0`.

Fixes #21162

* Reduce nrows in test case to 256, don't need 768
2026-03-30 16:20:00 +02:00
94 changed files with 2901 additions and 1260 deletions

View File

@@ -36,7 +36,7 @@ RUN mkdir -p /app/full \
FROM ubuntu:$UBUNTU_VERSION AS base
RUN apt-get update \
&& apt-get install -y libgomp1 curl\
&& apt-get install -y libgomp1 curl \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \

View File

@@ -1,6 +1,6 @@
ARG UBUNTU_VERSION=24.04
# This needs to generally match the container host's environment.
ARG CUDA_VERSION=13.1.0
ARG CUDA_VERSION=13.1.1
# Target the CUDA build image
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
@@ -12,7 +12,9 @@ FROM ${BASE_CUDA_DEV_CONTAINER} AS build
ARG CUDA_DOCKER_ARCH=default
RUN apt-get update && \
apt-get install -y build-essential cmake python3 python3-pip git libssl-dev libgomp1
apt-get install -y gcc-14 g++-14 build-essential cmake python3 python3-pip git libssl-dev libgomp1
ENV CC=gcc-14 CXX=g++-14 CUDAHOSTCXX=g++-14
WORKDIR /app
@@ -39,7 +41,7 @@ RUN mkdir -p /app/full \
FROM ${BASE_CUDA_RUN_CONTAINER} AS base
RUN apt-get update \
&& apt-get install -y libgomp1 curl\
&& apt-get install -y libgomp1 curl \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \

View File

@@ -1,6 +1,6 @@
ARG UBUNTU_VERSION=22.04
ARG UBUNTU_VERSION=24.04
# This needs to generally match the container host's environment.
ARG CUDA_VERSION=12.4.0
ARG CUDA_VERSION=12.8.1
# Target the CUDA build image
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
@@ -12,7 +12,9 @@ FROM ${BASE_CUDA_DEV_CONTAINER} AS build
ARG CUDA_DOCKER_ARCH=default
RUN apt-get update && \
apt-get install -y build-essential cmake python3 python3-pip git libssl-dev libgomp1
apt-get install -y gcc-14 g++-14 build-essential cmake python3 python3-pip git libssl-dev libgomp1
ENV CC=gcc-14 CXX=g++-14 CUDAHOSTCXX=g++-14
WORKDIR /app
@@ -39,7 +41,7 @@ RUN mkdir -p /app/full \
FROM ${BASE_CUDA_RUN_CONTAINER} AS base
RUN apt-get update \
&& apt-get install -y libgomp1 curl\
&& apt-get install -y libgomp1 curl \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
@@ -60,7 +62,8 @@ RUN apt-get update \
git \
python3 \
python3-pip \
&& pip install --upgrade pip setuptools wheel \
python3-wheel \
&& pip install --break-system-packages --upgrade setuptools \
&& pip install --break-system-packages -r requirements.txt \
&& apt autoremove -y \
&& apt clean -y \

View File

@@ -51,7 +51,7 @@ RUN mkdir /tmp/neo/ && cd /tmp/neo/ \
&& dpkg --install *.deb
RUN apt-get update \
&& apt-get install -y libgomp1 curl\
&& apt-get install -y libgomp1 curl \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \

View File

@@ -46,7 +46,7 @@ RUN mkdir -p /app/full \
FROM ${BASE_MUSA_RUN_CONTAINER} AS base
RUN apt-get update \
&& apt-get install -y libgomp1 curl\
&& apt-get install -y libgomp1 curl \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \

View File

@@ -78,7 +78,7 @@ ARG http_proxy
ARG https_proxy
RUN apt-get update \
&& apt-get install -y libgomp1 libtbb12 curl\
&& apt-get install -y libgomp1 libtbb12 curl \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \

View File

@@ -58,7 +58,7 @@ RUN mkdir -p /app/full \
FROM ${BASE_ROCM_DEV_CONTAINER} AS base
RUN apt-get update \
&& apt-get install -y libgomp1 curl\
&& apt-get install -y libgomp1 curl \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
@@ -79,7 +79,7 @@ RUN apt-get update \
git \
python3-pip \
python3 \
python3-wheel\
python3-wheel \
&& pip install --break-system-packages --upgrade setuptools \
&& pip install --break-system-packages -r requirements.txt \
&& apt autoremove -y \

View File

@@ -49,17 +49,20 @@ COPY --from=build /app/full /app
WORKDIR /app
ENV PATH="/root/.venv/bin:/root/.local/bin:${PATH}"
# Flag for compatibility with pip
ARG UV_INDEX_STRATEGY="unsafe-best-match"
RUN apt-get update \
&& apt-get install -y \
build-essential \
curl \
git \
python3.13 \
python3.13-dev \
python3-pip \
python3-wheel \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.13 100 \
&& pip install --break-system-packages --upgrade setuptools \
&& pip install --break-system-packages -r requirements.txt \
ca-certificates \
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
&& uv python install 3.13 \
&& uv venv --python 3.13 /root/.venv \
&& uv pip install --python /root/.venv/bin/python -r requirements.txt \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \

View File

@@ -181,7 +181,7 @@ jobs:
- build: 'x64'
os: ubuntu-22.04
- build: 'arm64'
os: ubuntu-22.04-arm
os: ubuntu-24.04-arm
- build: 's390x'
os: ubuntu-24.04-s390x
- build: 'ppc64le'
@@ -207,14 +207,22 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends \
python3 python3-pip python3-dev \
python3 python3-pip python3-dev python3-wheel \
libjpeg-dev build-essential libssl-dev \
git-lfs
- name: Toolchain workaround (GCC 14)
if: ${{ contains(matrix.os, 'ubuntu-24.04') }}
run: |
sudo apt-get install -y gcc-14 g++-14
echo "CC=gcc-14" >> "$GITHUB_ENV"
echo "CXX=g++-14" >> "$GITHUB_ENV"
- name: Python Dependencies
id: python_depends
run: |
python3 -m pip install --upgrade pip
export PIP_BREAK_SYSTEM_PACKAGES="1"
python3 -m pip install --upgrade pip setuptools
pip3 install ./gguf-py
- name: Swap Endianness
@@ -292,7 +300,15 @@ jobs:
ctest -L main --verbose
ubuntu-24-vulkan:
runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
strategy:
matrix:
include:
- build: 'x64'
os: ubuntu-24.04
- build: 'arm64'
os: ubuntu-24.04-arm
runs-on: ${{ matrix.os }}
steps:
- name: Clone
@@ -302,7 +318,10 @@ jobs:
- name: Dependencies
id: depends
run: |
sudo apt-get install -y glslc libvulkan-dev libssl-dev ninja-build
sudo apt-get update
sudo apt-get install -y gcc-14 g++-14 build-essential glslc libvulkan-dev libssl-dev ninja-build
echo "CC=gcc-14" >> "$GITHUB_ENV"
echo "CXX=g++-14" >> "$GITHUB_ENV"
- name: Configure
id: cmake_configure

View File

@@ -25,184 +25,13 @@ permissions:
packages: write
jobs:
push_to_registry:
name: Push Docker image to Docker Hub
runs-on: ${{ matrix.config.runs_on }}
env:
COMMIT_SHA: ${{ github.sha }}
strategy:
fail-fast: false
matrix:
config:
# Multi-stage build
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/arm64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
- { tag: "cuda cuda12", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04", cuda_version: "12.4.0", ubuntu_version: "22.04" }
- { tag: "cuda13", dockerfile: ".devops/cuda-new.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04", cuda_version: "13.1.0", ubuntu_version: "24.04" }
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
- { tag: "s390x", dockerfile: ".devops/s390x.Dockerfile", platforms: "linux/s390x", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04-s390x" }
- { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
- { tag: "openvino", dockerfile: ".devops/openvino.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
steps:
- name: Check out the repo
uses: actions/checkout@v6
with:
fetch-depth: 0 # preserve git history, so we can determine the build number
- name: Set up QEMU
if: ${{ matrix.config.tag != 's390x' }}
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3
with:
image: tonistiigi/binfmt:qemu-v10.2.1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
- name: Log in to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Determine source tag name
id: srctag
uses: ./.github/actions/get-tag-name
env:
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
- name: Determine image tag name
id: tag
shell: bash
run: |
REPO_OWNER="${GITHUB_REPOSITORY_OWNER@L}" # to lower case
REPO_NAME="${{ github.event.repository.name }}"
PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:"
# list all tags possible
tags="${{ matrix.config.tag }}"
for tag in $tags; do
if [[ "$tag" == "cpu" ]]; then
TYPE=""
else
TYPE="-$tag"
fi
CACHETAGS="${PREFIX}buildcache${TYPE}"
FULLTAGS="${FULLTAGS:+$FULLTAGS,}${PREFIX}full${TYPE},${PREFIX}full${TYPE}-${{ steps.srctag.outputs.name }}"
LIGHTTAGS="${LIGHTTAGS:+$LIGHTTAGS,}${PREFIX}light${TYPE},${PREFIX}light${TYPE}-${{ steps.srctag.outputs.name }}"
SERVERTAGS="${SERVERTAGS:+$SERVERTAGS,}${PREFIX}server${TYPE},${PREFIX}server${TYPE}-${{ steps.srctag.outputs.name }}"
done
echo "cache_output_tags=$CACHETAGS" >> $GITHUB_OUTPUT
echo "full_output_tags=$FULLTAGS" >> $GITHUB_OUTPUT
echo "light_output_tags=$LIGHTTAGS" >> $GITHUB_OUTPUT
echo "server_output_tags=$SERVERTAGS" >> $GITHUB_OUTPUT
echo "cache_output_tags=$CACHETAGS" # print out for debugging
echo "full_output_tags=$FULLTAGS" # print out for debugging
echo "light_output_tags=$LIGHTTAGS" # print out for debugging
echo "server_output_tags=$SERVERTAGS" # print out for debugging
env:
GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}'
- name: Free Disk Space (Ubuntu)
if: ${{ matrix.config.free_disk_space == true }}
uses: ggml-org/free-disk-space@v1.3.1
with:
# this might remove tools that are actually needed,
# if set to "true" but frees about 6 GB
tool-cache: false
# all of these default to true, but feel free to set to
# "false" if necessary for your workflow
android: true
dotnet: true
haskell: true
large-packages: true
docker-images: true
swap-storage: true
- name: Build and push Full Docker image (tagged + versioned)
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.full == true }}
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
with:
context: .
push: true
platforms: ${{ matrix.config.platforms }}
# tag list is generated from step above
tags: ${{ steps.tag.outputs.full_output_tags }}
file: ${{ matrix.config.dockerfile }}
target: full
provenance: false
build-args: |
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
# using github experimental cache
#cache-from: type=gha
#cache-to: type=gha,mode=max
# return to this if the experimental github cache is having issues
#cache-to: type=local,dest=/tmp/.buildx-cache
#cache-from: type=local,src=/tmp/.buildx-cache
# using registry cache (no storage limit)
cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }}
cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max
- name: Build and push Light Docker image (tagged + versioned)
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }}
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
with:
context: .
push: true
platforms: ${{ matrix.config.platforms }}
# tag list is generated from step above
tags: ${{ steps.tag.outputs.light_output_tags }}
file: ${{ matrix.config.dockerfile }}
target: light
provenance: false
build-args: |
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
# using github experimental cache
#cache-from: type=gha
#cache-to: type=gha,mode=max
# return to this if the experimental github cache is having issues
#cache-to: type=local,dest=/tmp/.buildx-cache
#cache-from: type=local,src=/tmp/.buildx-cache
# using registry cache (no storage limit)
cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }}
cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max
- name: Build and push Server Docker image (tagged + versioned)
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }}
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
with:
context: .
push: true
platforms: ${{ matrix.config.platforms }}
# tag list is generated from step above
tags: ${{ steps.tag.outputs.server_output_tags }}
file: ${{ matrix.config.dockerfile }}
target: server
provenance: false
build-args: |
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
# using github experimental cache
#cache-from: type=gha
#cache-to: type=gha,mode=max
# return to this if the experimental github cache is having issues
#cache-to: type=local,dest=/tmp/.buildx-cache
#cache-from: type=local,src=/tmp/.buildx-cache
# using registry cache (no storage limit)
cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }}
cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max
create_tag:
name: Create and push git tag
runs-on: ubuntu-22.04
runs-on: ubuntu-slim
permissions:
contents: write
outputs:
source_tag: ${{ steps.srctag.outputs.name }}
steps:
- name: Clone
@@ -223,3 +52,391 @@ jobs:
run: |
git tag ${{ steps.srctag.outputs.name }} || exit 0
git push origin ${{ steps.srctag.outputs.name }} || exit 0
prepare_matrices:
name: Prepare Docker matrices
runs-on: ubuntu-24.04
outputs:
build_matrix: ${{ steps.matrices.outputs.build_matrix }}
merge_matrix: ${{ steps.matrices.outputs.merge_matrix }}
steps:
- name: Generate build and merge matrices
id: matrices
shell: bash
run: |
set -euo pipefail
# Keep all build targets in one place and derive merge targets from it.
cat > build-matrix.json <<'JSON'
[
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda-new.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda-new.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "musa", "dockerfile": ".devops/musa.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "intel", "dockerfile": ".devops/intel.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "vulkan", "dockerfile": ".devops/vulkan.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
{ "tag": "vulkan", "dockerfile": ".devops/vulkan.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "rocm", "dockerfile": ".devops/rocm.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "openvino", "dockerfile": ".devops/openvino.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" }
]
JSON
BUILD_MATRIX="$(jq -c . build-matrix.json)"
MERGE_MATRIX="$(jq -c '
reduce .[] as $entry ({}; .[$entry.tag] |= (
. // {
tag: $entry.tag,
arches: [],
full: false,
light: false,
server: false
}
| .full = (.full or ($entry.full // false))
| .light = (.light or ($entry.light // false))
| .server = (.server or ($entry.server // false))
| .arches += [($entry.platforms | sub("^linux/"; ""))]
))
# Backward compatibility: s390x tags are aliases of cpu for the linux/s390x platform.
| if (has("cpu") and (((.cpu.arches // []) | index("s390x")) != null)) then
. + {
s390x: {
tag: "s390x",
arches: ["s390x"],
full: .cpu.full,
light: .cpu.light,
server: .cpu.server
}
}
else
.
end
| [.[] | .arches = (.arches | unique | sort | join(" "))]
' build-matrix.json)"
echo "build_matrix=$BUILD_MATRIX" >> "$GITHUB_OUTPUT"
echo "merge_matrix=$MERGE_MATRIX" >> "$GITHUB_OUTPUT"
push_to_registry:
name: Push Docker image to Docker Registry
needs: [prepare_matrices, create_tag]
runs-on: ${{ matrix.config.runs_on }}
strategy:
fail-fast: false
matrix:
config: ${{ fromJSON(needs.prepare_matrices.outputs.build_matrix) }}
steps:
- name: Check out the repo
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ needs.create_tag.outputs.source_tag }}
- name: Set up QEMU
if: ${{ contains(matrix.config.platforms, 'linux/amd64') }}
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4
with:
image: tonistiigi/binfmt:qemu-v10.2.1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Log in to Docker Registry
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Determine image metadata
id: meta
shell: bash
run: |
set -euo pipefail
REPO_OWNER="${GITHUB_REPOSITORY_OWNER@L}" # to lower case
REPO_NAME="${{ github.event.repository.name }}"
IMAGE_REPO="ghcr.io/${REPO_OWNER}/${REPO_NAME}"
PREFIX="${IMAGE_REPO}:"
PLATFORM="${{ matrix.config.platforms }}"
ARCH_SUFFIX="${PLATFORM#linux/}"
# list all tags possible
tags="${{ matrix.config.tag }}"
for tag in $tags; do
if [[ "$tag" == "cpu" ]]; then
TYPE=""
else
TYPE="-$tag"
fi
CACHETAG="${PREFIX}buildcache${TYPE}-${ARCH_SUFFIX}"
done
SAFE_TAGS="$(echo "$tags" | tr ' ' '_')"
echo "image_repo=$IMAGE_REPO" >> $GITHUB_OUTPUT
echo "arch_suffix=$ARCH_SUFFIX" >> $GITHUB_OUTPUT
echo "cache_output_tag=$CACHETAG" >> $GITHUB_OUTPUT
echo "digest_artifact_suffix=${SAFE_TAGS}-${ARCH_SUFFIX}" >> $GITHUB_OUTPUT
echo "cache_output_tag=$CACHETAG" # print out for debugging
env:
GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}'
- name: Free Disk Space (Ubuntu)
if: ${{ matrix.config.free_disk_space == true }}
uses: ggml-org/free-disk-space@v1.3.1
with:
# this might remove tools that are actually needed,
# if set to "true" but frees about 6 GB
tool-cache: false
# all of these default to true, but feel free to set to
# "false" if necessary for your workflow
android: true
dotnet: true
haskell: true
large-packages: true
docker-images: true
swap-storage: true
- name: Build and push Full Docker image by digest
id: build_full
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.full == true }}
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7
with:
context: .
platforms: ${{ matrix.config.platforms }}
outputs: type=image,name=${{ steps.meta.outputs.image_repo }},push-by-digest=true,name-canonical=true,push=true
file: ${{ matrix.config.dockerfile }}
target: full
provenance: false
build-args: |
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
# using github experimental cache
#cache-from: type=gha
#cache-to: type=gha,mode=max
# return to this if the experimental github cache is having issues
#cache-to: type=local,dest=/tmp/.buildx-cache
#cache-from: type=local,src=/tmp/.buildx-cache
# using registry cache (no storage limit)
cache-from: type=registry,ref=${{ steps.meta.outputs.cache_output_tag }}
cache-to: type=registry,ref=${{ steps.meta.outputs.cache_output_tag }},mode=max
- name: Build and push Light Docker image by digest
id: build_light
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }}
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7
with:
context: .
platforms: ${{ matrix.config.platforms }}
outputs: type=image,name=${{ steps.meta.outputs.image_repo }},push-by-digest=true,name-canonical=true,push=true
file: ${{ matrix.config.dockerfile }}
target: light
provenance: false
build-args: |
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
# using github experimental cache
#cache-from: type=gha
#cache-to: type=gha,mode=max
# return to this if the experimental github cache is having issues
#cache-to: type=local,dest=/tmp/.buildx-cache
#cache-from: type=local,src=/tmp/.buildx-cache
# using registry cache (no storage limit)
cache-from: type=registry,ref=${{ steps.meta.outputs.cache_output_tag }}
cache-to: type=registry,ref=${{ steps.meta.outputs.cache_output_tag }},mode=max
- name: Build and push Server Docker image by digest
id: build_server
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }}
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7
with:
context: .
platforms: ${{ matrix.config.platforms }}
outputs: type=image,name=${{ steps.meta.outputs.image_repo }},push-by-digest=true,name-canonical=true,push=true
file: ${{ matrix.config.dockerfile }}
target: server
provenance: false
build-args: |
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
# using github experimental cache
#cache-from: type=gha
#cache-to: type=gha,mode=max
# return to this if the experimental github cache is having issues
#cache-to: type=local,dest=/tmp/.buildx-cache
#cache-from: type=local,src=/tmp/.buildx-cache
# using registry cache (no storage limit)
cache-from: type=registry,ref=${{ steps.meta.outputs.cache_output_tag }}
cache-to: type=registry,ref=${{ steps.meta.outputs.cache_output_tag }},mode=max
- name: Export digest metadata
shell: bash
run: |
set -euo pipefail
TAGS="${{ matrix.config.tag }}"
ARCH_SUFFIX="${{ steps.meta.outputs.arch_suffix }}"
DIGEST_FILE="/tmp/digests/${{ steps.meta.outputs.digest_artifact_suffix }}.tsv"
mkdir -p /tmp/digests
add_digest_rows() {
local image_type="$1"
local digest="$2"
if [[ -z "$digest" ]]; then
echo "Missing digest for image_type=${image_type}" >&2
exit 1
fi
for tag in $TAGS; do
printf '%s\t%s\t%s\t%s\n' "$tag" "$ARCH_SUFFIX" "$image_type" "$digest" >> "$DIGEST_FILE"
done
}
if [[ "${{ matrix.config.full }}" == "true" ]]; then
add_digest_rows "full" "${{ steps.build_full.outputs.digest }}"
fi
if [[ "${{ matrix.config.light }}" == "true" ]]; then
add_digest_rows "light" "${{ steps.build_light.outputs.digest }}"
fi
if [[ "${{ matrix.config.server }}" == "true" ]]; then
add_digest_rows "server" "${{ steps.build_server.outputs.digest }}"
fi
- name: Upload digest metadata
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7
with:
name: digests-${{ steps.meta.outputs.digest_artifact_suffix }}
path: /tmp/digests/${{ steps.meta.outputs.digest_artifact_suffix }}.tsv
if-no-files-found: error
merge_arch_tags:
name: Create shared tags from digests
needs: [prepare_matrices, push_to_registry, create_tag]
runs-on: ubuntu-24.04
strategy:
fail-fast: false
matrix:
config: ${{ fromJSON(needs.prepare_matrices.outputs.merge_matrix) }}
steps:
- name: Check out the repo
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Download digest metadata
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with:
pattern: digests-*
path: /tmp/digests
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Log in to Docker Registry
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Create tags from digests
shell: bash
run: |
set -euo pipefail
REPO_OWNER="${GITHUB_REPOSITORY_OWNER@L}" # to lower case
REPO_NAME="${{ github.event.repository.name }}"
IMAGE_REPO="ghcr.io/${REPO_OWNER}/${REPO_NAME}"
PREFIX="${IMAGE_REPO}:"
SRC_TAG="${{ needs.create_tag.outputs.source_tag }}"
TAGS="${{ matrix.config.tag }}"
ARCHES="${{ matrix.config.arches }}"
DIGEST_GLOB="/tmp/digests/*.tsv"
if ! ls ${DIGEST_GLOB} >/dev/null 2>&1; then
echo "No digest metadata found in /tmp/digests" >&2
exit 1
fi
if [[ -z "$SRC_TAG" ]]; then
echo "Missing source tag from create_tag" >&2
exit 1
fi
find_digest() {
local tag_name="$1"
local arch="$2"
local image_type="$3"
local digest
digest="$(awk -F '\t' -v t="$tag_name" -v a="$arch" -v i="$image_type" '$1 == t && $2 == a && $3 == i { print $4; exit }' ${DIGEST_GLOB})"
# Backward compatibility: s390x tags are aliases of cpu for the linux/s390x platform.
if [[ -z "$digest" && "$tag_name" == "s390x" && "$arch" == "s390x" ]]; then
digest="$(awk -F '\t' -v t="cpu" -v a="$arch" -v i="$image_type" '$1 == t && $2 == a && $3 == i { print $4; exit }' ${DIGEST_GLOB})"
fi
if [[ -z "$digest" ]]; then
echo "Missing digest for tag=${tag_name} arch=${arch} image_type=${image_type}" >&2
exit 1
fi
echo "$digest"
}
create_manifest_tags() {
local image_type="$1"
local tag_name="$2"
local suffix="$3"
local merged_tag="${PREFIX}${image_type}${suffix}"
local merged_versioned_tag="${merged_tag}-${SRC_TAG}"
local refs=()
for arch in $ARCHES; do
local digest
digest="$(find_digest "$tag_name" "$arch" "$image_type")"
refs+=("${IMAGE_REPO}@${digest}")
done
echo "Creating ${merged_tag} from ${refs[*]}"
docker buildx imagetools create --tag "${merged_tag}" "${refs[@]}"
echo "Creating ${merged_versioned_tag} from ${refs[*]}"
docker buildx imagetools create --tag "${merged_versioned_tag}" "${refs[@]}"
}
for tag in $TAGS; do
if [[ "$tag" == "cpu" ]]; then
TYPE=""
else
TYPE="-$tag"
fi
if [[ "${{ matrix.config.full }}" == "true" ]]; then
create_manifest_tags "full" "$tag" "$TYPE"
fi
if [[ "${{ matrix.config.light }}" == "true" ]]; then
create_manifest_tags "light" "$tag" "$TYPE"
fi
if [[ "${{ matrix.config.server }}" == "true" ]]; then
create_manifest_tags "server" "$tag" "$TYPE"
fi
done
env:
GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}'

View File

@@ -131,17 +131,16 @@ jobs:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz
name: llama-bin-macos-x64.tar.gz
ubuntu-22-cpu:
ubuntu-cpu:
strategy:
matrix:
include:
- build: 'x64'
os: ubuntu-22.04
- build: 'arm64'
os: ubuntu-24.04-arm
- build: 's390x'
os: ubuntu-24.04-s390x
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
# - build: 'arm64'
# os: ubuntu-22.04-arm
runs-on: ${{ matrix.os }}
@@ -165,6 +164,13 @@ jobs:
sudo apt-get update
sudo apt-get install build-essential libssl-dev
- name: Toolchain workaround (GCC 14)
if: ${{ contains(matrix.os, 'ubuntu-24.04') }}
run: |
sudo apt-get install -y gcc-14 g++-14
echo "CC=gcc-14" >> "$GITHUB_ENV"
echo "CXX=g++-14" >> "$GITHUB_ENV"
- name: Build
id: cmake_build
run: |
@@ -194,8 +200,16 @@ jobs:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz
name: llama-bin-ubuntu-${{ matrix.build }}.tar.gz
ubuntu-22-vulkan:
runs-on: ubuntu-22.04
ubuntu-vulkan:
strategy:
matrix:
include:
- build: 'x64'
os: ubuntu-22.04
- build: 'arm64'
os: ubuntu-24.04-arm
runs-on: ${{ matrix.os }}
steps:
- name: Clone
@@ -207,16 +221,23 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.21
with:
key: ubuntu-22-vulkan
key: ubuntu-vulkan-${{ matrix.build }}
evict-old-files: 1d
- name: Dependencies
id: depends
run: |
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add -
sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list
sudo apt-get update -y
sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libssl-dev
if [[ "${{ matrix.os }}" =~ "ubuntu-22.04" ]]; then
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add -
sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list
sudo apt-get update -y
sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libssl-dev
else
sudo apt-get update -y
sudo apt-get install -y gcc-14 g++-14 build-essential glslc libvulkan-dev libssl-dev ninja-build
echo "CC=gcc-14" >> "$GITHUB_ENV"
echo "CXX=g++-14" >> "$GITHUB_ENV"
fi
- name: Build
id: cmake_build
@@ -239,13 +260,13 @@ jobs:
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz
name: llama-bin-ubuntu-vulkan-x64.tar.gz
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-${{ matrix.build }}.tar.gz
name: llama-bin-ubuntu-vulkan-${{ matrix.build }}.tar.gz
ubuntu-24-openvino:
runs-on: ubuntu-24.04
@@ -977,8 +998,8 @@ jobs:
- windows-sycl
- windows-hip
- ubuntu-22-rocm
- ubuntu-22-cpu
- ubuntu-22-vulkan
- ubuntu-cpu
- ubuntu-vulkan
- ubuntu-24-openvino
- macOS-arm64
- macOS-x64
@@ -1061,9 +1082,11 @@ jobs:
**Linux:**
- [Ubuntu x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.tar.gz)
- [Ubuntu x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz)
- [Ubuntu x64 (ROCm 7.2)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-7.2-x64.tar.gz)
- [Ubuntu arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-arm64.tar.gz)
- [Ubuntu s390x (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-s390x.tar.gz)
- [Ubuntu x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz)
- [Ubuntu arm64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-arm64.tar.gz)
- [Ubuntu x64 (ROCm 7.2)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-7.2-x64.tar.gz)
- [Ubuntu x64 (OpenVINO)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-${{ needs.ubuntu-24-openvino.outputs.openvino_version }}-x64.tar.gz)
**Windows:**

View File

@@ -221,7 +221,7 @@ using chat_template_caps = jinja::caps;
struct common_chat_templates {
bool add_bos;
bool add_eos;
bool has_explicit_template; // Model had builtin template or template overridde was specified.
bool has_explicit_template; // Model had builtin template or template overridden was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
};
@@ -989,6 +989,10 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
auto analysis = p.ref("analysis");
auto preamble = p.rule("preamble", p.literal("<|channel|>commentary<|message|>") + p.content(content) + end);
auto final_msg = p.rule("final", p.literal("<|channel|>final<|message|>") + p.content(content));
// Consume any unsolicited tool calls, e.g. builtin functions
auto unsolicited = p.rule("unsolicited", p.atomic(p.optional(channel) + p.literal(" to=") + content + end));
auto any = p.rule("any", preamble | analysis);
if (has_response_format) {
@@ -1032,7 +1036,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
return p.zero_or_more(start + any) + start + (tool_call | final_msg);
}
return p.zero_or_more(start + any) + start + final_msg;
return p.zero_or_more(start + any) + start + (final_msg | unsolicited);
});
data.parser = parser.save();

View File

@@ -359,6 +359,11 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
}
void common_init() {
#if defined(_WIN32)
SetConsoleOutputCP(CP_UTF8);
SetConsoleCP(CP_UTF8);
#endif
llama_log_set(common_log_default_callback, NULL);
#ifdef NDEBUG
@@ -367,7 +372,7 @@ void common_init() {
const char * build_type = " (debug)";
#endif
LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
LOG_DBG("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
}
std::string common_params_get_system_info(const common_params & params) {
@@ -1243,6 +1248,9 @@ llama_context * common_init_result::context() {
}
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
if (seq_id < 0 || seq_id >= (int) pimpl->samplers.size()) {
return nullptr;
}
return pimpl->samplers[seq_id].get();
}

View File

@@ -539,6 +539,9 @@ private:
statement_ptr step = slices.size() > 2 ? std::move(slices[2]) : nullptr;
return mk_stmt<slice_expression>(start_pos, std::move(start), std::move(stop), std::move(step));
}
if (slices.empty()) {
return mk_stmt<blank_expression>(start_pos);
}
return std::move(slices[0]);
}

View File

@@ -771,10 +771,15 @@ value member_expression::execute_impl(context & ctx) {
}
JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
ensure_key_type_allowed(property);
value val = mk_val<value_undefined>("object_property");
if (property->is_undefined()) {
JJ_DEBUG("%s", "Member expression property is undefined, returning undefined");
return val;
}
ensure_key_type_allowed(property);
if (is_val<value_undefined>(object)) {
JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined");
return val;

View File

@@ -263,6 +263,14 @@ struct comment_statement : public statement {
// Expressions
// Represents an omitted expression in a computed member, e.g. `a[]`.
struct blank_expression : public expression {
std::string type() const override { return "BlankExpression"; }
value execute_impl(context &) override {
return mk_val<value_undefined>();
}
};
struct member_expression : public expression {
statement_ptr object;
statement_ptr property;

View File

@@ -51,7 +51,7 @@ struct common_ngram_map_value {
// statistics of a n-gram
struct common_ngram_map_key {
size_t key_idx; // index of key n-gram in token-history
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
size_t stat_idx; // index of last token of statistics computation (key_num, values)
uint16_t key_num; // number of occurrences of this key n-gram in token-history
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key

View File

@@ -383,6 +383,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
params.backend_sampling = false;
}
if (rbudget && params.backend_sampling) {
LOG_WRN("%s: backend sampling is not compatible with reasoning budget, disabling\n", __func__);
params.backend_sampling = false;
}
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,

View File

@@ -13,24 +13,30 @@ We have three Docker images available for this project:
Additionally, there the following images, similar to the above:
- `ghcr.io/ggml-org/llama.cpp:full-cuda`: Same as `full` but compiled with CUDA support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:light-cuda`: Same as `light` but compiled with CUDA support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:server-cuda`: Same as `server` but compiled with CUDA support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:full-rocm`: Same as `full` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`)
- `ghcr.io/ggml-org/llama.cpp:light-rocm`: Same as `light` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`)
- `ghcr.io/ggml-org/llama.cpp:server-rocm`: Same as `server` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`)
- `ghcr.io/ggml-org/llama.cpp:full-cuda`: Same as `full` but compiled with CUDA 12 support. (platforms: `linux/amd64`, `linux/arm64`)
- `ghcr.io/ggml-org/llama.cpp:full-cuda13`: Same as `full` but compiled with CUDA 13 support. (platforms: `linux/amd64`, `linux/arm64`)
- `ghcr.io/ggml-org/llama.cpp:light-cuda`: Same as `light` but compiled with CUDA 12 support. (platforms: `linux/amd64`, `linux/arm64`)
- `ghcr.io/ggml-org/llama.cpp:light-cuda13`: Same as `light` but compiled with CUDA 13 support. (platforms: `linux/amd64`, `linux/arm64`)
- `ghcr.io/ggml-org/llama.cpp:server-cuda`: Same as `server` but compiled with CUDA 12 support. (platforms: `linux/amd64`, `linux/arm64`)
- `ghcr.io/ggml-org/llama.cpp:server-cuda13`: Same as `server` but compiled with CUDA 13 support. (platforms: `linux/amd64`, `linux/arm64`)
- `ghcr.io/ggml-org/llama.cpp:full-rocm`: Same as `full` but compiled with ROCm support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:light-rocm`: Same as `light` but compiled with ROCm support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:server-rocm`: Same as `server` but compiled with ROCm support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:full-musa`: Same as `full` but compiled with MUSA support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:light-musa`: Same as `light` but compiled with MUSA support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:server-musa`: Same as `server` but compiled with MUSA support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:full-intel`: Same as `full` but compiled with SYCL support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:light-intel`: Same as `light` but compiled with SYCL support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:server-intel`: Same as `server` but compiled with SYCL support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:full-vulkan`: Same as `full` but compiled with Vulkan support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:light-vulkan`: Same as `light` but compiled with Vulkan support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:server-vulkan`: Same as `server` but compiled with Vulkan support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:full-vulkan`: Same as `full` but compiled with Vulkan support. (platforms: `linux/amd64`, `linux/arm64`)
- `ghcr.io/ggml-org/llama.cpp:light-vulkan`: Same as `light` but compiled with Vulkan support. (platforms: `linux/amd64`, `linux/arm64`)
- `ghcr.io/ggml-org/llama.cpp:server-vulkan`: Same as `server` but compiled with Vulkan support. (platforms: `linux/amd64`, `linux/arm64`)
- `ghcr.io/ggml-org/llama.cpp:full-openvino`: Same as `full` but compiled with OpenVino support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:light-openvino`: Same as `light` but compiled with OpenVino support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:server-openvino`: Same as `server` but compiled with OpenVino support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:full-s390x`: Identical to `full`, an alias for the `s390x` platform. (platforms: `linux/s390x`)
- `ghcr.io/ggml-org/llama.cpp:light-s390x`: Identical to `light`, an alias for the `s390x` platform. (platforms: `linux/s390x`)
- `ghcr.io/ggml-org/llama.cpp:server-s390x`: Identical to `server`, an alias for the `s390x` platform. (platforms: `linux/s390x`)
The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now).
@@ -82,7 +88,7 @@ You may want to pass in some different `ARGS`, depending on the CUDA environment
The defaults are:
- `CUDA_VERSION` set to `12.4.0`
- `CUDA_VERSION` set to `12.8.1`
- `CUDA_DOCKER_ARCH` set to the cmake build default, which includes all the supported architectures
The resulting images, are essentially the same as the non-CUDA images:

View File

@@ -24,12 +24,12 @@ int main(int argc, char ** argv) {
params.prompt = "Hello my name is";
params.n_predict = 32;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BATCHED, print_usage)) {
return 1;
}
common_init();
// number of parallel batches
int n_parallel = params.n_parallel;

View File

@@ -213,12 +213,12 @@ static bool run(llama_context * ctx, const common_params & params) {
int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG, print_usage)) {
return 1;
}
common_init();
llama_backend_init();
llama_numa_init(params.numa);

View File

@@ -545,11 +545,12 @@ int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
return 1;
}
common_init();
llama_backend_init();
llama_model_params model_params = llama_model_default_params();

View File

@@ -99,12 +99,12 @@ int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) {
return 1;
}
common_init();
params.embedding = true;
// get max number of sequences per batch

View File

@@ -37,12 +37,12 @@ int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
common_init();
llama_backend_init();
llama_numa_init(params.numa);

View File

@@ -19,12 +19,12 @@ static void print_usage(int /*argc*/, char ** argv) {
int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
return 1;
}
common_init();
// init LLM
llama_backend_init();

View File

@@ -43,12 +43,12 @@ int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
common_init();
const int W = 15; // lookahead window
const int N = 5; // n-gram size
const int G = 15; // max verification n-grams

View File

@@ -12,6 +12,8 @@ int main(int argc, char ** argv){
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
return 1;
}

View File

@@ -18,12 +18,12 @@ int main(int argc, char ** argv){
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
return 1;
}
common_init();
const int n_draft = params.speculative.n_max;
// init llama.cpp

View File

@@ -18,12 +18,12 @@ int main(int argc, char ** argv){
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
return 1;
}
common_init();
// max. number of additional tokens to draft if match is found
const int n_draft = params.speculative.n_max;

View File

@@ -163,12 +163,12 @@ int main(int argc, char ** argv) {
params.n_predict = 128;
params.n_junk = 1;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
return 1;
}
common_init();
// number of simultaneous "clients" to simulate
const int32_t n_clients = params.n_parallel;

View File

@@ -25,12 +25,12 @@ int main(int argc, char ** argv) {
params.n_keep = 32;
params.i_pos = -1;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PASSKEY, print_usage)) {
return 1;
}
common_init();
int n_junk = params.n_junk;
int n_keep = params.n_keep;
int n_grp = params.grp_attn_n;

View File

@@ -117,12 +117,12 @@ int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_RETRIEVAL, print_usage)) {
return 1;
}
common_init();
// For BERT models, batch size must be equal to ubatch size
params.n_ubatch = params.n_batch;
params.embedding = true;

View File

@@ -17,6 +17,8 @@ int main(int argc, char ** argv) {
const std::string_view state_file = "dump_state.bin";
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
@@ -27,8 +29,6 @@ int main(int argc, char ** argv) {
params.kv_unified = true;
}
common_init();
if (params.n_predict < 0) {
params.n_predict = 16;
}

View File

@@ -16,6 +16,8 @@ int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
return 1;
}
@@ -25,8 +27,6 @@ int main(int argc, char ** argv) {
return 1;
}
common_init();
if (params.speculative.mparams_dft.path.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;

View File

@@ -38,6 +38,8 @@ int main(int argc, char ** argv) {
// needed to get candidate probs even for temp <= 0.0
params.sampling.n_probs = 128;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
return 1;
}
@@ -47,8 +49,6 @@ int main(int argc, char ** argv) {
return 1;
}
common_init();
if (params.speculative.mparams_dft.path.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;

View File

@@ -20,6 +20,8 @@ int main(int argc, char ** argv) {
common_params params;
params.escape = false;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
return 1;
}
@@ -38,7 +40,6 @@ int main(int argc, char ** argv) {
params.cache_type_v = GGML_TYPE_F32;
}
common_init();
llama_backend_init();
llama_numa_init(params.numa);
// load the model and apply lora adapter, if any

View File

@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 9)
set(GGML_VERSION_PATCH 8)
set(GGML_VERSION_PATCH 9)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)

View File

@@ -47,9 +47,11 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
#ifdef STRIDED_ITERATOR_AVAILABLE
auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
#else
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
// offset_iterator needs to populate nrows + 1 elements, so we also have to ceildiv nrows + 1 by block_size
const int nrows_offset = nrows + 1;
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows_offset);
int * offset_iterator = offsets_alloc.get();
const dim3 offset_grid((nrows + block_size - 1) / block_size);
const dim3 offset_grid((nrows_offset + block_size - 1) / block_size);
init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows);
#endif
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));

View File

@@ -114,6 +114,8 @@ set(GGML_OPENCL_KERNELS
gemv_noshuffle_q4_1_f32
gemm_noshuffle_q4_1_f32
gemv_noshuffle_general_q8_0_f32
gemv_noshuffle_q4_k_f32
gemm_noshuffle_q4_k_f32
gemv_noshuffle_q6_k_f32
gemm_noshuffle_q6_k_f32
mul

View File

@@ -538,6 +538,8 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_restore_block_q4_0_noshuffle;
cl_kernel kernel_convert_block_q4_1_noshuffle;
cl_kernel kernel_restore_block_q4_1_noshuffle;
cl_kernel kernel_convert_block_q4_K_noshuffle;
cl_kernel kernel_restore_block_q4_K_noshuffle;
cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K;
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
@@ -720,6 +722,8 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_gemm_noshuffle_q4_1_f32;
cl_kernel kernel_mul_mm_q8_0_f32_8x4;
cl_kernel CL_mul_mat_vec_q8_0_f32;
cl_kernel kernel_gemv_noshuffle_q4_k_f32;
cl_kernel kernel_gemm_noshuffle_q4_k_f32;
cl_kernel kernel_gemv_noshuffle_q6_K_f32;
cl_kernel kernel_gemm_noshuffle_q6_K_f32;
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
@@ -932,6 +936,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K_noshuffle", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K_noshuffle", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err));
@@ -2619,6 +2625,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// gemm_noshuffle_q4_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemm_noshuffle_q4_k_f32.cl.h"
};
#else
const std::string kernel_src = read_file("gemm_noshuffle_q4_k_f32.cl");
#endif
cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_k_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_k_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// gemv_noshuffle_q4_k_f32
{
std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
" -cl-mad-enable ";
if (backend_ctx->has_vector_subgroup_broadcast) {
CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST ";
}
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemv_noshuffle_q4_k_f32.cl.h"
};
#else
const std::string kernel_src = read_file("gemv_noshuffle_q4_k_f32.cl");
#endif
cl_program prog = build_program_from_source(
backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts);
CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_k_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_k_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std +
" -cl-mad-enable "
" -cl-fast-relaxed-math";
@@ -5060,12 +5105,25 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K;
if (use_adreno_kernels(backend_ctx, tensor)) {
kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle;
}
#else
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K;
#endif
cl_uchar mask_0F = 0x0F;
cl_uchar mask_F0 = 0xF0;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {64, 1, 1};
@@ -5076,6 +5134,20 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
CL_CHECK(clReleaseMemObject(data_device));
tensor->extra = extra;
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_kernels(backend_ctx, tensor)) {
int M = tensor->ne[1];
int K = tensor->ne[0];
GGML_ASSERT(K % 32 == 0);
// Transpose q, d, dm as ushort
transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M);
transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/256, M);
transpose_2d_as_16b(backend_ctx, extra->dm, extra->dm, size_dm, K/256, M);
}
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
return;
}
if (tensor->type == GGML_TYPE_Q6_K) {
@@ -5516,12 +5588,60 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
cl_uchar mask_0F = 0x0F;
cl_uchar mask_F0 = 0xF0;
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_kernels(backend_ctx, tensor)) {
int M = tensor->ne[1];
int K = tensor->ne[0];
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
static ggml_cl_buffer buf_trans_q;
static ggml_cl_buffer buf_trans_d;
static ggml_cl_buffer buf_trans_dm;
buf_trans_q.allocate(backend_ctx->context, size_q);
buf_trans_d.allocate(backend_ctx->context, size_d);
buf_trans_dm.allocate(backend_ctx->context, size_dm);
// Transpose q, d, dm back
transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4);
transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256);
transpose_2d_as_16b(backend_ctx, extra->dm, buf_trans_dm.buffer, size_dm, M, K/256);
cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K_noshuffle;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_dm.buffer));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {1, 1, 1};
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueReadBuffer(queue, data_device, CL_TRUE, offset,
size, data, 0, NULL, NULL));
CL_CHECK(clReleaseMemObject(data_device));
return;
}
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {1, 1, 1};
@@ -9688,6 +9808,192 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
#endif
}
static void ggml_cl_mul_mat_q4_k_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(src1);
GGML_ASSERT(src1->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
ggml_tensor_extra_cl_q4_K * extra0_q4_k = (ggml_tensor_extra_cl_q4_K *)src0->extra;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne1 = dst->ne[1];
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
cl_context context = backend_ctx->context;
cl_kernel kernel;
cl_int err;
cl_image_format img_fmt;
cl_image_desc img_desc;
cl_buffer_region region;
int M = ne01;
int N = ne1;
int K = ne00;
cl_uchar mask_d6 = 0x3F;
cl_uchar mask_d4 = 0x0F;
cl_uchar mask_hi2 = 0xC0;
if (ne1 == 1) {
cl_mem q_img = nullptr;
cl_mem b_sub_buf = nullptr;
cl_mem b_img = nullptr;
// image for q
img_fmt = { CL_R, CL_UNSIGNED_INT32};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = M * K / 2 / 4;
img_desc.buffer = extra0_q4_k->q;
CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// subbuffer for activations
region.origin = offset1;
region.size = K * N * sizeof(float);
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
// image for activations
img_fmt = {CL_RGBA, CL_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = K * N / 4;
img_desc.buffer = b_sub_buf;
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
kernel = backend_ctx->kernel_gemv_noshuffle_q4_k_f32;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_k->d));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_k->dm));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_k->s));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_d6));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_uchar), &mask_d4));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_hi2));
size_t local_work_size[3] = {64, 4, 1};
size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
CL_CHECK(clReleaseMemObject(q_img));
CL_CHECK(clReleaseMemObject(b_sub_buf));
CL_CHECK(clReleaseMemObject(b_img));
} else {
cl_mem b_sub_buf = nullptr;
cl_mem b_sub_buf_trans = nullptr;
cl_mem b_img = nullptr;
cl_mem b_img_trans = nullptr;
// subbuffer for activations
region.origin = offset1;
region.size = K * N * sizeof(float);
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
// image for activations
img_fmt = {CL_RGBA, CL_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = K * N / 4;
img_desc.buffer = b_sub_buf;
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// pad N to multiple of 8
int extra_elements = N % 8;
int padding = 0;
if (extra_elements > 0){
padding = 8 - extra_elements;
}
// subbuffer for transposed activations
region.origin = 0;
region.size = K * (N + padding) * sizeof(float)/2;
backend_ctx->prealloc_act_trans.allocate(context, region.size);
CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
// image for transposed activations
img_fmt = {CL_RGBA, CL_HALF_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = K * (N + padding) / 4;
img_desc.buffer = b_sub_buf_trans;
CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err));
// transpose activations
int height_B = N/4;
if (height_B == 0) {
height_B = 1;
}
int width_B = K/4;
int padded_height_B = (N + padding)/4;
kernel = backend_ctx->kernel_transpose_32_16;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B));
size_t local_work_size_t[2] = { 1, 16 };
size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B };
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst);
// gemm
kernel = backend_ctx->kernel_gemm_noshuffle_q4_k_f32;
int padded_N = N + padding;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_k->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_k->s));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_k->d));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_k->dm));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &padded_N));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_d6));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_d4));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_uchar), &mask_hi2));
size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1};
size_t local_work_size[3] = {1, 128, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
CL_CHECK(clReleaseMemObject(b_sub_buf));
CL_CHECK(clReleaseMemObject(b_sub_buf_trans));
CL_CHECK(clReleaseMemObject(b_img));
CL_CHECK(clReleaseMemObject(b_img_trans));
}
#else
GGML_UNUSED(backend);
GGML_UNUSED(src0);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
#endif
}
static void ggml_cl_mul_mat_q6_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
GGML_ASSERT(src0);
@@ -10014,6 +10320,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
return;
}
// q4_k x fp32
if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) {
ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst);
return;
}
// q6_K x fp32
if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) {
ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst);

View File

@@ -424,13 +424,17 @@ kernel void kernel_restore_block_q8_0_trans(
// Convert the block_q4_K format to 4 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
// Each thread processes a super block.
// Mask args are just to keep the signature consistent with the no-shuffle
// version and they are not used in this kernel.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q4_K(
global struct block_q4_K * src0,
global uchar * dst_q,
global uchar * dst_s,
global half * dst_d,
global half * dst_dm
global half * dst_dm,
uchar mask_0F,
uchar mask_F0
) {
global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0);
@@ -451,12 +455,15 @@ kernel void kernel_convert_block_q4_K(
// Restore block_q4_K from flattened arrays.
// Each thread processes a super block.
// Mask args are just to keep the signature consistent with the no-shuffle ones.
kernel void kernel_restore_block_q4_K(
global uchar * src_q,
global uchar * src_s,
global half * src_d,
global half * src_dm,
global struct block_q4_K * dst
global struct block_q4_K * dst,
uchar mask_0F,
uchar mask_F0
) {
global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0);
@@ -475,6 +482,70 @@ kernel void kernel_restore_block_q4_K(
}
}
kernel void kernel_convert_block_q4_K_noshuffle(
global struct block_q4_K * src0,
global uchar * dst_q,
global uchar * dst_s,
global half * dst_d,
global half * dst_dm,
uchar mask_0F,
uchar mask_F0
) {
global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + QK_K/2 * get_global_id(0);
global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE * get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
global half * dm = (global half *) dst_dm + get_global_id(0);
*d = b->d;
*dm = b->dm;
for (int i = 0; i < QK_K / 64; ++i) {
for (int j = 0; j < 16; ++j) {
uchar x0 = b->q[i*32 + 2*j];
uchar x1 = b->q[i*32 + 2*j + 1];
q[i*32 + j] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
q[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
}
}
for (int i = 0; i < K_SCALE_SIZE; ++i) {
s[i] = b->s[i];
}
}
kernel void kernel_restore_block_q4_K_noshuffle(
global uchar * src_q,
global uchar * src_s,
global half * src_d,
global half * src_dm,
global struct block_q4_K * dst,
uchar mask_0F,
uchar mask_F0
) {
global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + QK_K/2 * get_global_id(0);
global uchar * s = (global uchar *) src_s + K_SCALE_SIZE * get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);
global half * dm = (global half *) src_dm + get_global_id(0);
b->d = *d;
b->dm = *dm;
for (int i = 0; i < QK_K / 64; ++i) {
for (int j = 0; j < 16; ++j) {
uchar lo = q[i*32 + j];
uchar hi = q[i*32 + j + 16];
b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4));
b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0));
}
}
for (int i = 0; i < K_SCALE_SIZE; ++i) {
b->s[i] = s[i];
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q6_K
// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA).

View File

@@ -0,0 +1,172 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK_K 256
#define K_SCALE_SIZE 12
inline void get_scale_min_k4(
int j,
global const uchar * q,
uchar * d,
uchar * m,
uchar mask_d6,
uchar mask_d4,
uchar mask_hi2
) {
if (j < 4) {
*d = q[j] & mask_d6;
*m = q[j+4] & mask_d6;
} else {
*d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2);
*m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2);
}
}
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_128
#endif
kernel void kernel_gemm_noshuffle_q4_k_f32(
global const ushort * src0_q,
global const uchar * src0_s,
global const half * src0_d,
global const half * src0_dm,
read_only image1d_buffer_t src1,
global float * dst,
ulong offsetd,
int m,
int n,
int k,
int n_no_padding,
uchar mask_d6,
uchar mask_d4,
uchar mask_hi2
) {
dst = (global float *)((global char *)dst + offsetd);
int n_4 = n >> 2;
int gy = get_global_id(0);
int gx = get_global_id(1);
int gx_2 = gx << 2;
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
half8 B;
half4 dequantized_weights;
int num_blocks_K = k / QK_K;
global const ushort * weight_ptr = src0_q + gx_2;
global const half * d_ptr = src0_d + gx_2;
global const half * dm_ptr = src0_dm + gx_2;
for (int i = 0; i < k; i += 32) {
int sb_idx = i / QK_K;
int sub_idx = (i / 32) % 8;
half4 d = vload4(0, d_ptr + sb_idx * m);
half4 dm = vload4(0, dm_ptr + sb_idx * m);
global const uchar * sc0 = src0_s + (gx_2+0) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
global const uchar * sc1 = src0_s + (gx_2+1) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
global const uchar * sc2 = src0_s + (gx_2+2) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
global const uchar * sc3 = src0_s + (gx_2+3) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
uchar sv0, mn0, sv1, mn1, sv2, mn2, sv3, mn3;
get_scale_min_k4(sub_idx, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2);
get_scale_min_k4(sub_idx, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2);
get_scale_min_k4(sub_idx, sc2, &sv2, &mn2, mask_d6, mask_d4, mask_hi2);
get_scale_min_k4(sub_idx, sc3, &sv3, &mn3, mask_d6, mask_d4, mask_hi2);
half4 scale = convert_half4(convert_float4(d) * convert_float4((uchar4)(sv0, sv1, sv2, sv3)));
half4 mval = convert_half4(convert_float4(dm) * convert_float4((uchar4)(mn0, mn1, mn2, mn3)));
for (int l = 0; l < 32; l += 4) {
int ki = i + l;
ushort4 bits4 = vload4(0, weight_ptr + (ki/4) * m);
// j=0
B.s0123 = read_imageh(src1, gy*2 + (ki+0) * n_4);
B.s4567 = read_imageh(src1, gy*2+1 + (ki+0) * n_4);
dequantized_weights.s0 = (bits4.s0 & 0x000F) * scale.s0 - mval.s0;
dequantized_weights.s1 = (bits4.s1 & 0x000F) * scale.s1 - mval.s1;
dequantized_weights.s2 = (bits4.s2 & 0x000F) * scale.s2 - mval.s2;
dequantized_weights.s3 = (bits4.s3 & 0x000F) * scale.s3 - mval.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
// j=1
B.s0123 = read_imageh(src1, gy*2 + (ki+1) * n_4);
B.s4567 = read_imageh(src1, gy*2+1 + (ki+1) * n_4);
dequantized_weights.s0 = ((bits4.s0 & 0x00F0) >> 4) * scale.s0 - mval.s0;
dequantized_weights.s1 = ((bits4.s1 & 0x00F0) >> 4) * scale.s1 - mval.s1;
dequantized_weights.s2 = ((bits4.s2 & 0x00F0) >> 4) * scale.s2 - mval.s2;
dequantized_weights.s3 = ((bits4.s3 & 0x00F0) >> 4) * scale.s3 - mval.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
// j=2
B.s0123 = read_imageh(src1, gy*2 + (ki+2) * n_4);
B.s4567 = read_imageh(src1, gy*2+1 + (ki+2) * n_4);
dequantized_weights.s0 = ((bits4.s0 & 0x0F00) >> 8) * scale.s0 - mval.s0;
dequantized_weights.s1 = ((bits4.s1 & 0x0F00) >> 8) * scale.s1 - mval.s1;
dequantized_weights.s2 = ((bits4.s2 & 0x0F00) >> 8) * scale.s2 - mval.s2;
dequantized_weights.s3 = ((bits4.s3 & 0x0F00) >> 8) * scale.s3 - mval.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
// j=3
B.s0123 = read_imageh(src1, gy*2 + (ki+3) * n_4);
B.s4567 = read_imageh(src1, gy*2+1 + (ki+3) * n_4);
dequantized_weights.s0 = ((bits4.s0 & 0xF000) >> 12) * scale.s0 - mval.s0;
dequantized_weights.s1 = ((bits4.s1 & 0xF000) >> 12) * scale.s1 - mval.s1;
dequantized_weights.s2 = ((bits4.s2 & 0xF000) >> 12) * scale.s2 - mval.s2;
dequantized_weights.s3 = ((bits4.s3 & 0xF000) >> 12) * scale.s3 - mval.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
}
}
int idx = (gy<<3)*m + (gx<<2);
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
}
}

View File

@@ -0,0 +1,318 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#endif
#define QK_K 256
#define NSUBGROUPS 4
#define SUBGROUP_SIZE 64
inline void get_scale_min_k4(
int j,
global const uchar * q,
uchar * d,
uchar * m,
uchar mask_d6,
uchar mask_d4,
uchar mask_hi2
) {
if (j < 4) {
*d = q[j] & mask_d6;
*m = q[j+4] & mask_d6;
} else {
*d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2);
*m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2);
}
}
#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, minv, y) \
float shared_y; \
shared_y = sub_group_broadcast(y.s0, 0); \
total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s1, 0); \
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s2, 0); \
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s3, 0); \
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s4, 0); \
total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s5, 0); \
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s6, 0); \
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s7, 0); \
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s0, 1); \
total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s1, 1); \
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s2, 1); \
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s3, 1); \
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s4, 1); \
total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s5, 1); \
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s6, 1); \
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s7, 1); \
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, minv, y) \
shared_y = sub_group_broadcast(y.s0, 2); \
total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s1, 2); \
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s2, 2); \
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s3, 2); \
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s4, 2); \
total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s5, 2); \
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s6, 2); \
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s7, 2); \
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s0, 3); \
total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s1, 3); \
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s2, 3); \
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s3, 3); \
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s4, 3); \
total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s5, 3); \
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s6, 3); \
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s7, 3); \
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, minv, y) \
float8 shared_y; \
shared_y = sub_group_broadcast(y, 0); \
total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \
total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \
total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \
total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \
shared_y = sub_group_broadcast(y, 1); \
total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \
total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \
total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \
total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \
#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, minv, y) \
shared_y = sub_group_broadcast(y, 2); \
total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \
total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \
total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \
total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \
shared_y = sub_group_broadcast(y, 3); \
total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \
total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \
total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \
total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_gemv_noshuffle_q4_k_f32(
read_only image1d_buffer_t src0_q,
global half2 * src0_d,
global half2 * src0_m,
global uchar * src0_s,
read_only image1d_buffer_t src1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
uchar mask_d6,
uchar mask_d4,
uchar mask_hi2)
{
uint groupId = get_local_id(1);
uint gid = get_global_id(0);
ushort slid = get_sub_group_local_id();
uint K = ne00;
uint M = ne01;
uint LINE_STRIDE_A = M / 2;
uint BLOCK_STRIDE_A = NSUBGROUPS * M;
uint scales_per_row = (K / QK_K) * 12;
private uint4 regA;
private half2 regS;
private half2 regM;
private float8 regB;
private float2 totalSum = (float2)(0.0f);
for (uint k = groupId; k < (K / 32); k += NSUBGROUPS) {
uint sb = k / 8;
uint j = k % 8;
half2 d = src0_d[gid + sb * LINE_STRIDE_A];
half2 dm = src0_m[gid + sb * LINE_STRIDE_A];
global const uchar * sc0 = src0_s + 2 * gid * scales_per_row + sb * 12;
global const uchar * sc1 = src0_s + (2 * gid + 1) * scales_per_row + sb * 12;
uchar sv0, mn0, sv1, mn1;
get_scale_min_k4(j, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2);
get_scale_min_k4(j, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2);
regS = convert_half2(convert_float2(d) * convert_float2((uchar2)(sv0, sv1)));
regM = convert_half2(convert_float2(dm) * convert_float2((uchar2)(mn0, mn1)));
if (slid < 4) {
regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
}
// load half weights for two blocks in consecutive rows
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
#ifdef VECTOR_SUB_GROUP_BROADCAST
dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regM, regB);
#else
dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regM, regB);
#endif // VECTOR_SUB_GROUP_BROADCAST
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
#ifdef VECTOR_SUB_GROUP_BROADCAST
dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regM, regB);
#else
dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regM, regB);
#endif // VECTOR_SUB_GROUP_BROADCAST
}
// reduction in local memory, assumes #wave=4
local float2 reduceLM[SUBGROUP_SIZE * 3];
if (groupId == 1) {
reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum;
}
if (groupId == 2) {
reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum;
}
if (groupId == 3) {
reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum;
}
barrier(CLK_LOCAL_MEM_FENCE);
if (groupId == 0) {
totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid];
}
if (groupId == 0) {
totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid];
}
if (groupId == 0) {
totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid];
}
// 2 outputs per fiber in wave 0
if (groupId == 0) {
dst = (global float*)((global char*)dst + offsetd);
vstore2(totalSum, 0, &(dst[gid * 2]));
}
}

View File

@@ -70,6 +70,7 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 64, 64)
return 0;
}
@@ -310,11 +311,11 @@ static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const
sycl::half2 * const __restrict__ tile_KV,
const int stride_KV,
const int i_sup) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
auto load = [&] (const int n) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int stride_j = warp_size >> n;
if (stride_j == 0) {
@@ -455,7 +456,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
(K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
#ifdef SYCL_FAST_FP16
static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
@@ -505,7 +506,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
}
if (k_KQ_0 + nbatch_K < DKQ) {
item_ct1.barrier(); // Sync not needed on last iteration.
item_ct1.barrier(sycl::access::fence_space::local_space); // Sync not needed on last iteration.
}
}
@@ -545,7 +546,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
const int k_VKQ_max,
const int col_Q_0,
float * KQ_max_new_shared) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
@@ -620,14 +621,14 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
}
if constexpr (np == 1) {
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
} else {
static_assert(cpw == 1, "bad cpw");
if (item_ct1.get_local_id(2) == 0) {
KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0];
}
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np];
KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
}
@@ -697,7 +698,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
(V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
#ifdef SYCL_FAST_FP16
#pragma unroll
@@ -765,7 +766,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
}
}
#endif // SYCL_FAST_FP16
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
}
}
@@ -972,7 +973,7 @@ static void flash_attn_tile(const char * Q,
}
}
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
// Main loop over KV cache:
const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
@@ -1051,7 +1052,7 @@ static void flash_attn_tile(const char * Q,
return;
}
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
#pragma unroll
for (int ip = 1; ip < np; ++ip) {
@@ -1193,37 +1194,39 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm
constexpr size_t nbytes_shared = 0;
if constexpr (DV <= 256) {
if (Q->ne[1] > 16/ncols2) {
constexpr int cols_per_block = 32;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
if (DV < 512 && Q->ne[1] < 32) {
if constexpr (ncols2 <= 32) {
if (Q->ne[1] > 16/ncols2) {
constexpr int cols_per_block = 32;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
}
}
}
if (Q->ne[1] > 8/ncols2) {
constexpr int cols_per_block = 16;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
}
if constexpr (ncols2 <= 8) {
if (Q->ne[1] > 4/ncols2) {
constexpr int cols_per_block = 8;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
if constexpr (ncols2 <= 16) {
if (Q->ne[1] > 8/ncols2) {
constexpr int cols_per_block = 16;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
}
}
if constexpr (ncols2 <= 8) {
if (Q->ne[1] > 4/ncols2) {
constexpr int cols_per_block = 8;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
}
}
}

View File

@@ -1 +1 @@
c044a8eeae2591faa0950c8b5e514cbc4bbfc4ca
a04eea0761a85d18f3f504d6ab970c5c9dce705f

View File

@@ -294,7 +294,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
}
// get extra buffer types of the CPU
// TODO: a more general solution for non-CPU extra buft should be imlpemented in the future
// TODO: a more general solution for non-CPU extra buft should be implemented in the future
// ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948
std::vector<ggml_backend_buffer_type_t> buft_extra;
{

View File

@@ -18,7 +18,7 @@ struct llama_ubatch {
}
// typical for M-RoPE cases:
// 0 - sequantial position of the tokens/embeddings in the sequence
// 0 - sequential position of the tokens/embeddings in the sequence
// 1 - y position in the image
// 2 - x position in the image
// 3 - other

View File

@@ -586,7 +586,7 @@ void llama_context::sched_reserve() {
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
// TODO: not sure if the following graph would be worst case for multi-stream KV caches:
//
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
//

View File

@@ -1665,7 +1665,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
// note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
// but this would make the graph topology depend on the number of output tokens, which can interere with
// but this would make the graph topology depend on the number of output tokens, which can interfere with
// features that require constant topology such as pipeline parallelism
// ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
//if (n_outputs < n_tokens) {

View File

@@ -333,7 +333,7 @@ public:
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location
// note: the heads in k_cur and v_cur should be layed out contiguously in memory
// note: the heads in k_cur and v_cur should be laid out contiguously in memory
// - k_cur [n_embd_head_k, n_head_k, n_tokens]
// - k_idxs [n_tokens]
// - v_cur [n_embd_head_v, n_head_v, n_tokens]

View File

@@ -9,7 +9,7 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model,
inpL = build_inp_embd(model.tok_embd);
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
// important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings)
inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
cb(inpL, "inp_scaled", -1);

View File

@@ -9,7 +9,7 @@ llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_gr
inpL = build_inp_embd(model.tok_embd);
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
// important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings)
inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
cb(inpL, "inp_scaled", -1);

View File

@@ -12,7 +12,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
inpL = build_inp_embd(model.tok_embd);
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
// important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings)
inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
cb(inpL, "inp_scaled", -1);

View File

@@ -118,12 +118,12 @@ int main(int argc, char ** argv) {
common_params params;
params.out_file = "tests.txt";
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EXPORT_GRAPH_OPS)) {
return 1;
}
common_init();
// Load CPU-only
ggml_backend_dev_t cpu_device = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
params.devices = { cpu_device, nullptr };

View File

@@ -8424,6 +8424,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 256, 1, 1}, order)); // test ceildiv in CUDA's CUB's DeviceSegmentedSort
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));

View File

@@ -3077,6 +3077,27 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_reasoning("I need to output the invoice details in JSON")
.expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
.run();
// Unsolicited tool calls. There is no good way to handle these, so we return empty content.
// Builtin function - recipient in role
tst.test(
"<|channel|>analysis<|message|>I will execute python to say hello<|end|>"
"<|start|>assistant to=container.exec<|channel|>commentary<|message|>python3 -c 'print(\"hello\")'")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect_reasoning("I will execute python to say hello")
.expect_content("")
.run();
// Builtin function - recipient in channel
tst.test(
"<|channel|>analysis<|message|>I will execute python to say hello<|end|>"
"<|start|>assistant<|channel|>commentary to=python <|constrain|>code<|message|>print(\"hello\")")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect_reasoning("I will execute python to say hello")
.expect_content("")
.run();
}
{

View File

@@ -387,6 +387,24 @@ static void test_expressions(testing & t) {
"Bob"
);
test_template(t, "empty computed member defaults to undefined",
"{{ a[]|default('fallback') }}",
{{"a", {{"name", "Bob"}}}},
"fallback"
);
test_template(t, "empty computed member is undefined",
"{{ a[] is undefined }}",
{{"a", {{"name", "Bob"}}}},
"True"
);
test_template(t, "undefined computed member is undefined",
"{{ a[undefined] is undefined }}",
{{"a", {{"name", "Bob"}}}},
"True"
);
test_template(t, "array access",
"{{ items[1] }}",
{{"items", json::array({"a", "b", "c"})}},

View File

@@ -22,12 +22,12 @@ int main(int argc, char ** argv) {
params.n_parallel = 3;
params.n_ctx = 256;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
common_init();
// init
common_init_result_ptr llama_init = common_init_from_params(params);

View File

@@ -16,12 +16,12 @@
int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
common_init();
llama_backend_init();
llama_numa_init(params.numa);

View File

@@ -20,12 +20,12 @@ int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BENCH, print_usage)) {
return 1;
}
common_init();
int is_pp_shared = params.is_pp_shared;
int is_tg_separate = params.is_tg_separate;

View File

@@ -347,6 +347,8 @@ int main(int argc, char ** argv) {
params.verbosity = LOG_LEVEL_ERROR; // by default, less verbose logs
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CLI)) {
return 1;
}
@@ -357,8 +359,6 @@ int main(int argc, char ** argv) {
console::error("please use llama-completion instead\n");
}
common_init();
// struct that contains llama context and inference
cli_context ctx_cli(params);

View File

@@ -90,12 +90,12 @@ int main(int argc, char ** argv) {
common_params params;
g_params = &params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMPLETION, print_usage)) {
return 1;
}
common_init();
auto & sparams = params.sampling;
// save choice to use color for later
@@ -146,19 +146,13 @@ int main(int argc, char ** argv) {
ctx = llama_init->context();
model = llama_init->model();
smpl = llama_init->sampler(0);
if (ctx == NULL) {
LOG_ERR("%s: error: unable to create context\n", __func__);
return 1;
}
if (model == NULL) {
LOG_ERR("%s: error: unable to load model\n", __func__);
return 1;
}
smpl = llama_init->sampler(0);
llama_memory_t mem = llama_get_memory(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);

View File

@@ -400,6 +400,8 @@ int main(int argc, char ** argv) {
params.out_file = "control_vector.gguf";
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CVECTOR_GENERATOR, print_usage)) {
return 1;
}

View File

@@ -418,6 +418,8 @@ int main(int argc, char ** argv) {
params.out_file = "ggml-lora-merged-f16.gguf";
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EXPORT_LORA, print_usage)) {
return 1;
}

View File

@@ -17,11 +17,12 @@ using namespace std::chrono_literals;
int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
common_init();
llama_backend_init();
llama_numa_init(params.numa);
auto mparams = common_model_params_to_llama(params);

View File

@@ -1212,6 +1212,8 @@ int main(int argc, char ** argv) {
params.n_ctx = 512;
params.escape = false;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_IMATRIX, print_usage)) {
return 1;
}
@@ -1223,8 +1225,6 @@ int main(int argc, char ** argv) {
return 0;
}
common_init();
const int32_t n_ctx = params.n_ctx;
if (n_ctx <= 0) {

View File

@@ -54,11 +54,12 @@ int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) {
return 1;
}
common_init();
mtmd_helper_log_set(common_log_default_callback, nullptr);
if (params.mmproj.path.empty()) {

View File

@@ -281,11 +281,12 @@ int main(int argc, char ** argv) {
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) {
return 1;
}
common_init();
mtmd_helper_log_set(common_log_default_callback, nullptr);
if (params.mmproj.path.empty()) {

View File

@@ -2012,12 +2012,12 @@ int main(int argc, char ** argv) {
params.n_ctx = 512;
params.escape = false;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
return 1;
}
common_init();
const int32_t n_ctx = params.n_ctx;
if (n_ctx <= 0) {

View File

@@ -58,6 +58,9 @@ static std::vector<float> get_logits(
int main(int argc, char ** argv) {
common_params params;
params.escape = false;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_RESULTS)) {
return 1;
}
@@ -65,7 +68,6 @@ int main(int argc, char ** argv) {
LOG_ERR("%s: an output file must be specified", __func__);
return 1;
}
common_init();
llama_backend_init();
llama_numa_init(params.numa);
common_init_result_ptr llama_init = common_init_from_params(params);

Binary file not shown.

View File

@@ -75,6 +75,8 @@ int main(int argc, char ** argv) {
// own arguments required by this example
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
return 1;
}
@@ -100,8 +102,6 @@ int main(int argc, char ** argv) {
params.model_alias.insert(params.model.name);
}
common_init();
// struct that contains llama context and inference
server_context ctx_server;

View File

@@ -4,7 +4,7 @@
import { getChatActionsContext, setMessageEditContext } from '$lib/contexts';
import { chatStore, pendingEditMessageId } from '$lib/stores/chat.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { DatabaseService } from '$lib/services';
import { DatabaseService } from '$lib/services/database.service';
import { SYSTEM_MESSAGE_PLACEHOLDER } from '$lib/constants';
import { MessageRole, AttachmentType } from '$lib/enums';
import {
@@ -19,6 +19,7 @@
interface Props {
class?: string;
message: DatabaseMessage;
toolMessages?: DatabaseMessage[];
isLastAssistantMessage?: boolean;
siblingInfo?: ChatMessageSiblingInfo | null;
}
@@ -26,6 +27,7 @@
let {
class: className = '',
message,
toolMessages = [],
isLastAssistantMessage = false,
siblingInfo = null
}: Props = $props();
@@ -302,6 +304,7 @@
{deletionInfo}
{isLastAssistantMessage}
{message}
{toolMessages}
messageContent={message.content}
onConfirmDelete={handleConfirmDelete}
onContinue={handleContinue}

View File

@@ -6,42 +6,42 @@
SyntaxHighlightedCode
} from '$lib/components/app';
import { config } from '$lib/stores/settings.svelte';
import { Wrench, Loader2, AlertTriangle, Brain } from '@lucide/svelte';
import { AgenticSectionType, AttachmentType, FileTypeText } from '$lib/enums';
import { Wrench, Loader2, Brain } from '@lucide/svelte';
import { AgenticSectionType, FileTypeText } from '$lib/enums';
import { formatJsonPretty } from '$lib/utils';
import { ATTACHMENT_SAVED_REGEX, NEWLINE_SEPARATOR } from '$lib/constants';
import { parseAgenticContent, type AgenticSection } from '$lib/utils';
import type { DatabaseMessage, DatabaseMessageExtraImageFile } from '$lib/types/database';
import {
deriveAgenticSections,
parseToolResultWithImages,
type AgenticSection,
type ToolResultLine
} from '$lib/utils';
import type { DatabaseMessage } from '$lib/types/database';
import type { ChatMessageAgenticTimings, ChatMessageAgenticTurnStats } from '$lib/types/chat';
import { ChatMessageStatsView } from '$lib/enums';
interface Props {
message?: DatabaseMessage;
content: string;
message: DatabaseMessage;
toolMessages?: DatabaseMessage[];
isStreaming?: boolean;
highlightTurns?: boolean;
}
type ToolResultLine = {
text: string;
image?: DatabaseMessageExtraImageFile;
};
let { content, message, isStreaming = false, highlightTurns = false }: Props = $props();
let { message, toolMessages = [], isStreaming = false, highlightTurns = false }: Props = $props();
let expandedStates: Record<number, boolean> = $state({});
const sections = $derived(parseAgenticContent(content));
const showToolCallInProgress = $derived(config().showToolCallInProgress as boolean);
const showThoughtInProgress = $derived(config().showThoughtInProgress as boolean);
// Parse toolResults with images only when sections or message.extra change
const sections = $derived(deriveAgenticSections(message, toolMessages, []));
// Parse tool results with images
const sectionsParsed = $derived(
sections.map((section) => ({
...section,
parsedLines: section.toolResult
? parseToolResultWithImages(section.toolResult, message?.extra)
: []
? parseToolResultWithImages(section.toolResult, section.toolResultExtras || message?.extra)
: ([] as ToolResultLine[])
}))
);
@@ -107,26 +107,6 @@
expandedStates[index] = !currentState;
}
function parseToolResultWithImages(
toolResult: string,
extras?: DatabaseMessage['extra']
): ToolResultLine[] {
const lines = toolResult.split(NEWLINE_SEPARATOR);
return lines.map((line) => {
const match = line.match(ATTACHMENT_SAVED_REGEX);
if (!match || !extras) return { text: line };
const attachmentName = match[1];
const image = extras.find(
(e): e is DatabaseMessageExtraImageFile =>
e.type === AttachmentType.IMAGE && e.name === attachmentName
);
return { text: line, image };
});
}
function buildTurnAgenticTimings(stats: ChatMessageAgenticTurnStats): ChatMessageAgenticTimings {
return {
turns: 1,
@@ -144,9 +124,8 @@
<MarkdownContent content={section.content} attachments={message?.extra} />
</div>
{:else if section.type === AgenticSectionType.TOOL_CALL_STREAMING}
{@const streamingIcon = isStreaming ? Loader2 : AlertTriangle}
{@const streamingIconClass = isStreaming ? 'h-4 w-4 animate-spin' : 'h-4 w-4 text-yellow-500'}
{@const streamingSubtitle = isStreaming ? '' : 'incomplete'}
{@const streamingIcon = isStreaming ? Loader2 : Loader2}
{@const streamingIconClass = isStreaming ? 'h-4 w-4 animate-spin' : 'h-4 w-4'}
<CollapsibleContentBlock
open={isExpanded(index, section)}
@@ -154,7 +133,7 @@
icon={streamingIcon}
iconClass={streamingIconClass}
title={section.toolName || 'Tool call'}
subtitle={streamingSubtitle}
subtitle={isStreaming ? '' : 'incomplete'}
{isStreaming}
onToggle={() => toggleExpanded(index, section)}
>

View File

@@ -15,7 +15,7 @@
import { Check, X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { Checkbox } from '$lib/components/ui/checkbox';
import { AGENTIC_TAGS, INPUT_CLASSES, REASONING_TAGS } from '$lib/constants';
import { INPUT_CLASSES } from '$lib/constants';
import { MessageRole, KeyboardKey, ChatMessageStatsView } from '$lib/enums';
import Label from '$lib/components/ui/label/label.svelte';
import { config } from '$lib/stores/settings.svelte';
@@ -23,6 +23,8 @@
import { modelsStore } from '$lib/stores/models.svelte';
import { ServerModelStatus } from '$lib/enums';
import { hasAgenticContent } from '$lib/utils';
interface Props {
class?: string;
deletionInfo: {
@@ -33,6 +35,7 @@
} | null;
isLastAssistantMessage?: boolean;
message: DatabaseMessage;
toolMessages?: DatabaseMessage[];
messageContent: string | undefined;
onCopy: () => void;
onConfirmDelete: () => void;
@@ -53,6 +56,7 @@
deletionInfo,
isLastAssistantMessage = false,
message,
toolMessages = [],
messageContent,
onConfirmDelete,
onContinue,
@@ -84,10 +88,8 @@
}
}
const hasAgenticMarkers = $derived(
messageContent?.includes(AGENTIC_TAGS.TOOL_CALL_START) ?? false
);
const hasReasoningMarkers = $derived(messageContent?.includes(REASONING_TAGS.START) ?? false);
const isAgentic = $derived(hasAgenticContent(message, toolMessages));
const hasReasoning = $derived(!!message.reasoningContent);
const processingState = useProcessingState();
let currentConfig = $derived(config());
@@ -145,7 +147,7 @@
}
let highlightAgenticTurns = $derived(
hasAgenticMarkers &&
isAgentic &&
(currentConfig.alwaysShowAgenticTurns || activeStatsView === ChatMessageStatsView.SUMMARY)
);
@@ -160,13 +162,14 @@
message?.role === MessageRole.ASSISTANT &&
isActivelyProcessing &&
hasNoContent &&
!isAgentic &&
isLastAssistantMessage
);
let showProcessingInfoBottom = $derived(
message?.role === MessageRole.ASSISTANT &&
isActivelyProcessing &&
!hasNoContent &&
(!hasNoContent || isAgentic) &&
isLastAssistantMessage
);
@@ -252,10 +255,10 @@
<pre class="raw-output">{messageContent || ''}</pre>
{:else}
<ChatMessageAgenticContent
content={messageContent || ''}
{message}
{toolMessages}
isStreaming={isChatStreaming()}
highlightTurns={highlightAgenticTurns}
{message}
/>
{/if}
{:else}
@@ -344,9 +347,7 @@
{onCopy}
{onEdit}
{onRegenerate}
onContinue={currentConfig.enableContinueGeneration && !hasReasoningMarkers
? onContinue
: undefined}
onContinue={currentConfig.enableContinueGeneration && !hasReasoning ? onContinue : undefined}
{onForkConversation}
{onDelete}
{onConfirmDelete}

View File

@@ -6,7 +6,12 @@
import { chatStore } from '$lib/stores/chat.svelte';
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { copyToClipboard, formatMessageForClipboard, getMessageSiblings } from '$lib/utils';
import {
copyToClipboard,
formatMessageForClipboard,
getMessageSiblings,
hasAgenticContent
} from '$lib/utils';
interface Props {
class?: string;
@@ -119,32 +124,75 @@
? messages
: messages.filter((msg) => msg.type !== MessageRole.SYSTEM);
let lastAssistantIndex = -1;
// Build display entries, grouping agentic sessions into single entries.
// An agentic session = assistant(with tool_calls) → tool → assistant → tool → ... → assistant(final)
const result: Array<{
message: DatabaseMessage;
toolMessages: DatabaseMessage[];
isLastAssistantMessage: boolean;
siblingInfo: ChatMessageSiblingInfo;
}> = [];
for (let i = filteredMessages.length - 1; i >= 0; i--) {
if (filteredMessages[i].role === MessageRole.ASSISTANT) {
lastAssistantIndex = i;
for (let i = 0; i < filteredMessages.length; i++) {
const msg = filteredMessages[i];
// Skip tool messages - they're grouped with preceding assistant
if (msg.role === MessageRole.TOOL) continue;
const toolMessages: DatabaseMessage[] = [];
if (msg.role === MessageRole.ASSISTANT && hasAgenticContent(msg)) {
let j = i + 1;
while (j < filteredMessages.length) {
const next = filteredMessages[j];
if (next.role === MessageRole.TOOL) {
toolMessages.push(next);
j++;
} else if (next.role === MessageRole.ASSISTANT) {
toolMessages.push(next);
j++;
} else {
break;
}
}
i = j - 1;
} else if (msg.role === MessageRole.ASSISTANT) {
let j = i + 1;
while (j < filteredMessages.length && filteredMessages[j].role === MessageRole.TOOL) {
toolMessages.push(filteredMessages[j]);
j++;
}
}
const siblingInfo = getMessageSiblings(allConversationMessages, msg.id);
result.push({
message: msg,
toolMessages,
isLastAssistantMessage: false,
siblingInfo: siblingInfo || {
message: msg,
siblingIds: [msg.id],
currentIndex: 0,
totalSiblings: 1
}
});
}
// Mark the last assistant message
for (let i = result.length - 1; i >= 0; i--) {
if (result[i].message.role === MessageRole.ASSISTANT) {
result[i].isLastAssistantMessage = true;
break;
}
}
return filteredMessages.map((message, index) => {
const siblingInfo = getMessageSiblings(allConversationMessages, message.id);
const isLastAssistantMessage =
message.role === MessageRole.ASSISTANT && index === lastAssistantIndex;
return {
message,
isLastAssistantMessage,
siblingInfo: siblingInfo || {
message,
siblingIds: [message.id],
currentIndex: 0,
totalSiblings: 1
}
};
});
return result;
});
</script>
@@ -152,11 +200,12 @@
class="flex h-full flex-col space-y-10 pt-24 {className}"
style="height: auto; min-height: calc(100dvh - 14rem);"
>
{#each displayMessages as { message, isLastAssistantMessage, siblingInfo } (message.id)}
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
<div use:fadeInView>
<ChatMessage
class="mx-auto w-full max-w-[48rem]"
{message}
{toolMessages}
{isLastAssistantMessage}
{siblingInfo}
/>

View File

@@ -425,21 +425,16 @@ export { default as ChatMessage } from './ChatMessages/ChatMessage.svelte';
/**
* **ChatMessageAgenticContent** - Agentic workflow output display
*
* Specialized renderer for assistant messages containing agentic workflow markers.
* Parses structured content and displays tool calls and reasoning blocks as
* interactive collapsible sections with real-time streaming support.
* Specialized renderer for assistant messages with tool calls and reasoning.
* Derives display sections from structured message data (toolCalls, reasoningContent,
* and child tool result messages) and renders them as interactive collapsible sections.
*
* **Architecture:**
* - Uses `parseAgenticContent()` from `$lib/utils` to parse markers
* - Uses `deriveAgenticSections()` from `$lib/utils` to build sections from structured data
* - Renders sections as CollapsibleContentBlock components
* - Handles streaming state for progressive content display
* - Falls back to MarkdownContent for plain text sections
*
* **Marker Format:**
* - Tool calls: in constants/agentic.ts (AGENTIC_TAGS)
* - Reasoning: in constants/agentic.ts (REASONING_TAGS)
* - Partial markers handled gracefully during streaming
*
* **Execution States:**
* - **Streaming**: Animated spinner, block expanded, auto-scroll enabled
* - **Pending**: Waiting indicator for queued tool calls

View File

@@ -15,8 +15,11 @@ export const DEFAULT_AGENTIC_CONFIG: AgenticConfig = {
maxToolPreviewLines: 25
} as const;
// Agentic tool call tag markers
export const AGENTIC_TAGS = {
/**
* @deprecated Legacy marker tags - only used for migration of old stored messages.
* New messages use structured fields (reasoningContent, toolCalls, toolCallId).
*/
export const LEGACY_AGENTIC_TAGS = {
TOOL_CALL_START: '<<<AGENTIC_TOOL_CALL_START>>>',
TOOL_CALL_END: '<<<AGENTIC_TOOL_CALL_END>>>',
TOOL_NAME_PREFIX: '<<<TOOL_NAME:',
@@ -25,39 +28,25 @@ export const AGENTIC_TAGS = {
TAG_SUFFIX: '>>>'
} as const;
export const REASONING_TAGS = {
/**
* @deprecated Legacy reasoning tags - only used for migration of old stored messages.
* New messages use the dedicated reasoningContent field.
*/
export const LEGACY_REASONING_TAGS = {
START: '<<<reasoning_content_start>>>',
END: '<<<reasoning_content_end>>>'
} as const;
// Regex for trimming leading/trailing newlines
export const TRIM_NEWLINES_REGEX = /^\n+|\n+$/g;
// Regex patterns for parsing agentic content
export const AGENTIC_REGEX = {
// Matches completed tool calls (with END marker)
/**
* @deprecated Legacy regex patterns - only used for migration of old stored messages.
*/
export const LEGACY_AGENTIC_REGEX = {
COMPLETED_TOOL_CALL:
/<<<AGENTIC_TOOL_CALL_START>>>\n<<<TOOL_NAME:(.+?)>>>\n<<<TOOL_ARGS_START>>>([\s\S]*?)<<<TOOL_ARGS_END>>>([\s\S]*?)<<<AGENTIC_TOOL_CALL_END>>>/g,
// Matches pending tool call (has NAME and ARGS but no END)
PENDING_TOOL_CALL:
/<<<AGENTIC_TOOL_CALL_START>>>\n<<<TOOL_NAME:(.+?)>>>\n<<<TOOL_ARGS_START>>>([\s\S]*?)<<<TOOL_ARGS_END>>>([\s\S]*)$/,
// Matches partial tool call (has START and NAME, ARGS still streaming)
PARTIAL_WITH_NAME:
/<<<AGENTIC_TOOL_CALL_START>>>\n<<<TOOL_NAME:(.+?)>>>\n<<<TOOL_ARGS_START>>>([\s\S]*)$/,
// Matches early tool call (just START marker)
EARLY_MATCH: /<<<AGENTIC_TOOL_CALL_START>>>([\s\S]*)$/,
// Matches partial marker at end of content
PARTIAL_MARKER: /<<<[A-Za-z_]*$/,
// Matches reasoning content blocks (including tags)
REASONING_BLOCK: /<<<reasoning_content_start>>>[\s\S]*?<<<reasoning_content_end>>>/g,
// Captures the reasoning text between start/end tags
REASONING_EXTRACT: /<<<reasoning_content_start>>>([\s\S]*?)<<<reasoning_content_end>>>/,
// Matches an opening reasoning tag and any remaining content (unterminated)
REASONING_OPEN: /<<<reasoning_content_start>>>[\s\S]*$/,
// Matches a complete agentic tool call display block (start to end marker)
AGENTIC_TOOL_CALL_BLOCK: /\n*<<<AGENTIC_TOOL_CALL_START>>>[\s\S]*?<<<AGENTIC_TOOL_CALL_END>>>/g,
// Matches a pending/partial agentic tool call (start marker with no matching end)
AGENTIC_TOOL_CALL_OPEN: /\n*<<<AGENTIC_TOOL_CALL_START>>>[\s\S]*$/,
// Matches tool name inside content
TOOL_NAME_EXTRACT: /<<<TOOL_NAME:([^>]+)>>>/
HAS_LEGACY_MARKERS: /<<<(?:AGENTIC_TOOL_CALL_START|reasoning_content_start)>>>/
} as const;

View File

@@ -1,6 +1,7 @@
import { getJsonHeaders, formatAttachmentText, isAbortError } from '$lib/utils';
import { getJsonHeaders } from '$lib/utils/api-headers';
import { formatAttachmentText } from '$lib/utils/formatters';
import { isAbortError } from '$lib/utils/abort';
import {
AGENTIC_REGEX,
ATTACHMENT_LABEL_PDF_FILE,
ATTACHMENT_LABEL_MCP_PROMPT,
ATTACHMENT_LABEL_MCP_RESOURCE
@@ -17,38 +18,6 @@ import type { DatabaseMessageExtraMcpPrompt, DatabaseMessageExtraMcpResource } f
import { modelsStore } from '$lib/stores/models.svelte';
export class ChatService {
private static stripReasoningContent(
content: ApiChatMessageData['content'] | null | undefined
): ApiChatMessageData['content'] | null | undefined {
if (!content) {
return content;
}
if (typeof content === 'string') {
return content
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK, '')
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '');
}
if (!Array.isArray(content)) {
return content;
}
return content.map((part: ApiChatMessageContentPart) => {
if (part.type !== ContentPartType.TEXT || !part.text) return part;
return {
...part,
text: part.text
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK, '')
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '')
};
});
}
/**
*
*
@@ -57,46 +26,6 @@ export class ChatService {
*
*/
/**
* Extracts reasoning text from content that contains internal reasoning tags.
* Returns the concatenated reasoning content or undefined if none found.
*/
private static extractReasoningFromContent(
content: ApiChatMessageData['content'] | null | undefined
): string | undefined {
if (!content) return undefined;
const extractFromString = (text: string): string => {
const parts: string[] = [];
// Use a fresh regex instance to avoid shared lastIndex state
const re = new RegExp(AGENTIC_REGEX.REASONING_EXTRACT.source);
let match = re.exec(text);
while (match) {
parts.push(match[1]);
// advance past the matched portion and retry
text = text.slice(match.index + match[0].length);
match = re.exec(text);
}
return parts.join('');
};
if (typeof content === 'string') {
const result = extractFromString(content);
return result || undefined;
}
if (!Array.isArray(content)) return undefined;
const parts: string[] = [];
for (const part of content) {
if (part.type === ContentPartType.TEXT && part.text) {
const result = extractFromString(part.text);
if (result) parts.push(result);
}
}
return parts.length > 0 ? parts.join('') : undefined;
}
/**
* Sends a chat completion request to the llama.cpp server.
* Supports both streaming and non-streaming responses with comprehensive parameter configuration.
@@ -201,20 +130,15 @@ export class ChatService {
const requestBody: ApiChatCompletionRequest = {
messages: normalizedMessages.map((msg: ApiChatMessageData) => {
// Always strip internal reasoning/agentic tags from content
const cleanedContent = ChatService.stripReasoningContent(msg.content);
const mapped: ApiChatCompletionRequest['messages'][0] = {
role: msg.role,
content: cleanedContent,
content: msg.content,
tool_calls: msg.tool_calls,
tool_call_id: msg.tool_call_id
};
// When preserving reasoning, extract it from raw content and send as separate field
if (!excludeReasoningFromContext) {
const reasoning = ChatService.extractReasoningFromContent(msg.content);
if (reasoning) {
mapped.reasoning_content = reasoning;
}
// Include reasoning_content from the dedicated field
if (!excludeReasoningFromContext && msg.reasoning_content) {
mapped.reasoning_content = msg.reasoning_content;
}
return mapped;
}),
@@ -726,6 +650,10 @@ export class ChatService {
content: message.content
};
if (message.reasoningContent) {
result.reasoning_content = message.reasoningContent;
}
if (toolCalls && toolCalls.length > 0) {
result.tool_calls = toolCalls;
}
@@ -854,6 +782,9 @@ export class ChatService {
role: message.role as MessageRole,
content: contentParts
};
if (message.reasoningContent) {
result.reasoning_content = message.reasoningContent;
}
if (toolCalls && toolCalls.length > 0) {
result.tool_calls = toolCalls;
}

View File

@@ -42,6 +42,7 @@ import type {
import {
buildProxiedUrl,
buildProxiedHeaders,
getAuthHeaders,
throwIfAborted,
isAbortError,
createBase64DataUrl
@@ -124,7 +125,14 @@ export class MCPService {
const requestInit: RequestInit = {};
if (config.headers) {
requestInit.headers = buildProxiedHeaders(config.headers);
requestInit.headers = config.useProxy ? buildProxiedHeaders(config.headers) : config.headers;
}
if (useProxy) {
requestInit.headers = {
...getAuthHeaders(),
...(requestInit.headers as Record<string, string>)
};
}
if (config.credentials) {

View File

@@ -7,6 +7,10 @@
* - Session state management
* - Turn limit enforcement
*
* Each agentic turn produces separate DB messages:
* - One assistant message per LLM turn (with tool_calls if any)
* - One tool result message per tool call execution
*
* **Architecture & Relationships:**
* - **ChatService**: Stateless API layer (sendMessage, streaming)
* - **mcpStore**: MCP connection management and tool execution
@@ -16,7 +20,6 @@
* @see mcpStore in stores/mcp.svelte.ts for MCP operations
*/
import { SvelteMap } from 'svelte/reactivity';
import { ChatService } from '$lib/services';
import { config } from '$lib/stores/settings.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
@@ -24,7 +27,6 @@ import { modelsStore } from '$lib/stores/models.svelte';
import { isAbortError } from '$lib/utils';
import {
DEFAULT_AGENTIC_CONFIG,
AGENTIC_TAGS,
NEWLINE_SEPARATOR,
TURN_LIMIT_MESSAGE,
LLM_ERROR_BLOCK_START,
@@ -193,17 +195,6 @@ class AgenticStore {
async runAgenticFlow(params: AgenticFlowParams): Promise<AgenticFlowResult> {
const { conversationId, messages, options = {}, callbacks, signal, perChatOverrides } = params;
const {
onChunk,
onReasoningChunk,
onToolCallChunk,
onAttachments,
onModel,
onComplete,
onError,
onTimings,
onTurnComplete
} = callbacks;
const agenticConfig = this.getConfig(config(), perChatOverrides);
if (!agenticConfig.enabled) return { handled: false };
@@ -253,24 +244,14 @@ class AgenticStore {
options,
tools,
agenticConfig,
callbacks: {
onChunk,
onReasoningChunk,
onToolCallChunk,
onAttachments,
onModel,
onComplete,
onError,
onTimings,
onTurnComplete
},
callbacks,
signal
});
return { handled: true };
} catch (error) {
const normalizedError = error instanceof Error ? error : new Error(String(error));
this.updateSession(conversationId, { lastError: normalizedError });
onError?.(normalizedError);
callbacks.onError?.(normalizedError);
return { handled: true, error: normalizedError };
} finally {
this.updateSession(conversationId, { isRunning: false });
@@ -295,17 +276,20 @@ class AgenticStore {
const {
onChunk,
onReasoningChunk,
onToolCallChunk,
onToolCallsStreaming,
onAttachments,
onModel,
onComplete,
onAssistantTurnComplete,
createToolResultMessage,
createAssistantMessage,
onFlowComplete,
onTimings,
onTurnComplete
} = callbacks;
const sessionMessages: AgenticMessage[] = toAgenticMessages(messages);
const allToolCalls: ApiChatCompletionToolCall[] = [];
let capturedTimings: ChatMessageTimings | undefined;
let totalToolCallCount = 0;
const agenticTimings: ChatMessageAgenticTimings = {
turns: 0,
@@ -316,12 +300,7 @@ class AgenticStore {
llm: { predicted_n: 0, predicted_ms: 0, prompt_n: 0, prompt_ms: 0 }
};
const maxTurns = agenticConfig.maxTurns;
const maxToolPreviewLines = agenticConfig.maxToolPreviewLines;
// Resolve effective model for vision capability checks.
// In ROUTER mode, options.model is always set by the caller.
// In MODEL mode, options.model is undefined; use the single loaded model
// which carries modalities bridged from /props.
const effectiveModel = options.model || modelsStore.models[0]?.model || '';
for (let turn = 0; turn < maxTurns; turn++) {
@@ -329,23 +308,20 @@ class AgenticStore {
agenticTimings.turns = turn + 1;
if (signal?.aborted) {
onComplete?.(
'',
undefined,
this.buildFinalTimings(capturedTimings, agenticTimings),
undefined
);
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
return;
}
// For turns > 0, create a new assistant message via callback
if (turn > 0 && createAssistantMessage) {
await createAssistantMessage();
}
let turnContent = '';
let turnReasoningContent = '';
let turnToolCalls: ApiChatCompletionToolCall[] = [];
let lastStreamingToolCallName = '';
let lastStreamingToolCallArgsLength = 0;
const emittedToolCallStates = new SvelteMap<
number,
{ emittedOnce: boolean; lastArgs: string }
>();
let turnTimings: ChatMessageTimings | undefined;
const turnStats: ChatMessageAgenticTurnStats = {
@@ -366,30 +342,15 @@ class AgenticStore {
turnContent += chunk;
onChunk?.(chunk);
},
onReasoningChunk,
onReasoningChunk: (chunk: string) => {
turnReasoningContent += chunk;
onReasoningChunk?.(chunk);
},
onToolCallChunk: (serialized: string) => {
try {
turnToolCalls = JSON.parse(serialized) as ApiChatCompletionToolCall[];
for (let i = 0; i < turnToolCalls.length; i++) {
const toolCall = turnToolCalls[i];
const toolName = toolCall.function?.name ?? '';
const toolArgs = toolCall.function?.arguments ?? '';
const state = emittedToolCallStates.get(i) || {
emittedOnce: false,
lastArgs: ''
};
if (!state.emittedOnce) {
const output = `\n\n${AGENTIC_TAGS.TOOL_CALL_START}\n${AGENTIC_TAGS.TOOL_NAME_PREFIX}${toolName}${AGENTIC_TAGS.TAG_SUFFIX}\n${AGENTIC_TAGS.TOOL_ARGS_START}\n${toolArgs}`;
onChunk?.(output);
state.emittedOnce = true;
state.lastArgs = toolArgs;
emittedToolCallStates.set(i, state);
} else if (toolArgs.length > state.lastArgs.length) {
onChunk?.(toolArgs.slice(state.lastArgs.length));
state.lastArgs = toolArgs;
emittedToolCallStates.set(i, state);
}
}
onToolCallsStreaming?.(turnToolCalls);
if (turnToolCalls.length > 0 && turnToolCalls[0]?.function) {
const name = turnToolCalls[0].function.name || '';
const args = turnToolCalls[0].function.arguments || '';
@@ -442,77 +403,84 @@ class AgenticStore {
}
} catch (error) {
if (signal?.aborted) {
onComplete?.(
'',
undefined,
// Save whatever we have for this turn before exiting
await onAssistantTurnComplete?.(
turnContent,
turnReasoningContent || undefined,
this.buildFinalTimings(capturedTimings, agenticTimings),
undefined
);
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
return;
}
const normalizedError = error instanceof Error ? error : new Error('LLM stream error');
// Save error as content in the current turn
onChunk?.(`${LLM_ERROR_BLOCK_START}${normalizedError.message}${LLM_ERROR_BLOCK_END}`);
onComplete?.(
'',
undefined,
await onAssistantTurnComplete?.(
turnContent + `${LLM_ERROR_BLOCK_START}${normalizedError.message}${LLM_ERROR_BLOCK_END}`,
turnReasoningContent || undefined,
this.buildFinalTimings(capturedTimings, agenticTimings),
undefined
);
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
throw normalizedError;
}
// No tool calls = final turn, save and complete
if (turnToolCalls.length === 0) {
agenticTimings.perTurn!.push(turnStats);
onComplete?.(
'',
undefined,
this.buildFinalTimings(capturedTimings, agenticTimings),
const finalTimings = this.buildFinalTimings(capturedTimings, agenticTimings);
await onAssistantTurnComplete?.(
turnContent,
turnReasoningContent || undefined,
finalTimings,
undefined
);
if (finalTimings) onTurnComplete?.(finalTimings);
onFlowComplete?.(finalTimings);
return;
}
// Normalize and save assistant turn with tool calls
const normalizedCalls = this.normalizeToolCalls(turnToolCalls);
if (normalizedCalls.length === 0) {
onComplete?.(
'',
undefined,
await onAssistantTurnComplete?.(
turnContent,
turnReasoningContent || undefined,
this.buildFinalTimings(capturedTimings, agenticTimings),
undefined
);
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
return;
}
for (const call of normalizedCalls) {
allToolCalls.push({
id: call.id,
type: call.type,
function: call.function ? { ...call.function } : undefined
});
}
totalToolCallCount += normalizedCalls.length;
this.updateSession(conversationId, { totalToolCalls: totalToolCallCount });
this.updateSession(conversationId, { totalToolCalls: allToolCalls.length });
onToolCallChunk?.(JSON.stringify(allToolCalls));
// Save the assistant message with its tool calls
await onAssistantTurnComplete?.(
turnContent,
turnReasoningContent || undefined,
turnTimings,
normalizedCalls
);
// Add assistant message to session history
sessionMessages.push({
role: MessageRole.ASSISTANT,
content: turnContent || undefined,
tool_calls: normalizedCalls
});
// Execute each tool call and create result messages
for (const toolCall of normalizedCalls) {
if (signal?.aborted) {
onComplete?.(
'',
undefined,
this.buildFinalTimings(capturedTimings, agenticTimings),
undefined
);
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
return;
}
@@ -530,13 +498,7 @@ class AgenticStore {
result = executionResult.content;
} catch (error) {
if (isAbortError(error)) {
onComplete?.(
'',
undefined,
this.buildFinalTimings(capturedTimings, agenticTimings),
undefined
);
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
return;
}
result = `Error: ${error instanceof Error ? error.message : String(error)}`;
@@ -557,21 +519,27 @@ class AgenticStore {
turnStats.toolsMs += Math.round(toolDurationMs);
if (signal?.aborted) {
onComplete?.(
'',
undefined,
this.buildFinalTimings(capturedTimings, agenticTimings),
undefined
);
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
return;
}
const { cleanedResult, attachments } = this.extractBase64Attachments(result);
if (attachments.length > 0) onAttachments?.(attachments);
this.emitToolCallResult(cleanedResult, maxToolPreviewLines, onChunk);
// Create the tool result message in the DB
let toolResultMessage: DatabaseMessage | undefined;
if (createToolResultMessage) {
toolResultMessage = await createToolResultMessage(
toolCall.id,
cleanedResult,
attachments.length > 0 ? attachments : undefined
);
}
if (attachments.length > 0 && toolResultMessage) {
onAttachments?.(toolResultMessage.id, attachments);
}
// Build content parts for session history (including images for vision models)
const contentParts: ApiChatMessageContentPart[] = [
{ type: ContentPartType.TEXT, text: cleanedResult }
];
@@ -605,8 +573,15 @@ class AgenticStore {
}
}
// Turn limit reached
onChunk?.(TURN_LIMIT_MESSAGE);
onComplete?.('', undefined, this.buildFinalTimings(capturedTimings, agenticTimings), undefined);
await onAssistantTurnComplete?.(
TURN_LIMIT_MESSAGE,
undefined,
this.buildFinalTimings(capturedTimings, agenticTimings),
undefined
);
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
}
private buildFinalTimings(
@@ -633,23 +608,6 @@ class AgenticStore {
}));
}
private emitToolCallResult(
result: string,
maxLines: number,
emit?: (chunk: string) => void
): void {
if (!emit) {
return;
}
let output = `${NEWLINE_SEPARATOR}${AGENTIC_TAGS.TOOL_ARGS_END}`;
const lines = result.split(NEWLINE_SEPARATOR);
const trimmedLines = lines.length > maxLines ? lines.slice(-maxLines) : lines;
output += `${NEWLINE_SEPARATOR}${trimmedLines.join(NEWLINE_SEPARATOR)}${NEWLINE_SEPARATOR}${AGENTIC_TAGS.TOOL_CALL_END}${NEWLINE_SEPARATOR}`;
emit(output);
}
private extractBase64Attachments(result: string): {
cleanedResult: string;
attachments: DatabaseMessageExtra[];

View File

@@ -12,7 +12,8 @@
*/
import { SvelteMap } from 'svelte/reactivity';
import { DatabaseService, ChatService } from '$lib/services';
import { DatabaseService } from '$lib/services/database.service';
import { ChatService } from '$lib/services/chat.service';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { agenticStore } from '$lib/stores/agentic.svelte';
@@ -34,7 +35,6 @@ import {
import {
MAX_INACTIVE_CONVERSATION_STATES,
INACTIVE_CONVERSATION_STATE_MAX_AGE_MS,
REASONING_TAGS,
SYSTEM_MESSAGE_PLACEHOLDER
} from '$lib/constants';
import type {
@@ -50,15 +50,6 @@ interface ConversationStateEntry {
lastAccessed: number;
}
const countOccurrences = (source: string, token: string): number =>
source ? source.split(token).length - 1 : 0;
const hasUnclosedReasoningTag = (content: string): boolean =>
countOccurrences(content, REASONING_TAGS.START) > countOccurrences(content, REASONING_TAGS.END);
const wrapReasoningContent = (content: string, reasoningContent?: string): string => {
if (!reasoningContent) return content;
return `${REASONING_TAGS.START}${reasoningContent}${REASONING_TAGS.END}${content}`;
};
class ChatStore {
activeProcessingState = $state<ApiProcessingState | null>(null);
currentResponse = $state('');
@@ -557,83 +548,76 @@ class ChatStore {
await modelsStore.fetchModelProps(effectiveModel);
}
let streamedContent = '',
streamedToolCallContent = '',
isReasoningOpen = false,
hasStreamedChunks = false,
resolvedModel: string | null = null,
modelPersisted = false;
let streamedExtras: DatabaseMessageExtra[] = assistantMessage.extra
? JSON.parse(JSON.stringify(assistantMessage.extra))
: [];
// Mutable state for the current message being streamed
let currentMessageId = assistantMessage.id;
let streamedContent = '';
let streamedReasoningContent = '';
let resolvedModel: string | null = null;
let modelPersisted = false;
const convId = assistantMessage.convId;
const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => {
if (!modelName) return;
const n = normalizeModelName(modelName);
if (!n || n === resolvedModel) return;
resolvedModel = n;
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
const idx = conversationsStore.findMessageIndex(currentMessageId);
conversationsStore.updateMessageAtIndex(idx, { model: n });
if (persistImmediately && !modelPersisted) {
modelPersisted = true;
DatabaseService.updateMessage(assistantMessage.id, { model: n }).catch(() => {
DatabaseService.updateMessage(currentMessageId, { model: n }).catch(() => {
modelPersisted = false;
resolvedModel = null;
});
}
};
const updateStreamingContent = () => {
this.setChatStreaming(assistantMessage.convId, streamedContent, assistantMessage.id);
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
const updateStreamingUI = () => {
this.setChatStreaming(convId, streamedContent, currentMessageId);
const idx = conversationsStore.findMessageIndex(currentMessageId);
conversationsStore.updateMessageAtIndex(idx, { content: streamedContent });
};
const appendContentChunk = (chunk: string) => {
if (isReasoningOpen) {
streamedContent += REASONING_TAGS.END;
isReasoningOpen = false;
}
streamedContent += chunk;
hasStreamedChunks = true;
updateStreamingContent();
};
const appendReasoningChunk = (chunk: string) => {
if (!isReasoningOpen) {
streamedContent += REASONING_TAGS.START;
isReasoningOpen = true;
}
streamedContent += chunk;
hasStreamedChunks = true;
updateStreamingContent();
};
const finalizeReasoning = () => {
if (isReasoningOpen) {
streamedContent += REASONING_TAGS.END;
isReasoningOpen = false;
}
const cleanupStreamingState = () => {
this.setStreamingActive(false);
this.setChatLoading(convId, false);
this.clearChatStreaming(convId);
this.setProcessingState(convId, null);
};
this.setStreamingActive(true);
this.setActiveProcessingConversation(assistantMessage.convId);
const abortController = this.getOrCreateAbortController(assistantMessage.convId);
this.setActiveProcessingConversation(convId);
const abortController = this.getOrCreateAbortController(convId);
const streamCallbacks: ChatStreamCallbacks = {
onChunk: (chunk: string) => appendContentChunk(chunk),
onReasoningChunk: (chunk: string) => appendReasoningChunk(chunk),
onToolCallChunk: (chunk: string) => {
const c = chunk.trim();
if (!c) return;
streamedToolCallContent = c;
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
conversationsStore.updateMessageAtIndex(idx, { toolCalls: streamedToolCallContent });
onChunk: (chunk: string) => {
streamedContent += chunk;
updateStreamingUI();
},
onAttachments: (extras: DatabaseMessageExtra[]) => {
onReasoningChunk: (chunk: string) => {
streamedReasoningContent += chunk;
// Update UI to show reasoning is being received
const idx = conversationsStore.findMessageIndex(currentMessageId);
conversationsStore.updateMessageAtIndex(idx, {
reasoningContent: streamedReasoningContent
});
},
onToolCallsStreaming: (toolCalls) => {
const idx = conversationsStore.findMessageIndex(currentMessageId);
conversationsStore.updateMessageAtIndex(idx, { toolCalls: JSON.stringify(toolCalls) });
},
onAttachments: (messageId: string, extras: DatabaseMessageExtra[]) => {
if (!extras.length) return;
streamedExtras = [...streamedExtras, ...extras];
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
conversationsStore.updateMessageAtIndex(idx, { extra: streamedExtras });
DatabaseService.updateMessage(assistantMessage.id, { extra: streamedExtras }).catch(
console.error
);
const idx = conversationsStore.findMessageIndex(messageId);
if (idx === -1) return;
const msg = conversationsStore.activeMessages[idx];
const updatedExtras = [...(msg.extra || []), ...extras];
conversationsStore.updateMessageAtIndex(idx, { extra: updatedExtras });
DatabaseService.updateMessage(messageId, { extra: updatedExtras }).catch(console.error);
},
onModel: (modelName: string) => recordModel(modelName),
onTurnComplete: (intermediateTimings: ChatMessageTimings) => {
// Update the first assistant message with cumulative agentic timings
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
conversationsStore.updateMessageAtIndex(idx, { timings: intermediateTimings });
},
@@ -651,56 +635,104 @@ class ChatStore {
cache_n: timings?.cache_n || 0,
prompt_progress: promptProgress
},
assistantMessage.convId
convId
);
},
onComplete: async (
finalContent?: string,
reasoningContent?: string,
timings?: ChatMessageTimings,
toolCallContent?: string
onAssistantTurnComplete: async (
content: string,
reasoningContent: string | undefined,
timings: ChatMessageTimings | undefined,
toolCalls: import('$lib/types/api').ApiChatCompletionToolCall[] | undefined
) => {
this.setStreamingActive(false);
finalizeReasoning();
const combinedContent = hasStreamedChunks
? streamedContent
: wrapReasoningContent(finalContent || '', reasoningContent);
const updateData: Record<string, unknown> = {
content: combinedContent,
toolCalls: toolCallContent || streamedToolCallContent,
content,
reasoningContent: reasoningContent || undefined,
toolCalls: toolCalls ? JSON.stringify(toolCalls) : '',
timings
};
if (streamedExtras.length > 0) updateData.extra = streamedExtras;
if (resolvedModel && !modelPersisted) updateData.model = resolvedModel;
await DatabaseService.updateMessage(assistantMessage.id, updateData);
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
await DatabaseService.updateMessage(currentMessageId, updateData);
const idx = conversationsStore.findMessageIndex(currentMessageId);
const uiUpdate: Partial<DatabaseMessage> = {
content: combinedContent,
toolCalls: updateData.toolCalls as string
content,
reasoningContent: reasoningContent || undefined,
toolCalls: toolCalls ? JSON.stringify(toolCalls) : ''
};
if (streamedExtras.length > 0) uiUpdate.extra = streamedExtras;
if (timings) uiUpdate.timings = timings;
if (resolvedModel) uiUpdate.model = resolvedModel;
conversationsStore.updateMessageAtIndex(idx, uiUpdate);
await conversationsStore.updateCurrentNode(assistantMessage.id);
if (onComplete) await onComplete(combinedContent);
this.setChatLoading(assistantMessage.convId, false);
this.clearChatStreaming(assistantMessage.convId);
this.setProcessingState(assistantMessage.convId, null);
await conversationsStore.updateCurrentNode(currentMessageId);
},
createToolResultMessage: async (
toolCallId: string,
content: string,
extras?: DatabaseMessageExtra[]
) => {
const msg = await DatabaseService.createMessageBranch(
{
convId,
type: MessageType.TEXT,
role: MessageRole.TOOL,
content,
toolCallId,
timestamp: Date.now(),
toolCalls: '',
children: [],
extra: extras
},
currentMessageId
);
conversationsStore.addMessageToActive(msg);
await conversationsStore.updateCurrentNode(msg.id);
return msg;
},
createAssistantMessage: async () => {
// Reset streaming state for new message
streamedContent = '';
streamedReasoningContent = '';
const lastMsg =
conversationsStore.activeMessages[conversationsStore.activeMessages.length - 1];
const msg = await DatabaseService.createMessageBranch(
{
convId,
type: MessageType.TEXT,
role: MessageRole.ASSISTANT,
content: '',
timestamp: Date.now(),
toolCalls: '',
children: [],
model: resolvedModel
},
lastMsg.id
);
conversationsStore.addMessageToActive(msg);
currentMessageId = msg.id;
return msg;
},
onFlowComplete: (finalTimings?: ChatMessageTimings) => {
if (finalTimings) {
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
conversationsStore.updateMessageAtIndex(idx, { timings: finalTimings });
DatabaseService.updateMessage(assistantMessage.id, { timings: finalTimings }).catch(
console.error
);
}
cleanupStreamingState();
if (onComplete) onComplete(streamedContent);
if (isRouterMode()) modelsStore.fetchRouterModels().catch(console.error);
},
onError: (error: Error) => {
this.setStreamingActive(false);
if (isAbortError(error)) {
this.setChatLoading(assistantMessage.convId, false);
this.clearChatStreaming(assistantMessage.convId);
this.setProcessingState(assistantMessage.convId, null);
cleanupStreamingState();
return;
}
console.error('Streaming error:', error);
this.setChatLoading(assistantMessage.convId, false);
this.clearChatStreaming(assistantMessage.convId);
this.setProcessingState(assistantMessage.convId, null);
cleanupStreamingState();
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
if (idx !== -1) {
const failedMessage = conversationsStore.removeMessageAtIndex(idx);
@@ -717,12 +749,13 @@ class ChatStore {
if (onError) onError(error);
}
};
const perChatOverrides = conversationsStore.activeConversation?.mcpServerOverrides;
const agenticConfig = agenticStore.getConfig(config(), perChatOverrides);
if (agenticConfig.enabled) {
const agenticResult = await agenticStore.runAgenticFlow({
conversationId: assistantMessage.convId,
conversationId: convId,
messages: allMessages,
options: { ...this.getApiOptions(), ...(effectiveModel ? { model: effectiveModel } : {}) },
callbacks: streamCallbacks,
@@ -732,16 +765,50 @@ class ChatStore {
if (agenticResult.handled) return;
}
const completionOptions = {
...this.getApiOptions(),
...(effectiveModel ? { model: effectiveModel } : {}),
...streamCallbacks
};
// Non-agentic path: direct streaming into the single assistant message
await ChatService.sendMessage(
allMessages,
completionOptions,
assistantMessage.convId,
{
...this.getApiOptions(),
...(effectiveModel ? { model: effectiveModel } : {}),
stream: true,
onChunk: streamCallbacks.onChunk,
onReasoningChunk: streamCallbacks.onReasoningChunk,
onModel: streamCallbacks.onModel,
onTimings: streamCallbacks.onTimings,
onComplete: async (
finalContent?: string,
reasoningContent?: string,
timings?: ChatMessageTimings,
toolCalls?: string
) => {
const content = streamedContent || finalContent || '';
const reasoning = streamedReasoningContent || reasoningContent;
const updateData: Record<string, unknown> = {
content,
reasoningContent: reasoning || undefined,
toolCalls: toolCalls || '',
timings
};
if (resolvedModel && !modelPersisted) updateData.model = resolvedModel;
await DatabaseService.updateMessage(currentMessageId, updateData);
const idx = conversationsStore.findMessageIndex(currentMessageId);
const uiUpdate: Partial<DatabaseMessage> = {
content,
reasoningContent: reasoning || undefined,
toolCalls: toolCalls || ''
};
if (timings) uiUpdate.timings = timings;
if (resolvedModel) uiUpdate.model = resolvedModel;
conversationsStore.updateMessageAtIndex(idx, uiUpdate);
await conversationsStore.updateCurrentNode(currentMessageId);
cleanupStreamingState();
if (onComplete) await onComplete(content);
if (isRouterMode()) modelsStore.fetchRouterModels().catch(console.error);
},
onError: streamCallbacks.onError
},
convId,
abortController.signal
);
}
@@ -1033,56 +1100,40 @@ class ChatStore {
}
const originalContent = dbMessage.content;
const originalReasoning = dbMessage.reasoningContent || '';
const conversationContext = conversationsStore.activeMessages.slice(0, idx);
const contextWithContinue = [
...conversationContext,
{ role: MessageRole.ASSISTANT as const, content: originalContent }
];
let appendedContent = '',
hasReceivedContent = false,
isReasoningOpen = hasUnclosedReasoningTag(originalContent);
let appendedContent = '';
let appendedReasoning = '';
let hasReceivedContent = false;
const updateStreamingContent = (fullContent: string) => {
this.setChatStreaming(msg.convId, fullContent, msg.id);
conversationsStore.updateMessageAtIndex(idx, { content: fullContent });
};
const appendContentChunk = (chunk: string) => {
if (isReasoningOpen) {
appendedContent += REASONING_TAGS.END;
isReasoningOpen = false;
}
appendedContent += chunk;
hasReceivedContent = true;
updateStreamingContent(originalContent + appendedContent);
};
const appendReasoningChunk = (chunk: string) => {
if (!isReasoningOpen) {
appendedContent += REASONING_TAGS.START;
isReasoningOpen = true;
}
appendedContent += chunk;
hasReceivedContent = true;
updateStreamingContent(originalContent + appendedContent);
};
const finalizeReasoning = () => {
if (isReasoningOpen) {
appendedContent += REASONING_TAGS.END;
isReasoningOpen = false;
}
};
const abortController = this.getOrCreateAbortController(msg.convId);
await ChatService.sendMessage(
contextWithContinue,
{
...this.getApiOptions(),
onChunk: (chunk: string) => appendContentChunk(chunk),
onReasoningChunk: (chunk: string) => appendReasoningChunk(chunk),
onChunk: (chunk: string) => {
appendedContent += chunk;
hasReceivedContent = true;
updateStreamingContent(originalContent + appendedContent);
},
onReasoningChunk: (chunk: string) => {
appendedReasoning += chunk;
hasReceivedContent = true;
conversationsStore.updateMessageAtIndex(idx, {
reasoningContent: originalReasoning + appendedReasoning
});
},
onTimings: (timings?: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => {
const tokensPerSecond =
timings?.predicted_ms && timings?.predicted_n
@@ -1105,21 +1156,23 @@ class ChatStore {
reasoningContent?: string,
timings?: ChatMessageTimings
) => {
finalizeReasoning();
const appendedFromCompletion = hasReceivedContent
? appendedContent
: wrapReasoningContent(finalContent || '', reasoningContent);
const fullContent = originalContent + appendedFromCompletion;
const finalAppendedContent = hasReceivedContent ? appendedContent : finalContent || '';
const finalAppendedReasoning = hasReceivedContent
? appendedReasoning
: reasoningContent || '';
const fullContent = originalContent + finalAppendedContent;
const fullReasoning = originalReasoning + finalAppendedReasoning || undefined;
await DatabaseService.updateMessage(msg.id, {
content: fullContent,
reasoningContent: fullReasoning,
timestamp: Date.now(),
timings
});
conversationsStore.updateMessageAtIndex(idx, {
content: fullContent,
reasoningContent: fullReasoning,
timestamp: Date.now(),
timings
});
@@ -1135,11 +1188,13 @@ class ChatStore {
if (hasReceivedContent && appendedContent) {
await DatabaseService.updateMessage(msg.id, {
content: originalContent + appendedContent,
reasoningContent: originalReasoning + appendedReasoning || undefined,
timestamp: Date.now()
});
conversationsStore.updateMessageAtIndex(idx, {
content: originalContent + appendedContent,
reasoningContent: originalReasoning + appendedReasoning || undefined,
timestamp: Date.now()
});
}

View File

@@ -23,7 +23,7 @@ import { browser } from '$app/environment';
import { toast } from 'svelte-sonner';
import { DatabaseService } from '$lib/services/database.service';
import { config } from '$lib/stores/settings.svelte';
import { filterByLeafNodeId, findLeafNode } from '$lib/utils';
import { filterByLeafNodeId, findLeafNode, runLegacyMigration } from '$lib/utils';
import type { McpServerOverride } from '$lib/types/database';
import { MessageRole } from '$lib/enums';
import {
@@ -128,6 +128,10 @@ class ConversationsStore {
if (this.isInitialized) return;
try {
// @deprecated Legacy migration for old marker-based messages.
// Remove once all users have migrated to the structured format.
await runLegacyMigration();
await this.loadConversations();
this.isInitialized = true;
} catch (error) {

View File

@@ -2,6 +2,7 @@ import type { MessageRole } from '$lib/enums';
import { ToolCallType } from '$lib/enums';
import type {
ApiChatCompletionRequest,
ApiChatCompletionToolCall,
ApiChatMessageContentPart,
ApiChatMessageData
} from './api';
@@ -70,22 +71,48 @@ export interface AgenticSession {
}
/**
* Callbacks for agentic flow execution
* Callbacks for agentic flow execution.
*
* The agentic loop creates separate DB messages for each turn:
* - assistant messages (one per LLM turn, with tool_calls if any)
* - tool result messages (one per tool call execution)
*
* The first assistant message is created by the caller before starting the flow.
* Subsequent messages are created via createToolResultMessage / createAssistantMessage.
*/
export interface AgenticFlowCallbacks {
/** Content chunk for the current assistant message */
onChunk?: (chunk: string) => void;
/** Reasoning content chunk for the current assistant message */
onReasoningChunk?: (chunk: string) => void;
onToolCallChunk?: (serializedToolCalls: string) => void;
onAttachments?: (extras: DatabaseMessageExtra[]) => void;
/** Tool calls being streamed (partial, accumulating) for the current turn */
onToolCallsStreaming?: (toolCalls: ApiChatCompletionToolCall[]) => void;
/** Attachments extracted from tool results */
onAttachments?: (messageId: string, extras: DatabaseMessageExtra[]) => void;
/** Model name detected from response */
onModel?: (model: string) => void;
onComplete?: (
/** Current assistant turn's streaming is complete - save to DB */
onAssistantTurnComplete?: (
content: string,
reasoningContent?: string,
timings?: ChatMessageTimings,
toolCalls?: string
) => void;
reasoningContent: string | undefined,
timings: ChatMessageTimings | undefined,
toolCalls: ApiChatCompletionToolCall[] | undefined
) => Promise<void>;
/** Create a tool result message in the DB tree */
createToolResultMessage?: (
toolCallId: string,
content: string,
extras?: DatabaseMessageExtra[]
) => Promise<DatabaseMessage>;
/** Create a new assistant message for the next agentic turn */
createAssistantMessage?: () => Promise<DatabaseMessage>;
/** Entire agentic flow is complete */
onFlowComplete?: (timings?: ChatMessageTimings) => void;
/** Error during flow */
onError?: (error: Error) => void;
/** Timing updates during streaming */
onTimings?: (timings?: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => void;
/** An agentic turn (LLM + tool execution) completed - intermediate timing update */
onTurnComplete?: (intermediateTimings: ChatMessageTimings) => void;
}

View File

@@ -1,5 +1,6 @@
import type { ErrorDialogType } from '$lib/enums';
import type { DatabaseMessageExtra } from './database';
import type { ApiChatCompletionToolCall } from './api';
import type { DatabaseMessage, DatabaseMessageExtra } from './database';
export interface ChatUploadedFile {
id: string;
@@ -99,21 +100,28 @@ export interface ChatMessageToolCallTiming {
}
/**
* Callbacks for streaming chat responses
* Callbacks for streaming chat responses (used by both agentic and non-agentic paths)
*/
export interface ChatStreamCallbacks {
onChunk?: (chunk: string) => void;
onReasoningChunk?: (chunk: string) => void;
onToolCallChunk?: (chunk: string) => void;
onAttachments?: (extras: DatabaseMessageExtra[]) => void;
onToolCallsStreaming?: (toolCalls: ApiChatCompletionToolCall[]) => void;
onAttachments?: (messageId: string, extras: DatabaseMessageExtra[]) => void;
onModel?: (model: string) => void;
onTimings?: (timings?: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => void;
onComplete?: (
content?: string,
reasoningContent?: string,
timings?: ChatMessageTimings,
toolCallContent?: string
) => void;
onAssistantTurnComplete?: (
content: string,
reasoningContent: string | undefined,
timings: ChatMessageTimings | undefined,
toolCalls: ApiChatCompletionToolCall[] | undefined
) => Promise<void>;
createToolResultMessage?: (
toolCallId: string,
content: string,
extras?: DatabaseMessageExtra[]
) => Promise<DatabaseMessage>;
createAssistantMessage?: () => Promise<DatabaseMessage>;
onFlowComplete?: (timings?: ChatMessageTimings) => void;
onError?: (error: Error) => void;
onTurnComplete?: (intermediateTimings: ChatMessageTimings) => void;
}

View File

@@ -92,6 +92,8 @@ export interface DatabaseMessage {
* @deprecated - left for backward compatibility
*/
thinking?: string;
/** Reasoning content produced by the model (separate from visible content) */
reasoningContent?: string;
/** Serialized JSON array of tool calls made by assistant messages */
toolCalls?: string;
/** Tool call ID for tool result messages (role: 'tool') */

View File

@@ -1,8 +1,15 @@
import { AgenticSectionType } from '$lib/enums';
import { AGENTIC_TAGS, AGENTIC_REGEX, REASONING_TAGS, TRIM_NEWLINES_REGEX } from '$lib/constants';
import { AgenticSectionType, MessageRole } from '$lib/enums';
import { ATTACHMENT_SAVED_REGEX, NEWLINE_SEPARATOR } from '$lib/constants';
import type { ApiChatCompletionToolCall } from '$lib/types/api';
import type {
DatabaseMessage,
DatabaseMessageExtra,
DatabaseMessageExtraImageFile
} from '$lib/types/database';
import { AttachmentType } from '$lib/enums';
/**
* Represents a parsed section of agentic content
* Represents a parsed section of agentic content for display
*/
export interface AgenticSection {
type: AgenticSectionType;
@@ -10,63 +17,70 @@ export interface AgenticSection {
toolName?: string;
toolArgs?: string;
toolResult?: string;
toolResultExtras?: DatabaseMessageExtra[];
}
/**
* Represents a segment of content that may contain reasoning blocks
* Represents a tool result line that may reference an image attachment
*/
type ReasoningSegment = {
type:
| AgenticSectionType.TEXT
| AgenticSectionType.REASONING
| AgenticSectionType.REASONING_PENDING;
content: string;
export type ToolResultLine = {
text: string;
image?: DatabaseMessageExtraImageFile;
};
/**
* Parses agentic content into structured sections
* Derives display sections from a single assistant message and its direct tool results.
*
* Main parsing function that processes content containing:
* - Tool calls (completed, pending, or streaming)
* - Reasoning blocks (completed or streaming)
* - Regular text content
*
* The parser handles chronological display of agentic flow output, maintaining
* the order of operations and properly identifying different states of tool calls
* and reasoning blocks during streaming.
*
* @param rawContent - The raw content string to parse
* @returns Array of structured agentic sections ready for display
*
* @example
* ```typescript
* const content = "Some text <<<AGENTIC_TOOL_CALL>>>tool_name...";
* const sections = parseAgenticContent(content);
* // Returns: [{ type: 'text', content: 'Some text' }, { type: 'tool_call_streaming', ... }]
* ```
* @param message - The assistant message
* @param toolMessages - Tool result messages for this assistant's tool_calls
* @param streamingToolCalls - Partial tool calls during streaming (not yet persisted)
*/
export function parseAgenticContent(rawContent: string): AgenticSection[] {
if (!rawContent) return [];
const segments = splitReasoningSegments(rawContent);
function deriveSingleTurnSections(
message: DatabaseMessage,
toolMessages: DatabaseMessage[] = [],
streamingToolCalls: ApiChatCompletionToolCall[] = []
): AgenticSection[] {
const sections: AgenticSection[] = [];
for (const segment of segments) {
if (segment.type === AgenticSectionType.TEXT) {
sections.push(...parseToolCallContent(segment.content));
continue;
}
if (segment.type === AgenticSectionType.REASONING) {
if (segment.content.trim()) {
sections.push({ type: AgenticSectionType.REASONING, content: segment.content });
}
continue;
}
// 1. Reasoning content (from dedicated field)
if (message.reasoningContent) {
sections.push({
type: AgenticSectionType.REASONING_PENDING,
content: segment.content
type: AgenticSectionType.REASONING,
content: message.reasoningContent
});
}
// 2. Text content
if (message.content?.trim()) {
sections.push({
type: AgenticSectionType.TEXT,
content: message.content
});
}
// 3. Persisted tool calls (from message.toolCalls field)
const toolCalls = parseToolCalls(message.toolCalls);
for (const tc of toolCalls) {
const resultMsg = toolMessages.find((m) => m.toolCallId === tc.id);
sections.push({
type: resultMsg ? AgenticSectionType.TOOL_CALL : AgenticSectionType.TOOL_CALL_PENDING,
content: resultMsg?.content || '',
toolName: tc.function?.name,
toolArgs: tc.function?.arguments,
toolResult: resultMsg?.content,
toolResultExtras: resultMsg?.extra
});
}
// 4. Streaming tool calls (not yet persisted - currently being received)
for (const tc of streamingToolCalls) {
// Skip if already in persisted tool calls
if (tc.id && toolCalls.find((t) => t.id === tc.id)) continue;
sections.push({
type: AgenticSectionType.TOOL_CALL_STREAMING,
content: '',
toolName: tc.function?.name,
toolArgs: tc.function?.arguments
});
}
@@ -74,211 +88,123 @@ export function parseAgenticContent(rawContent: string): AgenticSection[] {
}
/**
* Parses content containing tool call markers
* Derives display sections from structured message data.
*
* Identifies and extracts tool calls from content, handling:
* - Completed tool calls with name, arguments, and results
* - Pending tool calls (execution in progress)
* - Streaming tool calls (arguments being received)
* - Early-stage tool calls (just started)
* Handles both single-turn (one assistant + its tool results) and multi-turn
* agentic sessions (multiple assistant + tool messages grouped together).
*
* @param rawContent - The raw content string to parse
* @returns Array of agentic sections representing tool calls and text
* When `toolMessages` contains continuation assistant messages (from multi-turn
* agentic flows), they are processed in order to produce sections across all turns.
*
* @param message - The first/anchor assistant message
* @param toolMessages - Tool result messages and continuation assistant messages
* @param streamingToolCalls - Partial tool calls during streaming (not yet persisted)
* @param isStreaming - Whether the message is currently being streamed
*/
function parseToolCallContent(rawContent: string): AgenticSection[] {
if (!rawContent) return [];
export function deriveAgenticSections(
message: DatabaseMessage,
toolMessages: DatabaseMessage[] = [],
streamingToolCalls: ApiChatCompletionToolCall[] = []
): AgenticSection[] {
const hasAssistantContinuations = toolMessages.some((m) => m.role === MessageRole.ASSISTANT);
if (!hasAssistantContinuations) {
return deriveSingleTurnSections(message, toolMessages, streamingToolCalls);
}
const sections: AgenticSection[] = [];
const completedToolCallRegex = new RegExp(AGENTIC_REGEX.COMPLETED_TOOL_CALL.source, 'g');
const firstTurnToolMsgs = collectToolMessages(toolMessages, 0);
sections.push(...deriveSingleTurnSections(message, firstTurnToolMsgs));
let lastIndex = 0;
let match;
let i = firstTurnToolMsgs.length;
while ((match = completedToolCallRegex.exec(rawContent)) !== null) {
if (match.index > lastIndex) {
const textBefore = rawContent.slice(lastIndex, match.index).trim();
if (textBefore) {
sections.push({ type: AgenticSectionType.TEXT, content: textBefore });
}
while (i < toolMessages.length) {
const msg = toolMessages[i];
if (msg.role === MessageRole.ASSISTANT) {
const turnToolMsgs = collectToolMessages(toolMessages, i + 1);
const isLastTurn = i + 1 + turnToolMsgs.length >= toolMessages.length;
sections.push(
...deriveSingleTurnSections(msg, turnToolMsgs, isLastTurn ? streamingToolCalls : [])
);
i += 1 + turnToolMsgs.length;
} else {
i++;
}
const toolName = match[1];
const toolArgs = match[2];
const toolResult = match[3].replace(TRIM_NEWLINES_REGEX, '');
sections.push({
type: AgenticSectionType.TOOL_CALL,
content: toolResult,
toolName,
toolArgs,
toolResult
});
lastIndex = match.index + match[0].length;
}
const remainingContent = rawContent.slice(lastIndex);
const pendingMatch = remainingContent.match(AGENTIC_REGEX.PENDING_TOOL_CALL);
const partialWithNameMatch = remainingContent.match(AGENTIC_REGEX.PARTIAL_WITH_NAME);
const earlyMatch = remainingContent.match(AGENTIC_REGEX.EARLY_MATCH);
if (pendingMatch) {
const pendingIndex = remainingContent.indexOf(AGENTIC_TAGS.TOOL_CALL_START);
if (pendingIndex > 0) {
const textBefore = remainingContent.slice(0, pendingIndex).trim();
if (textBefore) {
sections.push({ type: AgenticSectionType.TEXT, content: textBefore });
}
}
const toolName = pendingMatch[1];
const toolArgs = pendingMatch[2];
const streamingResult = (pendingMatch[3] || '').replace(TRIM_NEWLINES_REGEX, '');
sections.push({
type: AgenticSectionType.TOOL_CALL_PENDING,
content: streamingResult,
toolName,
toolArgs,
toolResult: streamingResult || undefined
});
} else if (partialWithNameMatch) {
const pendingIndex = remainingContent.indexOf(AGENTIC_TAGS.TOOL_CALL_START);
if (pendingIndex > 0) {
const textBefore = remainingContent.slice(0, pendingIndex).trim();
if (textBefore) {
sections.push({ type: AgenticSectionType.TEXT, content: textBefore });
}
}
const partialArgs = partialWithNameMatch[2] || '';
sections.push({
type: AgenticSectionType.TOOL_CALL_STREAMING,
content: '',
toolName: partialWithNameMatch[1],
toolArgs: partialArgs || undefined,
toolResult: undefined
});
} else if (earlyMatch) {
const pendingIndex = remainingContent.indexOf(AGENTIC_TAGS.TOOL_CALL_START);
if (pendingIndex > 0) {
const textBefore = remainingContent.slice(0, pendingIndex).trim();
if (textBefore) {
sections.push({ type: AgenticSectionType.TEXT, content: textBefore });
}
}
const nameMatch = earlyMatch[1]?.match(AGENTIC_REGEX.TOOL_NAME_EXTRACT);
sections.push({
type: AgenticSectionType.TOOL_CALL_STREAMING,
content: '',
toolName: nameMatch?.[1],
toolArgs: undefined,
toolResult: undefined
});
} else if (lastIndex < rawContent.length) {
let remainingText = rawContent.slice(lastIndex).trim();
const partialMarkerMatch = remainingText.match(AGENTIC_REGEX.PARTIAL_MARKER);
if (partialMarkerMatch) {
remainingText = remainingText.slice(0, partialMarkerMatch.index).trim();
}
if (remainingText) {
sections.push({ type: AgenticSectionType.TEXT, content: remainingText });
}
}
if (sections.length === 0 && rawContent.trim()) {
sections.push({ type: AgenticSectionType.TEXT, content: rawContent });
}
return sections;
}
/**
* Strips partial marker from text content
*
* Removes incomplete agentic markers (e.g., "<<<", "<<<AGENTIC") that may appear
* at the end of streaming content.
*
* @param text - The text content to process
* @returns Text with partial markers removed
* Collect consecutive tool messages starting at `startIndex`.
*/
function stripPartialMarker(text: string): string {
const partialMarkerMatch = text.match(AGENTIC_REGEX.PARTIAL_MARKER);
function collectToolMessages(messages: DatabaseMessage[], startIndex: number): DatabaseMessage[] {
const result: DatabaseMessage[] = [];
if (partialMarkerMatch) {
return text.slice(0, partialMarkerMatch.index).trim();
for (let i = startIndex; i < messages.length; i++) {
if (messages[i].role === MessageRole.TOOL) {
result.push(messages[i]);
} else {
break;
}
}
return text;
return result;
}
/**
* Splits raw content into segments based on reasoning blocks
*
* Identifies and extracts reasoning content wrapped in REASONING_TAGS.START/END markers,
* separating it from regular text content. Handles both complete and incomplete
* (streaming) reasoning blocks.
*
* @param rawContent - The raw content string to parse
* @returns Array of reasoning segments with their types and content
* Parse tool result text into lines, matching image attachments by name.
*/
function splitReasoningSegments(rawContent: string): ReasoningSegment[] {
if (!rawContent) return [];
export function parseToolResultWithImages(
toolResult: string,
extras?: DatabaseMessageExtra[]
): ToolResultLine[] {
const lines = toolResult.split(NEWLINE_SEPARATOR);
return lines.map((line) => {
const match = line.match(ATTACHMENT_SAVED_REGEX);
if (!match || !extras) return { text: line };
const segments: ReasoningSegment[] = [];
let cursor = 0;
const attachmentName = match[1];
const image = extras.find(
(e): e is DatabaseMessageExtraImageFile =>
e.type === AttachmentType.IMAGE && e.name === attachmentName
);
while (cursor < rawContent.length) {
const startIndex = rawContent.indexOf(REASONING_TAGS.START, cursor);
return { text: line, image };
});
}
if (startIndex === -1) {
const remainingText = rawContent.slice(cursor);
/**
* Safely parse the toolCalls JSON string from a DatabaseMessage.
*/
function parseToolCalls(toolCallsJson?: string): ApiChatCompletionToolCall[] {
if (!toolCallsJson) return [];
if (remainingText) {
segments.push({ type: AgenticSectionType.TEXT, content: remainingText });
}
try {
const parsed = JSON.parse(toolCallsJson);
break;
}
return Array.isArray(parsed) ? parsed : [];
} catch {
return [];
}
}
if (startIndex > cursor) {
const textBefore = rawContent.slice(cursor, startIndex);
/**
* Check if a message has agentic content (tool calls or is part of an agentic flow).
*/
export function hasAgenticContent(
message: DatabaseMessage,
toolMessages: DatabaseMessage[] = []
): boolean {
if (message.toolCalls) {
const tc = parseToolCalls(message.toolCalls);
if (textBefore) {
segments.push({ type: AgenticSectionType.TEXT, content: textBefore });
}
}
const contentStart = startIndex + REASONING_TAGS.START.length;
const endIndex = rawContent.indexOf(REASONING_TAGS.END, contentStart);
if (endIndex === -1) {
const pendingContent = rawContent.slice(contentStart);
segments.push({
type: AgenticSectionType.REASONING_PENDING,
content: stripPartialMarker(pendingContent)
});
break;
}
const reasoningContent = rawContent.slice(contentStart, endIndex);
segments.push({ type: AgenticSectionType.REASONING, content: reasoningContent });
cursor = endIndex + REASONING_TAGS.END.length;
if (tc.length > 0) return true;
}
return segments;
return toolMessages.length > 0;
}

View File

@@ -149,8 +149,17 @@ export { parseHeadersToArray, serializeHeaders } from './headers';
// Favicon utilities
export { getFaviconUrl } from './favicon';
// Agentic content parsing utilities
export { parseAgenticContent, type AgenticSection } from './agentic';
// Agentic content utilities (structured section derivation)
export {
deriveAgenticSections,
parseToolResultWithImages,
hasAgenticContent,
type AgenticSection,
type ToolResultLine
} from './agentic';
// Legacy migration utilities
export { runLegacyMigration, isMigrationNeeded } from './legacy-migration';
// Cache utilities
export { TTLCache, ReactiveTTLMap, type TTLCacheOptions } from './cache-ttl';

View File

@@ -0,0 +1,345 @@
/**
* @deprecated Legacy migration utility — remove at some point in the future once all users have migrated to the new structured agentic message format.
*
* Converts old marker-based agentic messages to the new structured format
* with separate messages per turn.
*
* Old format: Single assistant message with markers in content:
* <<<reasoning_content_start>>>...<<<reasoning_content_end>>>
* <<<AGENTIC_TOOL_CALL_START>>>...<<<AGENTIC_TOOL_CALL_END>>>
*
* New format: Separate messages per turn:
* - assistant (content + reasoningContent + toolCalls)
* - tool (toolCallId + content)
* - assistant (next turn)
* - ...
*/
import { LEGACY_AGENTIC_REGEX, LEGACY_REASONING_TAGS } from '$lib/constants';
import { DatabaseService } from '$lib/services/database.service';
import { MessageRole, MessageType } from '$lib/enums';
import type { DatabaseMessage } from '$lib/types/database';
const MIGRATION_DONE_KEY = 'llama-webui-migration-v2-done';
/**
* @deprecated Part of legacy migration — remove with the migration module.
* Check if migration has been performed.
*/
export function isMigrationNeeded(): boolean {
try {
return !localStorage.getItem(MIGRATION_DONE_KEY);
} catch {
return false;
}
}
/**
* Mark migration as done.
*/
function markMigrationDone(): void {
try {
localStorage.setItem(MIGRATION_DONE_KEY, String(Date.now()));
} catch {
// Ignore localStorage errors
}
}
/**
* Check if a message has legacy markers in its content.
*/
function hasLegacyMarkers(message: DatabaseMessage): boolean {
if (!message.content) return false;
return LEGACY_AGENTIC_REGEX.HAS_LEGACY_MARKERS.test(message.content);
}
/**
* Extract reasoning content from legacy marker format.
*/
function extractLegacyReasoning(content: string): { reasoning: string; cleanContent: string } {
let reasoning = '';
let cleanContent = content;
// Extract all reasoning blocks
const re = new RegExp(LEGACY_AGENTIC_REGEX.REASONING_EXTRACT.source, 'g');
let match;
while ((match = re.exec(content)) !== null) {
reasoning += match[1];
}
// Remove reasoning tags from content
cleanContent = cleanContent
.replace(new RegExp(LEGACY_AGENTIC_REGEX.REASONING_BLOCK.source, 'g'), '')
.replace(LEGACY_AGENTIC_REGEX.REASONING_OPEN, '');
return { reasoning, cleanContent };
}
/**
* Parse legacy content with tool call markers into structured turns.
*/
interface ParsedTurn {
textBefore: string;
toolCalls: Array<{
name: string;
args: string;
result: string;
}>;
}
function parseLegacyToolCalls(content: string): ParsedTurn[] {
const turns: ParsedTurn[] = [];
const regex = new RegExp(LEGACY_AGENTIC_REGEX.COMPLETED_TOOL_CALL.source, 'g');
let lastIndex = 0;
let currentTurn: ParsedTurn = { textBefore: '', toolCalls: [] };
let match;
while ((match = regex.exec(content)) !== null) {
const textBefore = content.slice(lastIndex, match.index).trim();
// If there's text between tool calls and we already have tool calls,
// that means a new turn started (text after tool results = new LLM turn)
if (textBefore && currentTurn.toolCalls.length > 0) {
turns.push(currentTurn);
currentTurn = { textBefore, toolCalls: [] };
} else if (textBefore && currentTurn.toolCalls.length === 0) {
currentTurn.textBefore = textBefore;
}
currentTurn.toolCalls.push({
name: match[1],
args: match[2],
result: match[3].replace(/^\n+|\n+$/g, '')
});
lastIndex = match.index + match[0].length;
}
// Any remaining text after the last tool call
const remainingText = content.slice(lastIndex).trim();
if (currentTurn.toolCalls.length > 0) {
turns.push(currentTurn);
}
// If there's text after all tool calls, it's the final assistant response
if (remainingText) {
// Remove any partial/open markers
const cleanRemaining = remainingText
.replace(LEGACY_AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '')
.trim();
if (cleanRemaining) {
turns.push({ textBefore: cleanRemaining, toolCalls: [] });
}
}
// If no tool calls found at all, return the original content as a single turn
if (turns.length === 0) {
turns.push({ textBefore: content.trim(), toolCalls: [] });
}
return turns;
}
/**
* Migrate a single conversation's messages from legacy format to new format.
*/
async function migrateConversation(convId: string): Promise<number> {
const allMessages = await DatabaseService.getConversationMessages(convId);
let migratedCount = 0;
for (const message of allMessages) {
if (message.role !== MessageRole.ASSISTANT) continue;
if (!hasLegacyMarkers(message)) {
// Still check for reasoning-only markers (no tool calls)
if (message.content?.includes(LEGACY_REASONING_TAGS.START)) {
const { reasoning, cleanContent } = extractLegacyReasoning(message.content);
await DatabaseService.updateMessage(message.id, {
content: cleanContent.trim(),
reasoningContent: reasoning || undefined
});
migratedCount++;
}
continue;
}
// Has agentic markers - full migration needed
const { reasoning, cleanContent } = extractLegacyReasoning(message.content);
const turns = parseLegacyToolCalls(cleanContent);
// Parse existing toolCalls JSON to try to match IDs
let existingToolCalls: Array<{ id: string; function?: { name: string; arguments: string } }> =
[];
if (message.toolCalls) {
try {
existingToolCalls = JSON.parse(message.toolCalls);
} catch {
// Ignore
}
}
// First turn uses the existing message
const firstTurn = turns[0];
if (!firstTurn) continue;
// Match tool calls from the first turn to existing IDs
const firstTurnToolCalls = firstTurn.toolCalls.map((tc, i) => {
const existing =
existingToolCalls.find((e) => e.function?.name === tc.name) || existingToolCalls[i];
return {
id: existing?.id || `legacy_tool_${i}`,
type: 'function' as const,
function: { name: tc.name, arguments: tc.args }
};
});
// Update the existing message for the first turn
await DatabaseService.updateMessage(message.id, {
content: firstTurn.textBefore,
reasoningContent: reasoning || undefined,
toolCalls: firstTurnToolCalls.length > 0 ? JSON.stringify(firstTurnToolCalls) : ''
});
let currentParentId = message.id;
let toolCallIdCounter = existingToolCalls.length;
// Create tool result messages for the first turn
for (let i = 0; i < firstTurn.toolCalls.length; i++) {
const tc = firstTurn.toolCalls[i];
const toolCallId = firstTurnToolCalls[i]?.id || `legacy_tool_${i}`;
const toolMsg = await DatabaseService.createMessageBranch(
{
convId,
type: MessageType.TEXT,
role: MessageRole.TOOL,
content: tc.result,
toolCallId,
timestamp: message.timestamp + i + 1,
toolCalls: '',
children: []
},
currentParentId
);
currentParentId = toolMsg.id;
}
// Create messages for subsequent turns
for (let turnIdx = 1; turnIdx < turns.length; turnIdx++) {
const turn = turns[turnIdx];
const turnToolCalls = turn.toolCalls.map((tc, i) => {
const idx = toolCallIdCounter + i;
const existing = existingToolCalls[idx];
return {
id: existing?.id || `legacy_tool_${idx}`,
type: 'function' as const,
function: { name: tc.name, arguments: tc.args }
};
});
toolCallIdCounter += turn.toolCalls.length;
// Create assistant message for this turn
const assistantMsg = await DatabaseService.createMessageBranch(
{
convId,
type: MessageType.TEXT,
role: MessageRole.ASSISTANT,
content: turn.textBefore,
timestamp: message.timestamp + turnIdx * 100,
toolCalls: turnToolCalls.length > 0 ? JSON.stringify(turnToolCalls) : '',
children: [],
model: message.model
},
currentParentId
);
currentParentId = assistantMsg.id;
// Create tool result messages for this turn
for (let i = 0; i < turn.toolCalls.length; i++) {
const tc = turn.toolCalls[i];
const toolCallId = turnToolCalls[i]?.id || `legacy_tool_${toolCallIdCounter + i}`;
const toolMsg = await DatabaseService.createMessageBranch(
{
convId,
type: MessageType.TEXT,
role: MessageRole.TOOL,
content: tc.result,
toolCallId,
timestamp: message.timestamp + turnIdx * 100 + i + 1,
toolCalls: '',
children: []
},
currentParentId
);
currentParentId = toolMsg.id;
}
}
// Re-parent any children of the original message to the last created message
// (the original message's children list was the next user message or similar)
if (message.children.length > 0 && currentParentId !== message.id) {
for (const childId of message.children) {
// Skip children we just created (they were already properly parented)
const child = allMessages.find((m) => m.id === childId);
if (!child) continue;
// Only re-parent non-tool messages that were original children
if (child.role !== MessageRole.TOOL) {
await DatabaseService.updateMessage(childId, { parent: currentParentId });
// Add to new parent's children
const newParent = await DatabaseService.getConversationMessages(convId).then((msgs) =>
msgs.find((m) => m.id === currentParentId)
);
if (newParent && !newParent.children.includes(childId)) {
await DatabaseService.updateMessage(currentParentId, {
children: [...newParent.children, childId]
});
}
}
}
// Clear re-parented children from the original message
await DatabaseService.updateMessage(message.id, { children: [] });
}
migratedCount++;
}
return migratedCount;
}
/**
* @deprecated Part of legacy migration — remove with the migration module.
* Run the full migration across all conversations.
* This should be called once at app startup if migration is needed.
*/
export async function runLegacyMigration(): Promise<void> {
if (!isMigrationNeeded()) return;
console.log('[Migration] Starting legacy message format migration...');
try {
const conversations = await DatabaseService.getAllConversations();
let totalMigrated = 0;
for (const conv of conversations) {
const count = await migrateConversation(conv.id);
totalMigrated += count;
}
if (totalMigrated > 0) {
console.log(
`[Migration] Migrated ${totalMigrated} messages across ${conversations.length} conversations`
);
} else {
console.log('[Migration] No legacy messages found, marking as done');
}
markMigrationDone();
} catch (error) {
console.error('[Migration] Failed to migrate legacy messages:', error);
// Still mark as done to avoid infinite retry loops
markMigrationDone();
}
}

View File

@@ -0,0 +1,211 @@
import { describe, it, expect } from 'vitest';
import { deriveAgenticSections, hasAgenticContent } from '$lib/utils/agentic';
import { AgenticSectionType, MessageRole } from '$lib/enums';
import type { DatabaseMessage } from '$lib/types/database';
import type { ApiChatCompletionToolCall } from '$lib/types/api';
function makeAssistant(overrides: Partial<DatabaseMessage> = {}): DatabaseMessage {
return {
id: overrides.id ?? 'ast-1',
convId: 'conv-1',
type: 'text',
timestamp: Date.now(),
role: MessageRole.ASSISTANT,
content: overrides.content ?? '',
parent: null,
children: [],
...overrides
} as DatabaseMessage;
}
function makeToolMsg(overrides: Partial<DatabaseMessage> = {}): DatabaseMessage {
return {
id: overrides.id ?? 'tool-1',
convId: 'conv-1',
type: 'text',
timestamp: Date.now(),
role: MessageRole.TOOL,
content: overrides.content ?? 'tool result',
parent: null,
children: [],
toolCallId: overrides.toolCallId ?? 'call_1',
...overrides
} as DatabaseMessage;
}
describe('deriveAgenticSections', () => {
it('returns empty array for assistant with no content', () => {
const msg = makeAssistant({ content: '' });
const sections = deriveAgenticSections(msg);
expect(sections).toEqual([]);
});
it('returns text section for simple assistant message', () => {
const msg = makeAssistant({ content: 'Hello world' });
const sections = deriveAgenticSections(msg);
expect(sections).toHaveLength(1);
expect(sections[0].type).toBe(AgenticSectionType.TEXT);
expect(sections[0].content).toBe('Hello world');
});
it('returns reasoning + text for message with reasoning', () => {
const msg = makeAssistant({
content: 'Answer is 4.',
reasoningContent: 'Let me think...'
});
const sections = deriveAgenticSections(msg);
expect(sections).toHaveLength(2);
expect(sections[0].type).toBe(AgenticSectionType.REASONING);
expect(sections[0].content).toBe('Let me think...');
expect(sections[1].type).toBe(AgenticSectionType.TEXT);
});
it('single turn: assistant with tool calls and results', () => {
const msg = makeAssistant({
content: 'Let me check.',
toolCalls: JSON.stringify([
{ id: 'call_1', type: 'function', function: { name: 'search', arguments: '{"q":"test"}' } }
])
});
const toolResult = makeToolMsg({
toolCallId: 'call_1',
content: 'Found 3 results'
});
const sections = deriveAgenticSections(msg, [toolResult]);
expect(sections).toHaveLength(2);
expect(sections[0].type).toBe(AgenticSectionType.TEXT);
expect(sections[1].type).toBe(AgenticSectionType.TOOL_CALL);
expect(sections[1].toolName).toBe('search');
expect(sections[1].toolResult).toBe('Found 3 results');
});
it('single turn: pending tool call without result', () => {
const msg = makeAssistant({
toolCalls: JSON.stringify([
{ id: 'call_1', type: 'function', function: { name: 'bash', arguments: '{}' } }
])
});
const sections = deriveAgenticSections(msg, []);
expect(sections).toHaveLength(1);
expect(sections[0].type).toBe(AgenticSectionType.TOOL_CALL_PENDING);
expect(sections[0].toolName).toBe('bash');
});
it('multi-turn: two assistant turns grouped as one session', () => {
const assistant1 = makeAssistant({
id: 'ast-1',
content: 'Turn 1 text',
toolCalls: JSON.stringify([
{ id: 'call_1', type: 'function', function: { name: 'search', arguments: '{"q":"foo"}' } }
])
});
const tool1 = makeToolMsg({ id: 'tool-1', toolCallId: 'call_1', content: 'result 1' });
const assistant2 = makeAssistant({
id: 'ast-2',
content: 'Final answer based on results.'
});
// toolMessages contains both tool result and continuation assistant
const sections = deriveAgenticSections(assistant1, [tool1, assistant2]);
expect(sections).toHaveLength(3);
// Turn 1
expect(sections[0].type).toBe(AgenticSectionType.TEXT);
expect(sections[0].content).toBe('Turn 1 text');
expect(sections[1].type).toBe(AgenticSectionType.TOOL_CALL);
expect(sections[1].toolName).toBe('search');
expect(sections[1].toolResult).toBe('result 1');
// Turn 2 (final)
expect(sections[2].type).toBe(AgenticSectionType.TEXT);
expect(sections[2].content).toBe('Final answer based on results.');
});
it('multi-turn: three turns with tool calls', () => {
const assistant1 = makeAssistant({
id: 'ast-1',
content: '',
toolCalls: JSON.stringify([
{ id: 'call_1', type: 'function', function: { name: 'list_files', arguments: '{}' } }
])
});
const tool1 = makeToolMsg({ id: 'tool-1', toolCallId: 'call_1', content: 'file1 file2' });
const assistant2 = makeAssistant({
id: 'ast-2',
content: 'Reading file1...',
toolCalls: JSON.stringify([
{
id: 'call_2',
type: 'function',
function: { name: 'read_file', arguments: '{"path":"file1"}' }
}
])
});
const tool2 = makeToolMsg({ id: 'tool-2', toolCallId: 'call_2', content: 'contents of file1' });
const assistant3 = makeAssistant({
id: 'ast-3',
content: 'Here is the analysis.',
reasoningContent: 'The file contains...'
});
const sections = deriveAgenticSections(assistant1, [tool1, assistant2, tool2, assistant3]);
// Turn 1: tool_call (no text since content is empty)
// Turn 2: text + tool_call
// Turn 3: reasoning + text
expect(sections).toHaveLength(5);
expect(sections[0].type).toBe(AgenticSectionType.TOOL_CALL);
expect(sections[0].toolName).toBe('list_files');
expect(sections[1].type).toBe(AgenticSectionType.TEXT);
expect(sections[1].content).toBe('Reading file1...');
expect(sections[2].type).toBe(AgenticSectionType.TOOL_CALL);
expect(sections[2].toolName).toBe('read_file');
expect(sections[3].type).toBe(AgenticSectionType.REASONING);
expect(sections[4].type).toBe(AgenticSectionType.TEXT);
expect(sections[4].content).toBe('Here is the analysis.');
});
it('multi-turn: streaming tool calls on last turn', () => {
const assistant1 = makeAssistant({
toolCalls: JSON.stringify([
{ id: 'call_1', type: 'function', function: { name: 'search', arguments: '{}' } }
])
});
const tool1 = makeToolMsg({ toolCallId: 'call_1', content: 'result' });
const assistant2 = makeAssistant({ id: 'ast-2', content: '' });
const streamingToolCalls: ApiChatCompletionToolCall[] = [
{ id: 'call_2', type: 'function', function: { name: 'write_file', arguments: '{"pa' } }
];
const sections = deriveAgenticSections(assistant1, [tool1, assistant2], streamingToolCalls);
// Turn 1: tool_call
// Turn 2 (streaming): streaming tool call
expect(sections.some((s) => s.type === AgenticSectionType.TOOL_CALL)).toBe(true);
expect(sections.some((s) => s.type === AgenticSectionType.TOOL_CALL_STREAMING)).toBe(true);
});
});
describe('hasAgenticContent', () => {
it('returns false for plain assistant', () => {
const msg = makeAssistant({ content: 'Just text' });
expect(hasAgenticContent(msg)).toBe(false);
});
it('returns true when message has toolCalls', () => {
const msg = makeAssistant({
toolCalls: JSON.stringify([
{ id: 'call_1', type: 'function', function: { name: 'test', arguments: '{}' } }
])
});
expect(hasAgenticContent(msg)).toBe(true);
});
it('returns true when toolMessages are provided', () => {
const msg = makeAssistant();
const tool = makeToolMsg();
expect(hasAgenticContent(msg, [tool])).toBe(true);
});
it('returns false for empty toolCalls JSON', () => {
const msg = makeAssistant({ toolCalls: '[]' });
expect(hasAgenticContent(msg)).toBe(false);
});
});

View File

@@ -1,17 +1,22 @@
import { describe, it, expect } from 'vitest';
import { AGENTIC_REGEX } from '$lib/constants/agentic';
import { LEGACY_AGENTIC_REGEX } from '$lib/constants/agentic';
// Mirror the logic in ChatService.stripReasoningContent so we can test it in isolation.
// The real function is private static, so we replicate the strip pipeline here.
function stripContextMarkers(content: string): string {
/**
* Tests for legacy marker stripping (used in migration).
* The new system does not embed markers in content - these tests verify
* the legacy regex patterns still work for the migration code.
*/
// Mirror the legacy stripping logic used during migration
function stripLegacyContextMarkers(content: string): string {
return content
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK, '')
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '');
.replace(new RegExp(LEGACY_AGENTIC_REGEX.REASONING_BLOCK.source, 'g'), '')
.replace(LEGACY_AGENTIC_REGEX.REASONING_OPEN, '')
.replace(new RegExp(LEGACY_AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK.source, 'g'), '')
.replace(LEGACY_AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '');
}
// A realistic complete tool call block as stored in message.content after a turn.
// A realistic complete tool call block as stored in old message.content
const COMPLETE_BLOCK =
'\n\n<<<AGENTIC_TOOL_CALL_START>>>\n' +
'<<<TOOL_NAME:bash_tool>>>\n' +
@@ -30,11 +35,10 @@ const OPEN_BLOCK =
'<<<TOOL_ARGS_END>>>\n' +
'partial output...';
describe('agentic marker stripping for context', () => {
describe('legacy agentic marker stripping (for migration)', () => {
it('strips a complete tool call block, leaving surrounding text', () => {
const input = 'Before.' + COMPLETE_BLOCK + 'After.';
const result = stripContextMarkers(input);
// markers gone; residual newlines between fragments are fine
const result = stripLegacyContextMarkers(input);
expect(result).not.toContain('<<<');
expect(result).toContain('Before.');
expect(result).toContain('After.');
@@ -42,7 +46,7 @@ describe('agentic marker stripping for context', () => {
it('strips multiple complete tool call blocks', () => {
const input = 'A' + COMPLETE_BLOCK + 'B' + COMPLETE_BLOCK + 'C';
const result = stripContextMarkers(input);
const result = stripLegacyContextMarkers(input);
expect(result).not.toContain('<<<');
expect(result).toContain('A');
expect(result).toContain('B');
@@ -51,19 +55,19 @@ describe('agentic marker stripping for context', () => {
it('strips an open/partial tool call block (no END marker)', () => {
const input = 'Lead text.' + OPEN_BLOCK;
const result = stripContextMarkers(input);
const result = stripLegacyContextMarkers(input);
expect(result).toBe('Lead text.');
expect(result).not.toContain('<<<');
});
it('does not alter content with no markers', () => {
const input = 'Just a normal assistant response.';
expect(stripContextMarkers(input)).toBe(input);
expect(stripLegacyContextMarkers(input)).toBe(input);
});
it('strips reasoning block independently', () => {
const input = '<<<reasoning_content_start>>>think hard<<<reasoning_content_end>>>Answer.';
expect(stripContextMarkers(input)).toBe('Answer.');
expect(stripLegacyContextMarkers(input)).toBe('Answer.');
});
it('strips both reasoning and agentic blocks together', () => {
@@ -71,11 +75,21 @@ describe('agentic marker stripping for context', () => {
'<<<reasoning_content_start>>>plan<<<reasoning_content_end>>>' +
'Some text.' +
COMPLETE_BLOCK;
expect(stripContextMarkers(input)).not.toContain('<<<');
expect(stripContextMarkers(input)).toContain('Some text.');
expect(stripLegacyContextMarkers(input)).not.toContain('<<<');
expect(stripLegacyContextMarkers(input)).toContain('Some text.');
});
it('empty string survives', () => {
expect(stripContextMarkers('')).toBe('');
expect(stripLegacyContextMarkers('')).toBe('');
});
it('detects legacy markers', () => {
expect(LEGACY_AGENTIC_REGEX.HAS_LEGACY_MARKERS.test('normal text')).toBe(false);
expect(
LEGACY_AGENTIC_REGEX.HAS_LEGACY_MARKERS.test('text<<<AGENTIC_TOOL_CALL_START>>>more')
).toBe(true);
expect(LEGACY_AGENTIC_REGEX.HAS_LEGACY_MARKERS.test('<<<reasoning_content_start>>>think')).toBe(
true
);
});
});

View File

@@ -1,196 +1,89 @@
import { describe, it, expect } from 'vitest';
import { AGENTIC_REGEX, REASONING_TAGS } from '$lib/constants/agentic';
import { ContentPartType } from '$lib/enums';
import { MessageRole } from '$lib/enums';
// Replicate ChatService.extractReasoningFromContent (private static)
function extractReasoningFromContent(
content: string | Array<{ type: string; text?: string }> | null | undefined
): string | undefined {
if (!content) return undefined;
/**
* Tests for the new reasoning content handling.
* In the new architecture, reasoning content is stored in a dedicated
* `reasoningContent` field on DatabaseMessage, not embedded in content with tags.
* The API sends it as `reasoning_content` on ApiChatMessageData.
*/
const extractFromString = (text: string): string => {
const parts: string[] = [];
const re = new RegExp(AGENTIC_REGEX.REASONING_EXTRACT.source);
let match = re.exec(text);
while (match) {
parts.push(match[1]);
text = text.slice(match.index + match[0].length);
match = re.exec(text);
}
return parts.join('');
};
if (typeof content === 'string') {
const result = extractFromString(content);
return result || undefined;
}
if (!Array.isArray(content)) return undefined;
const parts: string[] = [];
for (const part of content) {
if (part.type === ContentPartType.TEXT && part.text) {
const result = extractFromString(part.text);
if (result) parts.push(result);
}
}
return parts.length > 0 ? parts.join('') : undefined;
}
// Replicate ChatService.stripReasoningContent (private static)
function stripReasoningContent(
content: string | Array<{ type: string; text?: string }> | null | undefined
): typeof content {
if (!content) return content;
if (typeof content === 'string') {
return content
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK, '')
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '');
}
if (!Array.isArray(content)) return content;
return content.map((part) => {
if (part.type !== ContentPartType.TEXT || !part.text) return part;
return {
...part,
text: part.text
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK, '')
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '')
describe('reasoning content in new structured format', () => {
it('reasoning is stored as separate field, not in content', () => {
// Simulate what the new chat store does
const message = {
content: 'The answer is 4.',
reasoningContent: 'Let me think: 2+2=4, basic arithmetic.'
};
});
}
// Simulate the message mapping logic from ChatService.sendMessage
function buildApiMessage(
content: string,
excludeReasoningFromContext: boolean
): { role: string; content: string; reasoning_content?: string } {
const cleaned = stripReasoningContent(content) as string;
const mapped: { role: string; content: string; reasoning_content?: string } = {
role: 'assistant',
content: cleaned
};
if (!excludeReasoningFromContext) {
const reasoning = extractReasoningFromContent(content);
if (reasoning) {
mapped.reasoning_content = reasoning;
// Content should be clean
expect(message.content).not.toContain('<<<');
expect(message.content).toBe('The answer is 4.');
// Reasoning in dedicated field
expect(message.reasoningContent).toBe('Let me think: 2+2=4, basic arithmetic.');
});
it('convertDbMessageToApiChatMessageData includes reasoning_content', () => {
// Simulate the conversion logic
const dbMessage = {
role: MessageRole.ASSISTANT,
content: 'The answer is 4.',
reasoningContent: 'Let me think: 2+2=4, basic arithmetic.'
};
const apiMessage: Record<string, unknown> = {
role: dbMessage.role,
content: dbMessage.content
};
if (dbMessage.reasoningContent) {
apiMessage.reasoning_content = dbMessage.reasoningContent;
}
}
return mapped;
}
// Helper: wrap reasoning the same way the chat store does during streaming
function wrapReasoning(reasoning: string, content: string): string {
return `${REASONING_TAGS.START}${reasoning}${REASONING_TAGS.END}${content}`;
}
describe('reasoning content extraction', () => {
it('extracts reasoning from tagged string content', () => {
const input = wrapReasoning('step 1, step 2', 'The answer is 42.');
const result = extractReasoningFromContent(input);
expect(result).toBe('step 1, step 2');
expect(apiMessage.content).toBe('The answer is 4.');
expect(apiMessage.reasoning_content).toBe('Let me think: 2+2=4, basic arithmetic.');
// No internal tags leak into either field
expect(apiMessage.content).not.toContain('<<<');
expect(apiMessage.reasoning_content).not.toContain('<<<');
});
it('returns undefined when no reasoning tags present', () => {
expect(extractReasoningFromContent('Just a normal response.')).toBeUndefined();
it('API message excludes reasoning when excludeReasoningFromContext is true', () => {
const dbMessage = {
role: MessageRole.ASSISTANT,
content: 'The answer is 4.',
reasoningContent: 'internal thinking'
};
const excludeReasoningFromContext = true;
const apiMessage: Record<string, unknown> = {
role: dbMessage.role,
content: dbMessage.content
};
if (!excludeReasoningFromContext && dbMessage.reasoningContent) {
apiMessage.reasoning_content = dbMessage.reasoningContent;
}
expect(apiMessage.content).toBe('The answer is 4.');
expect(apiMessage.reasoning_content).toBeUndefined();
});
it('returns undefined for null/empty input', () => {
expect(extractReasoningFromContent(null)).toBeUndefined();
expect(extractReasoningFromContent(undefined)).toBeUndefined();
expect(extractReasoningFromContent('')).toBeUndefined();
});
it('handles messages with no reasoning', () => {
const dbMessage = {
role: MessageRole.ASSISTANT,
content: 'No reasoning here.',
reasoningContent: undefined
};
it('extracts reasoning from content part arrays', () => {
const input = [
{
type: ContentPartType.TEXT,
text: wrapReasoning('thinking hard', 'result')
}
];
expect(extractReasoningFromContent(input)).toBe('thinking hard');
});
const apiMessage: Record<string, unknown> = {
role: dbMessage.role,
content: dbMessage.content
};
if (dbMessage.reasoningContent) {
apiMessage.reasoning_content = dbMessage.reasoningContent;
}
it('handles multiple reasoning blocks', () => {
const input =
REASONING_TAGS.START +
'block1' +
REASONING_TAGS.END +
'middle' +
REASONING_TAGS.START +
'block2' +
REASONING_TAGS.END +
'end';
expect(extractReasoningFromContent(input)).toBe('block1block2');
});
it('ignores non-text content parts', () => {
const input = [{ type: 'image_url', text: wrapReasoning('hidden', 'img') }];
expect(extractReasoningFromContent(input)).toBeUndefined();
});
});
describe('strip reasoning content', () => {
it('removes reasoning tags from string content', () => {
const input = wrapReasoning('internal thoughts', 'visible answer');
expect(stripReasoningContent(input)).toBe('visible answer');
});
it('removes reasoning from content part arrays', () => {
const input = [
{
type: ContentPartType.TEXT,
text: wrapReasoning('thoughts', 'answer')
}
];
const result = stripReasoningContent(input) as Array<{ type: string; text?: string }>;
expect(result[0].text).toBe('answer');
});
});
describe('API message building with reasoning preservation', () => {
const storedContent = wrapReasoning('Let me think: 2+2=4, basic arithmetic.', 'The answer is 4.');
it('preserves reasoning_content when excludeReasoningFromContext is false', () => {
const msg = buildApiMessage(storedContent, false);
expect(msg.content).toBe('The answer is 4.');
expect(msg.reasoning_content).toBe('Let me think: 2+2=4, basic arithmetic.');
// no internal tags leak into either field
expect(msg.content).not.toContain('<<<');
expect(msg.reasoning_content).not.toContain('<<<');
});
it('strips reasoning_content when excludeReasoningFromContext is true', () => {
const msg = buildApiMessage(storedContent, true);
expect(msg.content).toBe('The answer is 4.');
expect(msg.reasoning_content).toBeUndefined();
});
it('handles content with no reasoning in both modes', () => {
const plain = 'No reasoning here.';
const msgPreserve = buildApiMessage(plain, false);
const msgExclude = buildApiMessage(plain, true);
expect(msgPreserve.content).toBe(plain);
expect(msgPreserve.reasoning_content).toBeUndefined();
expect(msgExclude.content).toBe(plain);
expect(msgExclude.reasoning_content).toBeUndefined();
});
it('cleans agentic tool call blocks from content even when preserving reasoning', () => {
const input =
wrapReasoning('plan', 'text') +
'\n\n<<<AGENTIC_TOOL_CALL_START>>>\n' +
'<<<TOOL_NAME:bash>>>\n' +
'<<<TOOL_ARGS_START>>>\n{}\n<<<TOOL_ARGS_END>>>\nout\n' +
'<<<AGENTIC_TOOL_CALL_END>>>\n';
const msg = buildApiMessage(input, false);
expect(msg.content).not.toContain('<<<');
expect(msg.reasoning_content).toBe('plan');
expect(apiMessage.content).toBe('No reasoning here.');
expect(apiMessage.reasoning_content).toBeUndefined();
});
});

View File

@@ -551,6 +551,8 @@ int main(int argc, char ** argv) {
params.sampling.top_k = 4;
params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K, };
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
return 1;
}
@@ -558,8 +560,6 @@ int main(int argc, char ** argv) {
const int n_parallel = params.n_parallel;
const int n_predict = params.n_predict;
common_init();
// init LLM
llama_backend_init();

View File

@@ -39,7 +39,7 @@ if (LLAMA_BUILD_BORINGSSL)
set(FIPS OFF CACHE BOOL "Enable FIPS (BoringSSL)")
set(BORINGSSL_GIT "https://boringssl.googlesource.com/boringssl" CACHE STRING "BoringSSL git repository")
set(BORINGSSL_VERSION "0.20260211.0" CACHE STRING "BoringSSL version")
set(BORINGSSL_VERSION "0.20260327.0" CACHE STRING "BoringSSL version")
message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}")