mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-23 16:37:33 +03:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9281dd135d | ||
|
|
0be6c7c9ce | ||
|
|
41361c8599 | ||
|
|
62278cedde | ||
|
|
90aa83c6bd | ||
|
|
fcc2d598c8 | ||
|
|
4453e77561 | ||
|
|
26dac845cc | ||
|
|
5ce013cd7e | ||
|
|
08f21453ae | ||
|
|
84ae8434d0 | ||
|
|
ead417f01c | ||
|
|
64ac9ab66a | ||
|
|
cad2d3884c | ||
|
|
389c7d4955 |
@@ -36,7 +36,7 @@ RUN mkdir -p /app/full \
|
||||
FROM ubuntu:$UBUNTU_VERSION AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl\
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=13.1.0
|
||||
ARG CUDA_VERSION=13.1.1
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
@@ -12,7 +12,9 @@ FROM ${BASE_CUDA_DEV_CONTAINER} AS build
|
||||
ARG CUDA_DOCKER_ARCH=default
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential cmake python3 python3-pip git libssl-dev libgomp1
|
||||
apt-get install -y gcc-14 g++-14 build-essential cmake python3 python3-pip git libssl-dev libgomp1
|
||||
|
||||
ENV CC=gcc-14 CXX=g++-14 CUDAHOSTCXX=g++-14
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -39,7 +41,7 @@ RUN mkdir -p /app/full \
|
||||
FROM ${BASE_CUDA_RUN_CONTAINER} AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl\
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=12.4.0
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
@@ -12,7 +12,9 @@ FROM ${BASE_CUDA_DEV_CONTAINER} AS build
|
||||
ARG CUDA_DOCKER_ARCH=default
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential cmake python3 python3-pip git libssl-dev libgomp1
|
||||
apt-get install -y gcc-14 g++-14 build-essential cmake python3 python3-pip git libssl-dev libgomp1
|
||||
|
||||
ENV CC=gcc-14 CXX=g++-14 CUDAHOSTCXX=g++-14
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -39,7 +41,7 @@ RUN mkdir -p /app/full \
|
||||
FROM ${BASE_CUDA_RUN_CONTAINER} AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl\
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
@@ -60,7 +62,8 @@ RUN apt-get update \
|
||||
git \
|
||||
python3 \
|
||||
python3-pip \
|
||||
&& pip install --upgrade pip setuptools wheel \
|
||||
python3-wheel \
|
||||
&& pip install --break-system-packages --upgrade setuptools \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
|
||||
@@ -51,7 +51,7 @@ RUN mkdir /tmp/neo/ && cd /tmp/neo/ \
|
||||
&& dpkg --install *.deb
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl\
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -46,7 +46,7 @@ RUN mkdir -p /app/full \
|
||||
FROM ${BASE_MUSA_RUN_CONTAINER} AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl\
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -78,7 +78,7 @@ ARG http_proxy
|
||||
ARG https_proxy
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 libtbb12 curl\
|
||||
&& apt-get install -y libgomp1 libtbb12 curl \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -58,7 +58,7 @@ RUN mkdir -p /app/full \
|
||||
FROM ${BASE_ROCM_DEV_CONTAINER} AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl\
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
@@ -79,7 +79,7 @@ RUN apt-get update \
|
||||
git \
|
||||
python3-pip \
|
||||
python3 \
|
||||
python3-wheel\
|
||||
python3-wheel \
|
||||
&& pip install --break-system-packages --upgrade setuptools \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
&& apt autoremove -y \
|
||||
|
||||
@@ -49,17 +49,20 @@ COPY --from=build /app/full /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV PATH="/root/.venv/bin:/root/.local/bin:${PATH}"
|
||||
|
||||
# Flag for compatibility with pip
|
||||
ARG UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
build-essential \
|
||||
curl \
|
||||
git \
|
||||
python3.13 \
|
||||
python3.13-dev \
|
||||
python3-pip \
|
||||
python3-wheel \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.13 100 \
|
||||
&& pip install --break-system-packages --upgrade setuptools \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
ca-certificates \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& uv python install 3.13 \
|
||||
&& uv venv --python 3.13 /root/.venv \
|
||||
&& uv pip install --python /root/.venv/bin/python -r requirements.txt \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
565
.github/workflows/docker.yml
vendored
565
.github/workflows/docker.yml
vendored
@@ -25,184 +25,13 @@ permissions:
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
push_to_registry:
|
||||
name: Push Docker image to Docker Hub
|
||||
|
||||
runs-on: ${{ matrix.config.runs_on }}
|
||||
env:
|
||||
COMMIT_SHA: ${{ github.sha }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
# Multi-stage build
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/arm64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "cuda cuda12", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04", cuda_version: "12.4.0", ubuntu_version: "22.04" }
|
||||
- { tag: "cuda13", dockerfile: ".devops/cuda-new.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04", cuda_version: "13.1.0", ubuntu_version: "24.04" }
|
||||
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "s390x", dockerfile: ".devops/s390x.Dockerfile", platforms: "linux/s390x", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04-s390x" }
|
||||
- { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "openvino", dockerfile: ".devops/openvino.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0 # preserve git history, so we can determine the build number
|
||||
|
||||
- name: Set up QEMU
|
||||
if: ${{ matrix.config.tag != 's390x' }}
|
||||
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3
|
||||
with:
|
||||
image: tonistiigi/binfmt:qemu-v10.2.1
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Determine source tag name
|
||||
id: srctag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
env:
|
||||
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Determine image tag name
|
||||
id: tag
|
||||
shell: bash
|
||||
run: |
|
||||
REPO_OWNER="${GITHUB_REPOSITORY_OWNER@L}" # to lower case
|
||||
REPO_NAME="${{ github.event.repository.name }}"
|
||||
PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:"
|
||||
|
||||
# list all tags possible
|
||||
tags="${{ matrix.config.tag }}"
|
||||
for tag in $tags; do
|
||||
if [[ "$tag" == "cpu" ]]; then
|
||||
TYPE=""
|
||||
else
|
||||
TYPE="-$tag"
|
||||
fi
|
||||
CACHETAGS="${PREFIX}buildcache${TYPE}"
|
||||
FULLTAGS="${FULLTAGS:+$FULLTAGS,}${PREFIX}full${TYPE},${PREFIX}full${TYPE}-${{ steps.srctag.outputs.name }}"
|
||||
LIGHTTAGS="${LIGHTTAGS:+$LIGHTTAGS,}${PREFIX}light${TYPE},${PREFIX}light${TYPE}-${{ steps.srctag.outputs.name }}"
|
||||
SERVERTAGS="${SERVERTAGS:+$SERVERTAGS,}${PREFIX}server${TYPE},${PREFIX}server${TYPE}-${{ steps.srctag.outputs.name }}"
|
||||
done
|
||||
echo "cache_output_tags=$CACHETAGS" >> $GITHUB_OUTPUT
|
||||
echo "full_output_tags=$FULLTAGS" >> $GITHUB_OUTPUT
|
||||
echo "light_output_tags=$LIGHTTAGS" >> $GITHUB_OUTPUT
|
||||
echo "server_output_tags=$SERVERTAGS" >> $GITHUB_OUTPUT
|
||||
echo "cache_output_tags=$CACHETAGS" # print out for debugging
|
||||
echo "full_output_tags=$FULLTAGS" # print out for debugging
|
||||
echo "light_output_tags=$LIGHTTAGS" # print out for debugging
|
||||
echo "server_output_tags=$SERVERTAGS" # print out for debugging
|
||||
env:
|
||||
GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}'
|
||||
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
if: ${{ matrix.config.free_disk_space == true }}
|
||||
uses: ggml-org/free-disk-space@v1.3.1
|
||||
with:
|
||||
# this might remove tools that are actually needed,
|
||||
# if set to "true" but frees about 6 GB
|
||||
tool-cache: false
|
||||
|
||||
# all of these default to true, but feel free to set to
|
||||
# "false" if necessary for your workflow
|
||||
android: true
|
||||
dotnet: true
|
||||
haskell: true
|
||||
large-packages: true
|
||||
docker-images: true
|
||||
swap-storage: true
|
||||
|
||||
- name: Build and push Full Docker image (tagged + versioned)
|
||||
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.full == true }}
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: ${{ matrix.config.platforms }}
|
||||
# tag list is generated from step above
|
||||
tags: ${{ steps.tag.outputs.full_output_tags }}
|
||||
file: ${{ matrix.config.dockerfile }}
|
||||
target: full
|
||||
provenance: false
|
||||
build-args: |
|
||||
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
|
||||
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
|
||||
# using github experimental cache
|
||||
#cache-from: type=gha
|
||||
#cache-to: type=gha,mode=max
|
||||
# return to this if the experimental github cache is having issues
|
||||
#cache-to: type=local,dest=/tmp/.buildx-cache
|
||||
#cache-from: type=local,src=/tmp/.buildx-cache
|
||||
# using registry cache (no storage limit)
|
||||
cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }}
|
||||
cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max
|
||||
|
||||
- name: Build and push Light Docker image (tagged + versioned)
|
||||
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }}
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: ${{ matrix.config.platforms }}
|
||||
# tag list is generated from step above
|
||||
tags: ${{ steps.tag.outputs.light_output_tags }}
|
||||
file: ${{ matrix.config.dockerfile }}
|
||||
target: light
|
||||
provenance: false
|
||||
build-args: |
|
||||
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
|
||||
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
|
||||
# using github experimental cache
|
||||
#cache-from: type=gha
|
||||
#cache-to: type=gha,mode=max
|
||||
# return to this if the experimental github cache is having issues
|
||||
#cache-to: type=local,dest=/tmp/.buildx-cache
|
||||
#cache-from: type=local,src=/tmp/.buildx-cache
|
||||
# using registry cache (no storage limit)
|
||||
cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }}
|
||||
cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max
|
||||
|
||||
- name: Build and push Server Docker image (tagged + versioned)
|
||||
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }}
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: ${{ matrix.config.platforms }}
|
||||
# tag list is generated from step above
|
||||
tags: ${{ steps.tag.outputs.server_output_tags }}
|
||||
file: ${{ matrix.config.dockerfile }}
|
||||
target: server
|
||||
provenance: false
|
||||
build-args: |
|
||||
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
|
||||
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
|
||||
# using github experimental cache
|
||||
#cache-from: type=gha
|
||||
#cache-to: type=gha,mode=max
|
||||
# return to this if the experimental github cache is having issues
|
||||
#cache-to: type=local,dest=/tmp/.buildx-cache
|
||||
#cache-from: type=local,src=/tmp/.buildx-cache
|
||||
# using registry cache (no storage limit)
|
||||
cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }}
|
||||
cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max
|
||||
|
||||
create_tag:
|
||||
name: Create and push git tag
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-slim
|
||||
permissions:
|
||||
contents: write
|
||||
outputs:
|
||||
source_tag: ${{ steps.srctag.outputs.name }}
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -223,3 +52,391 @@ jobs:
|
||||
run: |
|
||||
git tag ${{ steps.srctag.outputs.name }} || exit 0
|
||||
git push origin ${{ steps.srctag.outputs.name }} || exit 0
|
||||
|
||||
prepare_matrices:
|
||||
name: Prepare Docker matrices
|
||||
runs-on: ubuntu-24.04
|
||||
outputs:
|
||||
build_matrix: ${{ steps.matrices.outputs.build_matrix }}
|
||||
merge_matrix: ${{ steps.matrices.outputs.merge_matrix }}
|
||||
|
||||
steps:
|
||||
- name: Generate build and merge matrices
|
||||
id: matrices
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Keep all build targets in one place and derive merge targets from it.
|
||||
cat > build-matrix.json <<'JSON'
|
||||
[
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda-new.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda-new.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "musa", "dockerfile": ".devops/musa.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "intel", "dockerfile": ".devops/intel.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "vulkan", "dockerfile": ".devops/vulkan.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "vulkan", "dockerfile": ".devops/vulkan.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "rocm", "dockerfile": ".devops/rocm.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "openvino", "dockerfile": ".devops/openvino.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" }
|
||||
]
|
||||
JSON
|
||||
|
||||
BUILD_MATRIX="$(jq -c . build-matrix.json)"
|
||||
MERGE_MATRIX="$(jq -c '
|
||||
reduce .[] as $entry ({}; .[$entry.tag] |= (
|
||||
. // {
|
||||
tag: $entry.tag,
|
||||
arches: [],
|
||||
full: false,
|
||||
light: false,
|
||||
server: false
|
||||
}
|
||||
| .full = (.full or ($entry.full // false))
|
||||
| .light = (.light or ($entry.light // false))
|
||||
| .server = (.server or ($entry.server // false))
|
||||
| .arches += [($entry.platforms | sub("^linux/"; ""))]
|
||||
))
|
||||
# Backward compatibility: s390x tags are aliases of cpu for the linux/s390x platform.
|
||||
| if (has("cpu") and (((.cpu.arches // []) | index("s390x")) != null)) then
|
||||
. + {
|
||||
s390x: {
|
||||
tag: "s390x",
|
||||
arches: ["s390x"],
|
||||
full: .cpu.full,
|
||||
light: .cpu.light,
|
||||
server: .cpu.server
|
||||
}
|
||||
}
|
||||
else
|
||||
.
|
||||
end
|
||||
| [.[] | .arches = (.arches | unique | sort | join(" "))]
|
||||
' build-matrix.json)"
|
||||
|
||||
echo "build_matrix=$BUILD_MATRIX" >> "$GITHUB_OUTPUT"
|
||||
echo "merge_matrix=$MERGE_MATRIX" >> "$GITHUB_OUTPUT"
|
||||
|
||||
push_to_registry:
|
||||
name: Push Docker image to Docker Registry
|
||||
needs: [prepare_matrices, create_tag]
|
||||
|
||||
runs-on: ${{ matrix.config.runs_on }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config: ${{ fromJSON(needs.prepare_matrices.outputs.build_matrix) }}
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ needs.create_tag.outputs.source_tag }}
|
||||
|
||||
- name: Set up QEMU
|
||||
if: ${{ contains(matrix.config.platforms, 'linux/amd64') }}
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4
|
||||
with:
|
||||
image: tonistiigi/binfmt:qemu-v10.2.1
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
|
||||
|
||||
- name: Log in to Docker Registry
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Determine image metadata
|
||||
id: meta
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
REPO_OWNER="${GITHUB_REPOSITORY_OWNER@L}" # to lower case
|
||||
REPO_NAME="${{ github.event.repository.name }}"
|
||||
IMAGE_REPO="ghcr.io/${REPO_OWNER}/${REPO_NAME}"
|
||||
PREFIX="${IMAGE_REPO}:"
|
||||
PLATFORM="${{ matrix.config.platforms }}"
|
||||
ARCH_SUFFIX="${PLATFORM#linux/}"
|
||||
|
||||
# list all tags possible
|
||||
tags="${{ matrix.config.tag }}"
|
||||
for tag in $tags; do
|
||||
if [[ "$tag" == "cpu" ]]; then
|
||||
TYPE=""
|
||||
else
|
||||
TYPE="-$tag"
|
||||
fi
|
||||
CACHETAG="${PREFIX}buildcache${TYPE}-${ARCH_SUFFIX}"
|
||||
done
|
||||
|
||||
SAFE_TAGS="$(echo "$tags" | tr ' ' '_')"
|
||||
|
||||
echo "image_repo=$IMAGE_REPO" >> $GITHUB_OUTPUT
|
||||
echo "arch_suffix=$ARCH_SUFFIX" >> $GITHUB_OUTPUT
|
||||
echo "cache_output_tag=$CACHETAG" >> $GITHUB_OUTPUT
|
||||
echo "digest_artifact_suffix=${SAFE_TAGS}-${ARCH_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
echo "cache_output_tag=$CACHETAG" # print out for debugging
|
||||
env:
|
||||
GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}'
|
||||
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
if: ${{ matrix.config.free_disk_space == true }}
|
||||
uses: ggml-org/free-disk-space@v1.3.1
|
||||
with:
|
||||
# this might remove tools that are actually needed,
|
||||
# if set to "true" but frees about 6 GB
|
||||
tool-cache: false
|
||||
|
||||
# all of these default to true, but feel free to set to
|
||||
# "false" if necessary for your workflow
|
||||
android: true
|
||||
dotnet: true
|
||||
haskell: true
|
||||
large-packages: true
|
||||
docker-images: true
|
||||
swap-storage: true
|
||||
|
||||
- name: Build and push Full Docker image by digest
|
||||
id: build_full
|
||||
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.full == true }}
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.config.platforms }}
|
||||
outputs: type=image,name=${{ steps.meta.outputs.image_repo }},push-by-digest=true,name-canonical=true,push=true
|
||||
file: ${{ matrix.config.dockerfile }}
|
||||
target: full
|
||||
provenance: false
|
||||
build-args: |
|
||||
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
|
||||
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
|
||||
# using github experimental cache
|
||||
#cache-from: type=gha
|
||||
#cache-to: type=gha,mode=max
|
||||
# return to this if the experimental github cache is having issues
|
||||
#cache-to: type=local,dest=/tmp/.buildx-cache
|
||||
#cache-from: type=local,src=/tmp/.buildx-cache
|
||||
# using registry cache (no storage limit)
|
||||
cache-from: type=registry,ref=${{ steps.meta.outputs.cache_output_tag }}
|
||||
cache-to: type=registry,ref=${{ steps.meta.outputs.cache_output_tag }},mode=max
|
||||
|
||||
- name: Build and push Light Docker image by digest
|
||||
id: build_light
|
||||
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }}
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.config.platforms }}
|
||||
outputs: type=image,name=${{ steps.meta.outputs.image_repo }},push-by-digest=true,name-canonical=true,push=true
|
||||
file: ${{ matrix.config.dockerfile }}
|
||||
target: light
|
||||
provenance: false
|
||||
build-args: |
|
||||
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
|
||||
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
|
||||
# using github experimental cache
|
||||
#cache-from: type=gha
|
||||
#cache-to: type=gha,mode=max
|
||||
# return to this if the experimental github cache is having issues
|
||||
#cache-to: type=local,dest=/tmp/.buildx-cache
|
||||
#cache-from: type=local,src=/tmp/.buildx-cache
|
||||
# using registry cache (no storage limit)
|
||||
cache-from: type=registry,ref=${{ steps.meta.outputs.cache_output_tag }}
|
||||
cache-to: type=registry,ref=${{ steps.meta.outputs.cache_output_tag }},mode=max
|
||||
|
||||
- name: Build and push Server Docker image by digest
|
||||
id: build_server
|
||||
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }}
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.config.platforms }}
|
||||
outputs: type=image,name=${{ steps.meta.outputs.image_repo }},push-by-digest=true,name-canonical=true,push=true
|
||||
file: ${{ matrix.config.dockerfile }}
|
||||
target: server
|
||||
provenance: false
|
||||
build-args: |
|
||||
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
|
||||
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
|
||||
# using github experimental cache
|
||||
#cache-from: type=gha
|
||||
#cache-to: type=gha,mode=max
|
||||
# return to this if the experimental github cache is having issues
|
||||
#cache-to: type=local,dest=/tmp/.buildx-cache
|
||||
#cache-from: type=local,src=/tmp/.buildx-cache
|
||||
# using registry cache (no storage limit)
|
||||
cache-from: type=registry,ref=${{ steps.meta.outputs.cache_output_tag }}
|
||||
cache-to: type=registry,ref=${{ steps.meta.outputs.cache_output_tag }},mode=max
|
||||
|
||||
- name: Export digest metadata
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
TAGS="${{ matrix.config.tag }}"
|
||||
ARCH_SUFFIX="${{ steps.meta.outputs.arch_suffix }}"
|
||||
DIGEST_FILE="/tmp/digests/${{ steps.meta.outputs.digest_artifact_suffix }}.tsv"
|
||||
mkdir -p /tmp/digests
|
||||
|
||||
add_digest_rows() {
|
||||
local image_type="$1"
|
||||
local digest="$2"
|
||||
|
||||
if [[ -z "$digest" ]]; then
|
||||
echo "Missing digest for image_type=${image_type}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for tag in $TAGS; do
|
||||
printf '%s\t%s\t%s\t%s\n' "$tag" "$ARCH_SUFFIX" "$image_type" "$digest" >> "$DIGEST_FILE"
|
||||
done
|
||||
}
|
||||
|
||||
if [[ "${{ matrix.config.full }}" == "true" ]]; then
|
||||
add_digest_rows "full" "${{ steps.build_full.outputs.digest }}"
|
||||
fi
|
||||
|
||||
if [[ "${{ matrix.config.light }}" == "true" ]]; then
|
||||
add_digest_rows "light" "${{ steps.build_light.outputs.digest }}"
|
||||
fi
|
||||
|
||||
if [[ "${{ matrix.config.server }}" == "true" ]]; then
|
||||
add_digest_rows "server" "${{ steps.build_server.outputs.digest }}"
|
||||
fi
|
||||
|
||||
- name: Upload digest metadata
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7
|
||||
with:
|
||||
name: digests-${{ steps.meta.outputs.digest_artifact_suffix }}
|
||||
path: /tmp/digests/${{ steps.meta.outputs.digest_artifact_suffix }}.tsv
|
||||
if-no-files-found: error
|
||||
|
||||
merge_arch_tags:
|
||||
name: Create shared tags from digests
|
||||
needs: [prepare_matrices, push_to_registry, create_tag]
|
||||
runs-on: ubuntu-24.04
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config: ${{ fromJSON(needs.prepare_matrices.outputs.merge_matrix) }}
|
||||
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Download digest metadata
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
|
||||
with:
|
||||
pattern: digests-*
|
||||
path: /tmp/digests
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
|
||||
|
||||
- name: Log in to Docker Registry
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Create tags from digests
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
REPO_OWNER="${GITHUB_REPOSITORY_OWNER@L}" # to lower case
|
||||
REPO_NAME="${{ github.event.repository.name }}"
|
||||
IMAGE_REPO="ghcr.io/${REPO_OWNER}/${REPO_NAME}"
|
||||
PREFIX="${IMAGE_REPO}:"
|
||||
SRC_TAG="${{ needs.create_tag.outputs.source_tag }}"
|
||||
TAGS="${{ matrix.config.tag }}"
|
||||
ARCHES="${{ matrix.config.arches }}"
|
||||
DIGEST_GLOB="/tmp/digests/*.tsv"
|
||||
|
||||
if ! ls ${DIGEST_GLOB} >/dev/null 2>&1; then
|
||||
echo "No digest metadata found in /tmp/digests" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z "$SRC_TAG" ]]; then
|
||||
echo "Missing source tag from create_tag" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
find_digest() {
|
||||
local tag_name="$1"
|
||||
local arch="$2"
|
||||
local image_type="$3"
|
||||
local digest
|
||||
|
||||
digest="$(awk -F '\t' -v t="$tag_name" -v a="$arch" -v i="$image_type" '$1 == t && $2 == a && $3 == i { print $4; exit }' ${DIGEST_GLOB})"
|
||||
|
||||
# Backward compatibility: s390x tags are aliases of cpu for the linux/s390x platform.
|
||||
if [[ -z "$digest" && "$tag_name" == "s390x" && "$arch" == "s390x" ]]; then
|
||||
digest="$(awk -F '\t' -v t="cpu" -v a="$arch" -v i="$image_type" '$1 == t && $2 == a && $3 == i { print $4; exit }' ${DIGEST_GLOB})"
|
||||
fi
|
||||
|
||||
if [[ -z "$digest" ]]; then
|
||||
echo "Missing digest for tag=${tag_name} arch=${arch} image_type=${image_type}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "$digest"
|
||||
}
|
||||
|
||||
create_manifest_tags() {
|
||||
local image_type="$1"
|
||||
local tag_name="$2"
|
||||
local suffix="$3"
|
||||
|
||||
local merged_tag="${PREFIX}${image_type}${suffix}"
|
||||
local merged_versioned_tag="${merged_tag}-${SRC_TAG}"
|
||||
|
||||
local refs=()
|
||||
|
||||
for arch in $ARCHES; do
|
||||
local digest
|
||||
digest="$(find_digest "$tag_name" "$arch" "$image_type")"
|
||||
refs+=("${IMAGE_REPO}@${digest}")
|
||||
done
|
||||
|
||||
echo "Creating ${merged_tag} from ${refs[*]}"
|
||||
docker buildx imagetools create --tag "${merged_tag}" "${refs[@]}"
|
||||
|
||||
echo "Creating ${merged_versioned_tag} from ${refs[*]}"
|
||||
docker buildx imagetools create --tag "${merged_versioned_tag}" "${refs[@]}"
|
||||
}
|
||||
|
||||
for tag in $TAGS; do
|
||||
if [[ "$tag" == "cpu" ]]; then
|
||||
TYPE=""
|
||||
else
|
||||
TYPE="-$tag"
|
||||
fi
|
||||
|
||||
if [[ "${{ matrix.config.full }}" == "true" ]]; then
|
||||
create_manifest_tags "full" "$tag" "$TYPE"
|
||||
fi
|
||||
|
||||
if [[ "${{ matrix.config.light }}" == "true" ]]; then
|
||||
create_manifest_tags "light" "$tag" "$TYPE"
|
||||
fi
|
||||
|
||||
if [[ "${{ matrix.config.server }}" == "true" ]]; then
|
||||
create_manifest_tags "server" "$tag" "$TYPE"
|
||||
fi
|
||||
done
|
||||
env:
|
||||
GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}'
|
||||
|
||||
@@ -359,6 +359,11 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
}
|
||||
|
||||
void common_init() {
|
||||
#if defined(_WIN32)
|
||||
SetConsoleOutputCP(CP_UTF8);
|
||||
SetConsoleCP(CP_UTF8);
|
||||
#endif
|
||||
|
||||
llama_log_set(common_log_default_callback, NULL);
|
||||
|
||||
#ifdef NDEBUG
|
||||
@@ -367,7 +372,7 @@ void common_init() {
|
||||
const char * build_type = " (debug)";
|
||||
#endif
|
||||
|
||||
LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
|
||||
LOG_DBG("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
|
||||
}
|
||||
|
||||
std::string common_params_get_system_info(const common_params & params) {
|
||||
@@ -1243,6 +1248,9 @@ llama_context * common_init_result::context() {
|
||||
}
|
||||
|
||||
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
|
||||
if (seq_id < 0 || seq_id >= (int) pimpl->samplers.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
return pimpl->samplers[seq_id].get();
|
||||
}
|
||||
|
||||
|
||||
@@ -539,6 +539,9 @@ private:
|
||||
statement_ptr step = slices.size() > 2 ? std::move(slices[2]) : nullptr;
|
||||
return mk_stmt<slice_expression>(start_pos, std::move(start), std::move(stop), std::move(step));
|
||||
}
|
||||
if (slices.empty()) {
|
||||
return mk_stmt<blank_expression>(start_pos);
|
||||
}
|
||||
return std::move(slices[0]);
|
||||
}
|
||||
|
||||
|
||||
@@ -771,10 +771,15 @@ value member_expression::execute_impl(context & ctx) {
|
||||
}
|
||||
|
||||
JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
|
||||
ensure_key_type_allowed(property);
|
||||
|
||||
value val = mk_val<value_undefined>("object_property");
|
||||
|
||||
if (property->is_undefined()) {
|
||||
JJ_DEBUG("%s", "Member expression property is undefined, returning undefined");
|
||||
return val;
|
||||
}
|
||||
|
||||
ensure_key_type_allowed(property);
|
||||
|
||||
if (is_val<value_undefined>(object)) {
|
||||
JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined");
|
||||
return val;
|
||||
|
||||
@@ -263,6 +263,14 @@ struct comment_statement : public statement {
|
||||
|
||||
// Expressions
|
||||
|
||||
// Represents an omitted expression in a computed member, e.g. `a[]`.
|
||||
struct blank_expression : public expression {
|
||||
std::string type() const override { return "BlankExpression"; }
|
||||
value execute_impl(context &) override {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
};
|
||||
|
||||
struct member_expression : public expression {
|
||||
statement_ptr object;
|
||||
statement_ptr property;
|
||||
|
||||
@@ -383,6 +383,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
params.backend_sampling = false;
|
||||
}
|
||||
|
||||
if (rbudget && params.backend_sampling) {
|
||||
LOG_WRN("%s: backend sampling is not compatible with reasoning budget, disabling\n", __func__);
|
||||
|
||||
params.backend_sampling = false;
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
|
||||
@@ -13,24 +13,30 @@ We have three Docker images available for this project:
|
||||
|
||||
Additionally, there the following images, similar to the above:
|
||||
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-cuda`: Same as `full` but compiled with CUDA support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-cuda`: Same as `light` but compiled with CUDA support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-cuda`: Same as `server` but compiled with CUDA support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-rocm`: Same as `full` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-rocm`: Same as `light` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-rocm`: Same as `server` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-cuda`: Same as `full` but compiled with CUDA 12 support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-cuda13`: Same as `full` but compiled with CUDA 13 support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-cuda`: Same as `light` but compiled with CUDA 12 support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-cuda13`: Same as `light` but compiled with CUDA 13 support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-cuda`: Same as `server` but compiled with CUDA 12 support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-cuda13`: Same as `server` but compiled with CUDA 13 support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-rocm`: Same as `full` but compiled with ROCm support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-rocm`: Same as `light` but compiled with ROCm support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-rocm`: Same as `server` but compiled with ROCm support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-musa`: Same as `full` but compiled with MUSA support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-musa`: Same as `light` but compiled with MUSA support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-musa`: Same as `server` but compiled with MUSA support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-intel`: Same as `full` but compiled with SYCL support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-intel`: Same as `light` but compiled with SYCL support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-intel`: Same as `server` but compiled with SYCL support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-vulkan`: Same as `full` but compiled with Vulkan support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-vulkan`: Same as `light` but compiled with Vulkan support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-vulkan`: Same as `server` but compiled with Vulkan support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-vulkan`: Same as `full` but compiled with Vulkan support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-vulkan`: Same as `light` but compiled with Vulkan support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-vulkan`: Same as `server` but compiled with Vulkan support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-openvino`: Same as `full` but compiled with OpenVino support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-openvino`: Same as `light` but compiled with OpenVino support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-openvino`: Same as `server` but compiled with OpenVino support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-s390x`: Identical to `full`, an alias for the `s390x` platform. (platforms: `linux/s390x`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-s390x`: Identical to `light`, an alias for the `s390x` platform. (platforms: `linux/s390x`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-s390x`: Identical to `server`, an alias for the `s390x` platform. (platforms: `linux/s390x`)
|
||||
|
||||
The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now).
|
||||
|
||||
@@ -82,7 +88,7 @@ You may want to pass in some different `ARGS`, depending on the CUDA environment
|
||||
|
||||
The defaults are:
|
||||
|
||||
- `CUDA_VERSION` set to `12.4.0`
|
||||
- `CUDA_VERSION` set to `12.8.1`
|
||||
- `CUDA_DOCKER_ARCH` set to the cmake build default, which includes all the supported architectures
|
||||
|
||||
The resulting images, are essentially the same as the non-CUDA images:
|
||||
|
||||
@@ -24,12 +24,12 @@ int main(int argc, char ** argv) {
|
||||
params.prompt = "Hello my name is";
|
||||
params.n_predict = 32;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BATCHED, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// number of parallel batches
|
||||
int n_parallel = params.n_parallel;
|
||||
|
||||
|
||||
@@ -213,12 +213,12 @@ static bool run(llama_context * ctx, const common_params & params) {
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
|
||||
@@ -545,11 +545,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
llama_backend_init();
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
||||
@@ -99,12 +99,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
params.embedding = true;
|
||||
|
||||
// get max number of sequences per batch
|
||||
|
||||
@@ -37,12 +37,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
|
||||
@@ -19,12 +19,12 @@ static void print_usage(int /*argc*/, char ** argv) {
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// init LLM
|
||||
|
||||
llama_backend_init();
|
||||
|
||||
@@ -43,12 +43,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
const int W = 15; // lookahead window
|
||||
const int N = 5; // n-gram size
|
||||
const int G = 15; // max verification n-grams
|
||||
|
||||
@@ -12,6 +12,8 @@ int main(int argc, char ** argv){
|
||||
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -18,12 +18,12 @@ int main(int argc, char ** argv){
|
||||
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
const int n_draft = params.speculative.n_max;
|
||||
|
||||
// init llama.cpp
|
||||
|
||||
@@ -18,12 +18,12 @@ int main(int argc, char ** argv){
|
||||
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// max. number of additional tokens to draft if match is found
|
||||
const int n_draft = params.speculative.n_max;
|
||||
|
||||
|
||||
@@ -163,12 +163,12 @@ int main(int argc, char ** argv) {
|
||||
params.n_predict = 128;
|
||||
params.n_junk = 1;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// number of simultaneous "clients" to simulate
|
||||
const int32_t n_clients = params.n_parallel;
|
||||
|
||||
|
||||
@@ -25,12 +25,12 @@ int main(int argc, char ** argv) {
|
||||
params.n_keep = 32;
|
||||
params.i_pos = -1;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PASSKEY, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
int n_junk = params.n_junk;
|
||||
int n_keep = params.n_keep;
|
||||
int n_grp = params.grp_attn_n;
|
||||
|
||||
@@ -117,12 +117,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_RETRIEVAL, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// For BERT models, batch size must be equal to ubatch size
|
||||
params.n_ubatch = params.n_batch;
|
||||
params.embedding = true;
|
||||
|
||||
@@ -17,6 +17,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const std::string_view state_file = "dump_state.bin";
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -27,8 +29,6 @@ int main(int argc, char ** argv) {
|
||||
params.kv_unified = true;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.n_predict < 0) {
|
||||
params.n_predict = 16;
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -25,8 +27,6 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.speculative.mparams_dft.path.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
|
||||
@@ -38,6 +38,8 @@ int main(int argc, char ** argv) {
|
||||
// needed to get candidate probs even for temp <= 0.0
|
||||
params.sampling.n_probs = 128;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -47,8 +49,6 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.speculative.mparams_dft.path.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
|
||||
@@ -20,6 +20,8 @@ int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
params.escape = false;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -38,7 +40,6 @@ int main(int argc, char ** argv) {
|
||||
params.cache_type_v = GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
common_init();
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
// load the model and apply lora adapter, if any
|
||||
|
||||
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 9)
|
||||
set(GGML_VERSION_PATCH 8)
|
||||
set(GGML_VERSION_PATCH 9)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
@@ -47,9 +47,11 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
#ifdef STRIDED_ITERATOR_AVAILABLE
|
||||
auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
|
||||
#else
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
// offset_iterator needs to populate nrows + 1 elements, so we also have to ceildiv nrows + 1 by block_size
|
||||
const int nrows_offset = nrows + 1;
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows_offset);
|
||||
int * offset_iterator = offsets_alloc.get();
|
||||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
const dim3 offset_grid((nrows_offset + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows);
|
||||
#endif
|
||||
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
@@ -114,6 +114,8 @@ set(GGML_OPENCL_KERNELS
|
||||
gemv_noshuffle_q4_1_f32
|
||||
gemm_noshuffle_q4_1_f32
|
||||
gemv_noshuffle_general_q8_0_f32
|
||||
gemv_noshuffle_q4_k_f32
|
||||
gemm_noshuffle_q4_k_f32
|
||||
gemv_noshuffle_q6_k_f32
|
||||
gemm_noshuffle_q6_k_f32
|
||||
mul
|
||||
|
||||
@@ -538,6 +538,8 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_restore_block_q4_0_noshuffle;
|
||||
cl_kernel kernel_convert_block_q4_1_noshuffle;
|
||||
cl_kernel kernel_restore_block_q4_1_noshuffle;
|
||||
cl_kernel kernel_convert_block_q4_K_noshuffle;
|
||||
cl_kernel kernel_restore_block_q4_K_noshuffle;
|
||||
cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K;
|
||||
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
|
||||
@@ -720,6 +722,8 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_gemm_noshuffle_q4_1_f32;
|
||||
cl_kernel kernel_mul_mm_q8_0_f32_8x4;
|
||||
cl_kernel CL_mul_mat_vec_q8_0_f32;
|
||||
cl_kernel kernel_gemv_noshuffle_q4_k_f32;
|
||||
cl_kernel kernel_gemm_noshuffle_q4_k_f32;
|
||||
cl_kernel kernel_gemv_noshuffle_q6_K_f32;
|
||||
cl_kernel kernel_gemm_noshuffle_q6_K_f32;
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
@@ -932,6 +936,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err));
|
||||
@@ -2619,6 +2625,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemm_noshuffle_q4_k_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "gemm_noshuffle_q4_k_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("gemm_noshuffle_q4_k_f32.cl");
|
||||
#endif
|
||||
cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_k_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_k_f32", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemv_noshuffle_q4_k_f32
|
||||
{
|
||||
std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable ";
|
||||
if (backend_ctx->has_vector_subgroup_broadcast) {
|
||||
CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST ";
|
||||
}
|
||||
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "gemv_noshuffle_q4_k_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("gemv_noshuffle_q4_k_f32.cl");
|
||||
#endif
|
||||
|
||||
cl_program prog = build_program_from_source(
|
||||
backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_k_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_k_f32", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable "
|
||||
" -cl-fast-relaxed-math";
|
||||
@@ -5060,12 +5105,25 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K;
|
||||
if (use_adreno_kernels(backend_ctx, tensor)) {
|
||||
kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle;
|
||||
}
|
||||
#else
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K;
|
||||
#endif
|
||||
|
||||
cl_uchar mask_0F = 0x0F;
|
||||
cl_uchar mask_F0 = 0xF0;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
@@ -5076,6 +5134,20 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
|
||||
tensor->extra = extra;
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_kernels(backend_ctx, tensor)) {
|
||||
|
||||
int M = tensor->ne[1];
|
||||
int K = tensor->ne[0];
|
||||
|
||||
GGML_ASSERT(K % 32 == 0);
|
||||
|
||||
// Transpose q, d, dm as ushort
|
||||
transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M);
|
||||
transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/256, M);
|
||||
transpose_2d_as_16b(backend_ctx, extra->dm, extra->dm, size_dm, K/256, M);
|
||||
}
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q6_K) {
|
||||
@@ -5516,12 +5588,60 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
cl_uchar mask_0F = 0x0F;
|
||||
cl_uchar mask_F0 = 0xF0;
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_kernels(backend_ctx, tensor)) {
|
||||
int M = tensor->ne[1];
|
||||
int K = tensor->ne[0];
|
||||
|
||||
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
|
||||
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
|
||||
size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
|
||||
|
||||
static ggml_cl_buffer buf_trans_q;
|
||||
static ggml_cl_buffer buf_trans_d;
|
||||
static ggml_cl_buffer buf_trans_dm;
|
||||
|
||||
buf_trans_q.allocate(backend_ctx->context, size_q);
|
||||
buf_trans_d.allocate(backend_ctx->context, size_d);
|
||||
buf_trans_dm.allocate(backend_ctx->context, size_dm);
|
||||
|
||||
// Transpose q, d, dm back
|
||||
transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4);
|
||||
transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256);
|
||||
transpose_2d_as_16b(backend_ctx, extra->dm, buf_trans_dm.buffer, size_dm, M, K/256);
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K_noshuffle;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_dm.buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {1, 1, 1};
|
||||
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
|
||||
global_work_size, local_work_size, 0, NULL, NULL));
|
||||
CL_CHECK(clEnqueueReadBuffer(queue, data_device, CL_TRUE, offset,
|
||||
size, data, 0, NULL, NULL));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {1, 1, 1};
|
||||
@@ -9688,6 +9808,192 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat_q4_k_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
GGML_ASSERT(src1);
|
||||
GGML_ASSERT(src1->extra);
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
ggml_tensor_extra_cl_q4_K * extra0_q4_k = (ggml_tensor_extra_cl_q4_K *)src0->extra;
|
||||
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
|
||||
const int ne1 = dst->ne[1];
|
||||
|
||||
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
||||
|
||||
cl_context context = backend_ctx->context;
|
||||
cl_kernel kernel;
|
||||
|
||||
cl_int err;
|
||||
cl_image_format img_fmt;
|
||||
cl_image_desc img_desc;
|
||||
cl_buffer_region region;
|
||||
|
||||
int M = ne01;
|
||||
int N = ne1;
|
||||
int K = ne00;
|
||||
|
||||
cl_uchar mask_d6 = 0x3F;
|
||||
cl_uchar mask_d4 = 0x0F;
|
||||
cl_uchar mask_hi2 = 0xC0;
|
||||
|
||||
if (ne1 == 1) {
|
||||
cl_mem q_img = nullptr;
|
||||
cl_mem b_sub_buf = nullptr;
|
||||
cl_mem b_img = nullptr;
|
||||
|
||||
// image for q
|
||||
img_fmt = { CL_R, CL_UNSIGNED_INT32};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = M * K / 2 / 4;
|
||||
img_desc.buffer = extra0_q4_k->q;
|
||||
CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// subbuffer for activations
|
||||
region.origin = offset1;
|
||||
region.size = K * N * sizeof(float);
|
||||
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
|
||||
// image for activations
|
||||
img_fmt = {CL_RGBA, CL_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = K * N / 4;
|
||||
img_desc.buffer = b_sub_buf;
|
||||
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
kernel = backend_ctx->kernel_gemv_noshuffle_q4_k_f32;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_k->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_k->dm));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_k->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_d6));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_uchar), &mask_d4));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_hi2));
|
||||
|
||||
size_t local_work_size[3] = {64, 4, 1};
|
||||
size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
|
||||
CL_CHECK(clReleaseMemObject(q_img));
|
||||
CL_CHECK(clReleaseMemObject(b_sub_buf));
|
||||
CL_CHECK(clReleaseMemObject(b_img));
|
||||
} else {
|
||||
|
||||
cl_mem b_sub_buf = nullptr;
|
||||
cl_mem b_sub_buf_trans = nullptr;
|
||||
cl_mem b_img = nullptr;
|
||||
cl_mem b_img_trans = nullptr;
|
||||
|
||||
// subbuffer for activations
|
||||
region.origin = offset1;
|
||||
region.size = K * N * sizeof(float);
|
||||
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
|
||||
// image for activations
|
||||
img_fmt = {CL_RGBA, CL_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = K * N / 4;
|
||||
img_desc.buffer = b_sub_buf;
|
||||
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// pad N to multiple of 8
|
||||
int extra_elements = N % 8;
|
||||
int padding = 0;
|
||||
if (extra_elements > 0){
|
||||
padding = 8 - extra_elements;
|
||||
}
|
||||
|
||||
// subbuffer for transposed activations
|
||||
region.origin = 0;
|
||||
region.size = K * (N + padding) * sizeof(float)/2;
|
||||
backend_ctx->prealloc_act_trans.allocate(context, region.size);
|
||||
CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
|
||||
// image for transposed activations
|
||||
img_fmt = {CL_RGBA, CL_HALF_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = K * (N + padding) / 4;
|
||||
img_desc.buffer = b_sub_buf_trans;
|
||||
CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// transpose activations
|
||||
int height_B = N/4;
|
||||
if (height_B == 0) {
|
||||
height_B = 1;
|
||||
}
|
||||
int width_B = K/4;
|
||||
int padded_height_B = (N + padding)/4;
|
||||
|
||||
kernel = backend_ctx->kernel_transpose_32_16;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B));
|
||||
|
||||
size_t local_work_size_t[2] = { 1, 16 };
|
||||
size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B };
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst);
|
||||
|
||||
// gemm
|
||||
kernel = backend_ctx->kernel_gemm_noshuffle_q4_k_f32;
|
||||
int padded_N = N + padding;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_k->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_k->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_k->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_k->dm));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &padded_N));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_d6));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_d4));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_uchar), &mask_hi2));
|
||||
|
||||
size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1};
|
||||
size_t local_work_size[3] = {1, 128, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
CL_CHECK(clReleaseMemObject(b_sub_buf));
|
||||
CL_CHECK(clReleaseMemObject(b_sub_buf_trans));
|
||||
CL_CHECK(clReleaseMemObject(b_img));
|
||||
CL_CHECK(clReleaseMemObject(b_img_trans));
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(backend);
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat_q6_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
GGML_ASSERT(src0);
|
||||
@@ -10014,6 +10320,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
return;
|
||||
}
|
||||
|
||||
// q4_k x fp32
|
||||
if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) {
|
||||
ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
// q6_K x fp32
|
||||
if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) {
|
||||
ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst);
|
||||
|
||||
@@ -424,13 +424,17 @@ kernel void kernel_restore_block_q8_0_trans(
|
||||
// Convert the block_q4_K format to 4 separate arrays (AOS -> SOA).
|
||||
// This kernel does not deshuffle the bits.
|
||||
// Each thread processes a super block.
|
||||
// Mask args are just to keep the signature consistent with the no-shuffle
|
||||
// version and they are not used in this kernel.
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_convert_block_q4_K(
|
||||
global struct block_q4_K * src0,
|
||||
global uchar * dst_q,
|
||||
global uchar * dst_s,
|
||||
global half * dst_d,
|
||||
global half * dst_dm
|
||||
global half * dst_dm,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0
|
||||
) {
|
||||
global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0);
|
||||
global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0);
|
||||
@@ -451,12 +455,15 @@ kernel void kernel_convert_block_q4_K(
|
||||
|
||||
// Restore block_q4_K from flattened arrays.
|
||||
// Each thread processes a super block.
|
||||
// Mask args are just to keep the signature consistent with the no-shuffle ones.
|
||||
kernel void kernel_restore_block_q4_K(
|
||||
global uchar * src_q,
|
||||
global uchar * src_s,
|
||||
global half * src_d,
|
||||
global half * src_dm,
|
||||
global struct block_q4_K * dst
|
||||
global struct block_q4_K * dst,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0
|
||||
) {
|
||||
global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0);
|
||||
global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0);
|
||||
@@ -475,6 +482,70 @@ kernel void kernel_restore_block_q4_K(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_convert_block_q4_K_noshuffle(
|
||||
global struct block_q4_K * src0,
|
||||
global uchar * dst_q,
|
||||
global uchar * dst_s,
|
||||
global half * dst_d,
|
||||
global half * dst_dm,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0
|
||||
) {
|
||||
global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0);
|
||||
global uchar * q = (global uchar *) dst_q + QK_K/2 * get_global_id(0);
|
||||
global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE * get_global_id(0);
|
||||
global half * d = (global half *) dst_d + get_global_id(0);
|
||||
global half * dm = (global half *) dst_dm + get_global_id(0);
|
||||
|
||||
*d = b->d;
|
||||
*dm = b->dm;
|
||||
|
||||
for (int i = 0; i < QK_K / 64; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uchar x0 = b->q[i*32 + 2*j];
|
||||
uchar x1 = b->q[i*32 + 2*j + 1];
|
||||
q[i*32 + j] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
|
||||
q[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < K_SCALE_SIZE; ++i) {
|
||||
s[i] = b->s[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q4_K_noshuffle(
|
||||
global uchar * src_q,
|
||||
global uchar * src_s,
|
||||
global half * src_d,
|
||||
global half * src_dm,
|
||||
global struct block_q4_K * dst,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0
|
||||
) {
|
||||
global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0);
|
||||
global uchar * q = (global uchar *) src_q + QK_K/2 * get_global_id(0);
|
||||
global uchar * s = (global uchar *) src_s + K_SCALE_SIZE * get_global_id(0);
|
||||
global half * d = (global half *) src_d + get_global_id(0);
|
||||
global half * dm = (global half *) src_dm + get_global_id(0);
|
||||
|
||||
b->d = *d;
|
||||
b->dm = *dm;
|
||||
|
||||
for (int i = 0; i < QK_K / 64; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uchar lo = q[i*32 + j];
|
||||
uchar hi = q[i*32 + j + 16];
|
||||
b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4));
|
||||
b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < K_SCALE_SIZE; ++i) {
|
||||
b->s[i] = s[i];
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_q6_K
|
||||
// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA).
|
||||
|
||||
172
ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl
Normal file
172
ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl
Normal file
@@ -0,0 +1,172 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_qcom_reqd_sub_group_size
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
|
||||
inline void get_scale_min_k4(
|
||||
int j,
|
||||
global const uchar * q,
|
||||
uchar * d,
|
||||
uchar * m,
|
||||
uchar mask_d6,
|
||||
uchar mask_d4,
|
||||
uchar mask_hi2
|
||||
) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & mask_d6;
|
||||
*m = q[j+4] & mask_d6;
|
||||
} else {
|
||||
*d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2);
|
||||
*m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
kernel void kernel_gemm_noshuffle_q4_k_f32(
|
||||
global const ushort * src0_q,
|
||||
global const uchar * src0_s,
|
||||
global const half * src0_d,
|
||||
global const half * src0_dm,
|
||||
read_only image1d_buffer_t src1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int n_no_padding,
|
||||
uchar mask_d6,
|
||||
uchar mask_d4,
|
||||
uchar mask_hi2
|
||||
) {
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
int n_4 = n >> 2;
|
||||
int gy = get_global_id(0);
|
||||
int gx = get_global_id(1);
|
||||
int gx_2 = gx << 2;
|
||||
|
||||
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
|
||||
half8 B;
|
||||
half4 dequantized_weights;
|
||||
|
||||
int num_blocks_K = k / QK_K;
|
||||
|
||||
global const ushort * weight_ptr = src0_q + gx_2;
|
||||
global const half * d_ptr = src0_d + gx_2;
|
||||
global const half * dm_ptr = src0_dm + gx_2;
|
||||
|
||||
for (int i = 0; i < k; i += 32) {
|
||||
int sb_idx = i / QK_K;
|
||||
int sub_idx = (i / 32) % 8;
|
||||
|
||||
half4 d = vload4(0, d_ptr + sb_idx * m);
|
||||
half4 dm = vload4(0, dm_ptr + sb_idx * m);
|
||||
|
||||
global const uchar * sc0 = src0_s + (gx_2+0) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
|
||||
global const uchar * sc1 = src0_s + (gx_2+1) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
|
||||
global const uchar * sc2 = src0_s + (gx_2+2) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
|
||||
global const uchar * sc3 = src0_s + (gx_2+3) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
|
||||
|
||||
uchar sv0, mn0, sv1, mn1, sv2, mn2, sv3, mn3;
|
||||
get_scale_min_k4(sub_idx, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2);
|
||||
get_scale_min_k4(sub_idx, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2);
|
||||
get_scale_min_k4(sub_idx, sc2, &sv2, &mn2, mask_d6, mask_d4, mask_hi2);
|
||||
get_scale_min_k4(sub_idx, sc3, &sv3, &mn3, mask_d6, mask_d4, mask_hi2);
|
||||
|
||||
half4 scale = convert_half4(convert_float4(d) * convert_float4((uchar4)(sv0, sv1, sv2, sv3)));
|
||||
half4 mval = convert_half4(convert_float4(dm) * convert_float4((uchar4)(mn0, mn1, mn2, mn3)));
|
||||
|
||||
for (int l = 0; l < 32; l += 4) {
|
||||
int ki = i + l;
|
||||
ushort4 bits4 = vload4(0, weight_ptr + (ki/4) * m);
|
||||
|
||||
// j=0
|
||||
B.s0123 = read_imageh(src1, gy*2 + (ki+0) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy*2+1 + (ki+0) * n_4);
|
||||
dequantized_weights.s0 = (bits4.s0 & 0x000F) * scale.s0 - mval.s0;
|
||||
dequantized_weights.s1 = (bits4.s1 & 0x000F) * scale.s1 - mval.s1;
|
||||
dequantized_weights.s2 = (bits4.s2 & 0x000F) * scale.s2 - mval.s2;
|
||||
dequantized_weights.s3 = (bits4.s3 & 0x000F) * scale.s3 - mval.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
|
||||
// j=1
|
||||
B.s0123 = read_imageh(src1, gy*2 + (ki+1) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy*2+1 + (ki+1) * n_4);
|
||||
dequantized_weights.s0 = ((bits4.s0 & 0x00F0) >> 4) * scale.s0 - mval.s0;
|
||||
dequantized_weights.s1 = ((bits4.s1 & 0x00F0) >> 4) * scale.s1 - mval.s1;
|
||||
dequantized_weights.s2 = ((bits4.s2 & 0x00F0) >> 4) * scale.s2 - mval.s2;
|
||||
dequantized_weights.s3 = ((bits4.s3 & 0x00F0) >> 4) * scale.s3 - mval.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
|
||||
// j=2
|
||||
B.s0123 = read_imageh(src1, gy*2 + (ki+2) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy*2+1 + (ki+2) * n_4);
|
||||
dequantized_weights.s0 = ((bits4.s0 & 0x0F00) >> 8) * scale.s0 - mval.s0;
|
||||
dequantized_weights.s1 = ((bits4.s1 & 0x0F00) >> 8) * scale.s1 - mval.s1;
|
||||
dequantized_weights.s2 = ((bits4.s2 & 0x0F00) >> 8) * scale.s2 - mval.s2;
|
||||
dequantized_weights.s3 = ((bits4.s3 & 0x0F00) >> 8) * scale.s3 - mval.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
|
||||
// j=3
|
||||
B.s0123 = read_imageh(src1, gy*2 + (ki+3) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy*2+1 + (ki+3) * n_4);
|
||||
dequantized_weights.s0 = ((bits4.s0 & 0xF000) >> 12) * scale.s0 - mval.s0;
|
||||
dequantized_weights.s1 = ((bits4.s1 & 0xF000) >> 12) * scale.s1 - mval.s1;
|
||||
dequantized_weights.s2 = ((bits4.s2 & 0xF000) >> 12) * scale.s2 - mval.s2;
|
||||
dequantized_weights.s3 = ((bits4.s3 & 0xF000) >> 12) * scale.s3 - mval.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
}
|
||||
}
|
||||
|
||||
int idx = (gy<<3)*m + (gx<<2);
|
||||
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
|
||||
}
|
||||
}
|
||||
318
ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl
Normal file
318
ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl
Normal file
@@ -0,0 +1,318 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
|
||||
#ifdef cl_qcom_reqd_sub_group_size
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#endif
|
||||
|
||||
#define QK_K 256
|
||||
#define NSUBGROUPS 4
|
||||
#define SUBGROUP_SIZE 64
|
||||
|
||||
inline void get_scale_min_k4(
|
||||
int j,
|
||||
global const uchar * q,
|
||||
uchar * d,
|
||||
uchar * m,
|
||||
uchar mask_d6,
|
||||
uchar mask_d4,
|
||||
uchar mask_hi2
|
||||
) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & mask_d6;
|
||||
*m = q[j+4] & mask_d6;
|
||||
} else {
|
||||
*d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2);
|
||||
*m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2);
|
||||
}
|
||||
}
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, minv, y) \
|
||||
float shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 0); \
|
||||
total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 0); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 0); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 0); \
|
||||
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 0); \
|
||||
total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 0); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 0); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 0); \
|
||||
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 1); \
|
||||
total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 1); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 1); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 1); \
|
||||
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 1); \
|
||||
total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 1); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 1); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 1); \
|
||||
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, minv, y) \
|
||||
shared_y = sub_group_broadcast(y.s0, 2); \
|
||||
total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 2); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 2); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 2); \
|
||||
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 2); \
|
||||
total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 2); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 2); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 2); \
|
||||
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 3); \
|
||||
total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 3); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 3); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 3); \
|
||||
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 3); \
|
||||
total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 3); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 3); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 3); \
|
||||
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, minv, y) \
|
||||
float8 shared_y; \
|
||||
shared_y = sub_group_broadcast(y, 0); \
|
||||
total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \
|
||||
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \
|
||||
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \
|
||||
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \
|
||||
total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \
|
||||
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \
|
||||
total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \
|
||||
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \
|
||||
total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \
|
||||
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \
|
||||
shared_y = sub_group_broadcast(y, 1); \
|
||||
total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \
|
||||
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \
|
||||
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \
|
||||
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \
|
||||
total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \
|
||||
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \
|
||||
total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \
|
||||
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \
|
||||
total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \
|
||||
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, minv, y) \
|
||||
shared_y = sub_group_broadcast(y, 2); \
|
||||
total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \
|
||||
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \
|
||||
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \
|
||||
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \
|
||||
total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \
|
||||
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \
|
||||
total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \
|
||||
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \
|
||||
total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \
|
||||
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \
|
||||
shared_y = sub_group_broadcast(y, 3); \
|
||||
total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \
|
||||
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \
|
||||
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \
|
||||
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \
|
||||
total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \
|
||||
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \
|
||||
total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \
|
||||
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \
|
||||
total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \
|
||||
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_gemv_noshuffle_q4_k_f32(
|
||||
read_only image1d_buffer_t src0_q,
|
||||
global half2 * src0_d,
|
||||
global half2 * src0_m,
|
||||
global uchar * src0_s,
|
||||
read_only image1d_buffer_t src1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
uchar mask_d6,
|
||||
uchar mask_d4,
|
||||
uchar mask_hi2)
|
||||
{
|
||||
uint groupId = get_local_id(1);
|
||||
uint gid = get_global_id(0);
|
||||
ushort slid = get_sub_group_local_id();
|
||||
|
||||
uint K = ne00;
|
||||
uint M = ne01;
|
||||
|
||||
uint LINE_STRIDE_A = M / 2;
|
||||
uint BLOCK_STRIDE_A = NSUBGROUPS * M;
|
||||
uint scales_per_row = (K / QK_K) * 12;
|
||||
|
||||
private uint4 regA;
|
||||
private half2 regS;
|
||||
private half2 regM;
|
||||
private float8 regB;
|
||||
|
||||
private float2 totalSum = (float2)(0.0f);
|
||||
|
||||
for (uint k = groupId; k < (K / 32); k += NSUBGROUPS) {
|
||||
uint sb = k / 8;
|
||||
uint j = k % 8;
|
||||
|
||||
half2 d = src0_d[gid + sb * LINE_STRIDE_A];
|
||||
half2 dm = src0_m[gid + sb * LINE_STRIDE_A];
|
||||
|
||||
global const uchar * sc0 = src0_s + 2 * gid * scales_per_row + sb * 12;
|
||||
global const uchar * sc1 = src0_s + (2 * gid + 1) * scales_per_row + sb * 12;
|
||||
|
||||
uchar sv0, mn0, sv1, mn1;
|
||||
get_scale_min_k4(j, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2);
|
||||
get_scale_min_k4(j, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2);
|
||||
|
||||
regS = convert_half2(convert_float2(d) * convert_float2((uchar2)(sv0, sv1)));
|
||||
regM = convert_half2(convert_float2(dm) * convert_float2((uchar2)(mn0, mn1)));
|
||||
|
||||
if (slid < 4) {
|
||||
regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
|
||||
regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
|
||||
}
|
||||
|
||||
// load half weights for two blocks in consecutive rows
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAST
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regM, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regM, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAST
|
||||
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAST
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regM, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regM, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAST
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
local float2 reduceLM[SUBGROUP_SIZE * 3];
|
||||
if (groupId == 1) {
|
||||
reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum;
|
||||
}
|
||||
if (groupId == 2) {
|
||||
reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum;
|
||||
}
|
||||
if (groupId == 3) {
|
||||
reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (groupId == 0) {
|
||||
totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid];
|
||||
}
|
||||
if (groupId == 0) {
|
||||
totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid];
|
||||
}
|
||||
if (groupId == 0) {
|
||||
totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid];
|
||||
}
|
||||
|
||||
// 2 outputs per fiber in wave 0
|
||||
if (groupId == 0) {
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
vstore2(totalSum, 0, &(dst[gid * 2]));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1340,7 +1340,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
|
||||
if (buffer && buffer->iface.init_tensor) {
|
||||
buffer->iface.init_tensor(buffer, tensor);
|
||||
} else {
|
||||
GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
|
||||
if (!buffer) {
|
||||
GGML_LOG_ERROR("Tensor with null buffer passed to init_tensor function\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (tensor->extra != nullptr) {
|
||||
|
||||
@@ -70,6 +70,7 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -310,11 +311,11 @@ static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const
|
||||
sycl::half2 * const __restrict__ tile_KV,
|
||||
const int stride_KV,
|
||||
const int i_sup) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
auto load = [&] (const int n) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int stride_j = warp_size >> n;
|
||||
|
||||
if (stride_j == 0) {
|
||||
@@ -455,7 +456,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
|
||||
|
||||
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
|
||||
(K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
#ifdef SYCL_FAST_FP16
|
||||
static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
|
||||
@@ -505,7 +506,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
|
||||
}
|
||||
|
||||
if (k_KQ_0 + nbatch_K < DKQ) {
|
||||
item_ct1.barrier(); // Sync not needed on last iteration.
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space); // Sync not needed on last iteration.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -545,7 +546,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
|
||||
const int k_VKQ_max,
|
||||
const int col_Q_0,
|
||||
float * KQ_max_new_shared) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
@@ -620,14 +621,14 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
|
||||
}
|
||||
|
||||
if constexpr (np == 1) {
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
} else {
|
||||
static_assert(cpw == 1, "bad cpw");
|
||||
|
||||
if (item_ct1.get_local_id(2) == 0) {
|
||||
KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0];
|
||||
}
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np];
|
||||
KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
|
||||
}
|
||||
@@ -697,7 +698,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
|
||||
for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
|
||||
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
|
||||
(V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
#ifdef SYCL_FAST_FP16
|
||||
#pragma unroll
|
||||
@@ -765,7 +766,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
|
||||
}
|
||||
}
|
||||
#endif // SYCL_FAST_FP16
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -972,7 +973,7 @@ static void flash_attn_tile(const char * Q,
|
||||
}
|
||||
}
|
||||
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Main loop over KV cache:
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
|
||||
@@ -1051,7 +1052,7 @@ static void flash_attn_tile(const char * Q,
|
||||
return;
|
||||
}
|
||||
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
#pragma unroll
|
||||
for (int ip = 1; ip < np; ++ip) {
|
||||
@@ -1193,37 +1194,39 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm
|
||||
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
if (Q->ne[1] > 16/ncols2) {
|
||||
constexpr int cols_per_block = 32;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
if (DV < 512 && Q->ne[1] < 32) {
|
||||
if constexpr (ncols2 <= 32) {
|
||||
if (Q->ne[1] > 16/ncols2) {
|
||||
constexpr int cols_per_block = 32;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (Q->ne[1] > 8/ncols2) {
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (ncols2 <= 8) {
|
||||
if (Q->ne[1] > 4/ncols2) {
|
||||
constexpr int cols_per_block = 8;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
if constexpr (ncols2 <= 16) {
|
||||
if (Q->ne[1] > 8/ncols2) {
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if constexpr (ncols2 <= 8) {
|
||||
if (Q->ne[1] > 4/ncols2) {
|
||||
constexpr int cols_per_block = 8;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
c044a8eeae2591faa0950c8b5e514cbc4bbfc4ca
|
||||
a04eea0761a85d18f3f504d6ab970c5c9dce705f
|
||||
|
||||
@@ -118,12 +118,12 @@ int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
params.out_file = "tests.txt";
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EXPORT_GRAPH_OPS)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// Load CPU-only
|
||||
ggml_backend_dev_t cpu_device = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
params.devices = { cpu_device, nullptr };
|
||||
|
||||
@@ -8424,6 +8424,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 256, 1, 1}, order)); // test ceildiv in CUDA's CUB's DeviceSegmentedSort
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
|
||||
|
||||
@@ -387,6 +387,24 @@ static void test_expressions(testing & t) {
|
||||
"Bob"
|
||||
);
|
||||
|
||||
test_template(t, "empty computed member defaults to undefined",
|
||||
"{{ a[]|default('fallback') }}",
|
||||
{{"a", {{"name", "Bob"}}}},
|
||||
"fallback"
|
||||
);
|
||||
|
||||
test_template(t, "empty computed member is undefined",
|
||||
"{{ a[] is undefined }}",
|
||||
{{"a", {{"name", "Bob"}}}},
|
||||
"True"
|
||||
);
|
||||
|
||||
test_template(t, "undefined computed member is undefined",
|
||||
"{{ a[undefined] is undefined }}",
|
||||
{{"a", {{"name", "Bob"}}}},
|
||||
"True"
|
||||
);
|
||||
|
||||
test_template(t, "array access",
|
||||
"{{ items[1] }}",
|
||||
{{"items", json::array({"a", "b", "c"})}},
|
||||
|
||||
@@ -22,12 +22,12 @@ int main(int argc, char ** argv) {
|
||||
params.n_parallel = 3;
|
||||
params.n_ctx = 256;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// init
|
||||
common_init_result_ptr llama_init = common_init_from_params(params);
|
||||
|
||||
|
||||
@@ -16,12 +16,12 @@
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
|
||||
@@ -20,12 +20,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BENCH, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
int is_pp_shared = params.is_pp_shared;
|
||||
int is_tg_separate = params.is_tg_separate;
|
||||
|
||||
|
||||
@@ -347,6 +347,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
params.verbosity = LOG_LEVEL_ERROR; // by default, less verbose logs
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CLI)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -357,8 +359,6 @@ int main(int argc, char ** argv) {
|
||||
console::error("please use llama-completion instead\n");
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// struct that contains llama context and inference
|
||||
cli_context ctx_cli(params);
|
||||
|
||||
|
||||
@@ -90,12 +90,12 @@ int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
g_params = ¶ms;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMPLETION, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
auto & sparams = params.sampling;
|
||||
|
||||
// save choice to use color for later
|
||||
@@ -146,19 +146,13 @@ int main(int argc, char ** argv) {
|
||||
|
||||
ctx = llama_init->context();
|
||||
model = llama_init->model();
|
||||
smpl = llama_init->sampler(0);
|
||||
|
||||
if (ctx == NULL) {
|
||||
LOG_ERR("%s: error: unable to create context\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: error: unable to load model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
smpl = llama_init->sampler(0);
|
||||
|
||||
llama_memory_t mem = llama_get_memory(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
|
||||
@@ -400,6 +400,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
params.out_file = "control_vector.gguf";
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CVECTOR_GENERATOR, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -418,6 +418,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
params.out_file = "ggml-lora-merged-f16.gguf";
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EXPORT_LORA, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -17,11 +17,12 @@ using namespace std::chrono_literals;
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
auto mparams = common_model_params_to_llama(params);
|
||||
|
||||
@@ -1212,6 +1212,8 @@ int main(int argc, char ** argv) {
|
||||
params.n_ctx = 512;
|
||||
params.escape = false;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_IMATRIX, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -1223,8 +1225,6 @@ int main(int argc, char ** argv) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
const int32_t n_ctx = params.n_ctx;
|
||||
|
||||
if (n_ctx <= 0) {
|
||||
|
||||
@@ -54,11 +54,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
|
||||
if (params.mmproj.path.empty()) {
|
||||
|
||||
@@ -281,11 +281,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
|
||||
if (params.mmproj.path.empty()) {
|
||||
|
||||
@@ -2012,12 +2012,12 @@ int main(int argc, char ** argv) {
|
||||
params.n_ctx = 512;
|
||||
params.escape = false;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
const int32_t n_ctx = params.n_ctx;
|
||||
|
||||
if (n_ctx <= 0) {
|
||||
|
||||
@@ -58,6 +58,9 @@ static std::vector<float> get_logits(
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
params.escape = false;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_RESULTS)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -65,7 +68,6 @@ int main(int argc, char ** argv) {
|
||||
LOG_ERR("%s: an output file must be specified", __func__);
|
||||
return 1;
|
||||
}
|
||||
common_init();
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
common_init_result_ptr llama_init = common_init_from_params(params);
|
||||
|
||||
Binary file not shown.
@@ -75,6 +75,8 @@ int main(int argc, char ** argv) {
|
||||
// own arguments required by this example
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -100,8 +102,6 @@ int main(int argc, char ** argv) {
|
||||
params.model_alias.insert(params.model.name);
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// struct that contains llama context and inference
|
||||
server_context ctx_server;
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import { getChatActionsContext, setMessageEditContext } from '$lib/contexts';
|
||||
import { chatStore, pendingEditMessageId } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { DatabaseService } from '$lib/services';
|
||||
import { DatabaseService } from '$lib/services/database.service';
|
||||
import { SYSTEM_MESSAGE_PLACEHOLDER } from '$lib/constants';
|
||||
import { MessageRole, AttachmentType } from '$lib/enums';
|
||||
import {
|
||||
@@ -19,6 +19,7 @@
|
||||
interface Props {
|
||||
class?: string;
|
||||
message: DatabaseMessage;
|
||||
toolMessages?: DatabaseMessage[];
|
||||
isLastAssistantMessage?: boolean;
|
||||
siblingInfo?: ChatMessageSiblingInfo | null;
|
||||
}
|
||||
@@ -26,6 +27,7 @@
|
||||
let {
|
||||
class: className = '',
|
||||
message,
|
||||
toolMessages = [],
|
||||
isLastAssistantMessage = false,
|
||||
siblingInfo = null
|
||||
}: Props = $props();
|
||||
@@ -302,6 +304,7 @@
|
||||
{deletionInfo}
|
||||
{isLastAssistantMessage}
|
||||
{message}
|
||||
{toolMessages}
|
||||
messageContent={message.content}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onContinue={handleContinue}
|
||||
|
||||
@@ -6,42 +6,42 @@
|
||||
SyntaxHighlightedCode
|
||||
} from '$lib/components/app';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { Wrench, Loader2, AlertTriangle, Brain } from '@lucide/svelte';
|
||||
import { AgenticSectionType, AttachmentType, FileTypeText } from '$lib/enums';
|
||||
import { Wrench, Loader2, Brain } from '@lucide/svelte';
|
||||
import { AgenticSectionType, FileTypeText } from '$lib/enums';
|
||||
import { formatJsonPretty } from '$lib/utils';
|
||||
import { ATTACHMENT_SAVED_REGEX, NEWLINE_SEPARATOR } from '$lib/constants';
|
||||
import { parseAgenticContent, type AgenticSection } from '$lib/utils';
|
||||
import type { DatabaseMessage, DatabaseMessageExtraImageFile } from '$lib/types/database';
|
||||
import {
|
||||
deriveAgenticSections,
|
||||
parseToolResultWithImages,
|
||||
type AgenticSection,
|
||||
type ToolResultLine
|
||||
} from '$lib/utils';
|
||||
import type { DatabaseMessage } from '$lib/types/database';
|
||||
import type { ChatMessageAgenticTimings, ChatMessageAgenticTurnStats } from '$lib/types/chat';
|
||||
import { ChatMessageStatsView } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
message?: DatabaseMessage;
|
||||
content: string;
|
||||
message: DatabaseMessage;
|
||||
toolMessages?: DatabaseMessage[];
|
||||
isStreaming?: boolean;
|
||||
highlightTurns?: boolean;
|
||||
}
|
||||
|
||||
type ToolResultLine = {
|
||||
text: string;
|
||||
image?: DatabaseMessageExtraImageFile;
|
||||
};
|
||||
|
||||
let { content, message, isStreaming = false, highlightTurns = false }: Props = $props();
|
||||
let { message, toolMessages = [], isStreaming = false, highlightTurns = false }: Props = $props();
|
||||
|
||||
let expandedStates: Record<number, boolean> = $state({});
|
||||
|
||||
const sections = $derived(parseAgenticContent(content));
|
||||
const showToolCallInProgress = $derived(config().showToolCallInProgress as boolean);
|
||||
const showThoughtInProgress = $derived(config().showThoughtInProgress as boolean);
|
||||
|
||||
// Parse toolResults with images only when sections or message.extra change
|
||||
const sections = $derived(deriveAgenticSections(message, toolMessages, []));
|
||||
|
||||
// Parse tool results with images
|
||||
const sectionsParsed = $derived(
|
||||
sections.map((section) => ({
|
||||
...section,
|
||||
parsedLines: section.toolResult
|
||||
? parseToolResultWithImages(section.toolResult, message?.extra)
|
||||
: []
|
||||
? parseToolResultWithImages(section.toolResult, section.toolResultExtras || message?.extra)
|
||||
: ([] as ToolResultLine[])
|
||||
}))
|
||||
);
|
||||
|
||||
@@ -107,26 +107,6 @@
|
||||
expandedStates[index] = !currentState;
|
||||
}
|
||||
|
||||
function parseToolResultWithImages(
|
||||
toolResult: string,
|
||||
extras?: DatabaseMessage['extra']
|
||||
): ToolResultLine[] {
|
||||
const lines = toolResult.split(NEWLINE_SEPARATOR);
|
||||
|
||||
return lines.map((line) => {
|
||||
const match = line.match(ATTACHMENT_SAVED_REGEX);
|
||||
if (!match || !extras) return { text: line };
|
||||
|
||||
const attachmentName = match[1];
|
||||
const image = extras.find(
|
||||
(e): e is DatabaseMessageExtraImageFile =>
|
||||
e.type === AttachmentType.IMAGE && e.name === attachmentName
|
||||
);
|
||||
|
||||
return { text: line, image };
|
||||
});
|
||||
}
|
||||
|
||||
function buildTurnAgenticTimings(stats: ChatMessageAgenticTurnStats): ChatMessageAgenticTimings {
|
||||
return {
|
||||
turns: 1,
|
||||
@@ -144,9 +124,8 @@
|
||||
<MarkdownContent content={section.content} attachments={message?.extra} />
|
||||
</div>
|
||||
{:else if section.type === AgenticSectionType.TOOL_CALL_STREAMING}
|
||||
{@const streamingIcon = isStreaming ? Loader2 : AlertTriangle}
|
||||
{@const streamingIconClass = isStreaming ? 'h-4 w-4 animate-spin' : 'h-4 w-4 text-yellow-500'}
|
||||
{@const streamingSubtitle = isStreaming ? '' : 'incomplete'}
|
||||
{@const streamingIcon = isStreaming ? Loader2 : Loader2}
|
||||
{@const streamingIconClass = isStreaming ? 'h-4 w-4 animate-spin' : 'h-4 w-4'}
|
||||
|
||||
<CollapsibleContentBlock
|
||||
open={isExpanded(index, section)}
|
||||
@@ -154,7 +133,7 @@
|
||||
icon={streamingIcon}
|
||||
iconClass={streamingIconClass}
|
||||
title={section.toolName || 'Tool call'}
|
||||
subtitle={streamingSubtitle}
|
||||
subtitle={isStreaming ? '' : 'incomplete'}
|
||||
{isStreaming}
|
||||
onToggle={() => toggleExpanded(index, section)}
|
||||
>
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import { Check, X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import { AGENTIC_TAGS, INPUT_CLASSES, REASONING_TAGS } from '$lib/constants';
|
||||
import { INPUT_CLASSES } from '$lib/constants';
|
||||
import { MessageRole, KeyboardKey, ChatMessageStatsView } from '$lib/enums';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
@@ -23,6 +23,8 @@
|
||||
import { modelsStore } from '$lib/stores/models.svelte';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
|
||||
import { hasAgenticContent } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
deletionInfo: {
|
||||
@@ -33,6 +35,7 @@
|
||||
} | null;
|
||||
isLastAssistantMessage?: boolean;
|
||||
message: DatabaseMessage;
|
||||
toolMessages?: DatabaseMessage[];
|
||||
messageContent: string | undefined;
|
||||
onCopy: () => void;
|
||||
onConfirmDelete: () => void;
|
||||
@@ -53,6 +56,7 @@
|
||||
deletionInfo,
|
||||
isLastAssistantMessage = false,
|
||||
message,
|
||||
toolMessages = [],
|
||||
messageContent,
|
||||
onConfirmDelete,
|
||||
onContinue,
|
||||
@@ -84,10 +88,8 @@
|
||||
}
|
||||
}
|
||||
|
||||
const hasAgenticMarkers = $derived(
|
||||
messageContent?.includes(AGENTIC_TAGS.TOOL_CALL_START) ?? false
|
||||
);
|
||||
const hasReasoningMarkers = $derived(messageContent?.includes(REASONING_TAGS.START) ?? false);
|
||||
const isAgentic = $derived(hasAgenticContent(message, toolMessages));
|
||||
const hasReasoning = $derived(!!message.reasoningContent);
|
||||
const processingState = useProcessingState();
|
||||
|
||||
let currentConfig = $derived(config());
|
||||
@@ -145,7 +147,7 @@
|
||||
}
|
||||
|
||||
let highlightAgenticTurns = $derived(
|
||||
hasAgenticMarkers &&
|
||||
isAgentic &&
|
||||
(currentConfig.alwaysShowAgenticTurns || activeStatsView === ChatMessageStatsView.SUMMARY)
|
||||
);
|
||||
|
||||
@@ -160,13 +162,14 @@
|
||||
message?.role === MessageRole.ASSISTANT &&
|
||||
isActivelyProcessing &&
|
||||
hasNoContent &&
|
||||
!isAgentic &&
|
||||
isLastAssistantMessage
|
||||
);
|
||||
|
||||
let showProcessingInfoBottom = $derived(
|
||||
message?.role === MessageRole.ASSISTANT &&
|
||||
isActivelyProcessing &&
|
||||
!hasNoContent &&
|
||||
(!hasNoContent || isAgentic) &&
|
||||
isLastAssistantMessage
|
||||
);
|
||||
|
||||
@@ -252,10 +255,10 @@
|
||||
<pre class="raw-output">{messageContent || ''}</pre>
|
||||
{:else}
|
||||
<ChatMessageAgenticContent
|
||||
content={messageContent || ''}
|
||||
{message}
|
||||
{toolMessages}
|
||||
isStreaming={isChatStreaming()}
|
||||
highlightTurns={highlightAgenticTurns}
|
||||
{message}
|
||||
/>
|
||||
{/if}
|
||||
{:else}
|
||||
@@ -344,9 +347,7 @@
|
||||
{onCopy}
|
||||
{onEdit}
|
||||
{onRegenerate}
|
||||
onContinue={currentConfig.enableContinueGeneration && !hasReasoningMarkers
|
||||
? onContinue
|
||||
: undefined}
|
||||
onContinue={currentConfig.enableContinueGeneration && !hasReasoning ? onContinue : undefined}
|
||||
{onForkConversation}
|
||||
{onDelete}
|
||||
{onConfirmDelete}
|
||||
|
||||
@@ -6,7 +6,12 @@
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { copyToClipboard, formatMessageForClipboard, getMessageSiblings } from '$lib/utils';
|
||||
import {
|
||||
copyToClipboard,
|
||||
formatMessageForClipboard,
|
||||
getMessageSiblings,
|
||||
hasAgenticContent
|
||||
} from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -119,32 +124,75 @@
|
||||
? messages
|
||||
: messages.filter((msg) => msg.type !== MessageRole.SYSTEM);
|
||||
|
||||
let lastAssistantIndex = -1;
|
||||
// Build display entries, grouping agentic sessions into single entries.
|
||||
// An agentic session = assistant(with tool_calls) → tool → assistant → tool → ... → assistant(final)
|
||||
const result: Array<{
|
||||
message: DatabaseMessage;
|
||||
toolMessages: DatabaseMessage[];
|
||||
isLastAssistantMessage: boolean;
|
||||
siblingInfo: ChatMessageSiblingInfo;
|
||||
}> = [];
|
||||
|
||||
for (let i = filteredMessages.length - 1; i >= 0; i--) {
|
||||
if (filteredMessages[i].role === MessageRole.ASSISTANT) {
|
||||
lastAssistantIndex = i;
|
||||
for (let i = 0; i < filteredMessages.length; i++) {
|
||||
const msg = filteredMessages[i];
|
||||
|
||||
// Skip tool messages - they're grouped with preceding assistant
|
||||
if (msg.role === MessageRole.TOOL) continue;
|
||||
|
||||
const toolMessages: DatabaseMessage[] = [];
|
||||
if (msg.role === MessageRole.ASSISTANT && hasAgenticContent(msg)) {
|
||||
let j = i + 1;
|
||||
|
||||
while (j < filteredMessages.length) {
|
||||
const next = filteredMessages[j];
|
||||
|
||||
if (next.role === MessageRole.TOOL) {
|
||||
toolMessages.push(next);
|
||||
|
||||
j++;
|
||||
} else if (next.role === MessageRole.ASSISTANT) {
|
||||
toolMessages.push(next);
|
||||
|
||||
j++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
i = j - 1;
|
||||
} else if (msg.role === MessageRole.ASSISTANT) {
|
||||
let j = i + 1;
|
||||
|
||||
while (j < filteredMessages.length && filteredMessages[j].role === MessageRole.TOOL) {
|
||||
toolMessages.push(filteredMessages[j]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
|
||||
const siblingInfo = getMessageSiblings(allConversationMessages, msg.id);
|
||||
|
||||
result.push({
|
||||
message: msg,
|
||||
toolMessages,
|
||||
isLastAssistantMessage: false,
|
||||
siblingInfo: siblingInfo || {
|
||||
message: msg,
|
||||
siblingIds: [msg.id],
|
||||
currentIndex: 0,
|
||||
totalSiblings: 1
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Mark the last assistant message
|
||||
for (let i = result.length - 1; i >= 0; i--) {
|
||||
if (result[i].message.role === MessageRole.ASSISTANT) {
|
||||
result[i].isLastAssistantMessage = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return filteredMessages.map((message, index) => {
|
||||
const siblingInfo = getMessageSiblings(allConversationMessages, message.id);
|
||||
const isLastAssistantMessage =
|
||||
message.role === MessageRole.ASSISTANT && index === lastAssistantIndex;
|
||||
|
||||
return {
|
||||
message,
|
||||
isLastAssistantMessage,
|
||||
siblingInfo: siblingInfo || {
|
||||
message,
|
||||
siblingIds: [message.id],
|
||||
currentIndex: 0,
|
||||
totalSiblings: 1
|
||||
}
|
||||
};
|
||||
});
|
||||
return result;
|
||||
});
|
||||
</script>
|
||||
|
||||
@@ -152,11 +200,12 @@
|
||||
class="flex h-full flex-col space-y-10 pt-24 {className}"
|
||||
style="height: auto; min-height: calc(100dvh - 14rem);"
|
||||
>
|
||||
{#each displayMessages as { message, isLastAssistantMessage, siblingInfo } (message.id)}
|
||||
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
|
||||
<div use:fadeInView>
|
||||
<ChatMessage
|
||||
class="mx-auto w-full max-w-[48rem]"
|
||||
{message}
|
||||
{toolMessages}
|
||||
{isLastAssistantMessage}
|
||||
{siblingInfo}
|
||||
/>
|
||||
|
||||
@@ -425,21 +425,16 @@ export { default as ChatMessage } from './ChatMessages/ChatMessage.svelte';
|
||||
/**
|
||||
* **ChatMessageAgenticContent** - Agentic workflow output display
|
||||
*
|
||||
* Specialized renderer for assistant messages containing agentic workflow markers.
|
||||
* Parses structured content and displays tool calls and reasoning blocks as
|
||||
* interactive collapsible sections with real-time streaming support.
|
||||
* Specialized renderer for assistant messages with tool calls and reasoning.
|
||||
* Derives display sections from structured message data (toolCalls, reasoningContent,
|
||||
* and child tool result messages) and renders them as interactive collapsible sections.
|
||||
*
|
||||
* **Architecture:**
|
||||
* - Uses `parseAgenticContent()` from `$lib/utils` to parse markers
|
||||
* - Uses `deriveAgenticSections()` from `$lib/utils` to build sections from structured data
|
||||
* - Renders sections as CollapsibleContentBlock components
|
||||
* - Handles streaming state for progressive content display
|
||||
* - Falls back to MarkdownContent for plain text sections
|
||||
*
|
||||
* **Marker Format:**
|
||||
* - Tool calls: in constants/agentic.ts (AGENTIC_TAGS)
|
||||
* - Reasoning: in constants/agentic.ts (REASONING_TAGS)
|
||||
* - Partial markers handled gracefully during streaming
|
||||
*
|
||||
* **Execution States:**
|
||||
* - **Streaming**: Animated spinner, block expanded, auto-scroll enabled
|
||||
* - **Pending**: Waiting indicator for queued tool calls
|
||||
|
||||
@@ -15,8 +15,11 @@ export const DEFAULT_AGENTIC_CONFIG: AgenticConfig = {
|
||||
maxToolPreviewLines: 25
|
||||
} as const;
|
||||
|
||||
// Agentic tool call tag markers
|
||||
export const AGENTIC_TAGS = {
|
||||
/**
|
||||
* @deprecated Legacy marker tags - only used for migration of old stored messages.
|
||||
* New messages use structured fields (reasoningContent, toolCalls, toolCallId).
|
||||
*/
|
||||
export const LEGACY_AGENTIC_TAGS = {
|
||||
TOOL_CALL_START: '<<<AGENTIC_TOOL_CALL_START>>>',
|
||||
TOOL_CALL_END: '<<<AGENTIC_TOOL_CALL_END>>>',
|
||||
TOOL_NAME_PREFIX: '<<<TOOL_NAME:',
|
||||
@@ -25,39 +28,25 @@ export const AGENTIC_TAGS = {
|
||||
TAG_SUFFIX: '>>>'
|
||||
} as const;
|
||||
|
||||
export const REASONING_TAGS = {
|
||||
/**
|
||||
* @deprecated Legacy reasoning tags - only used for migration of old stored messages.
|
||||
* New messages use the dedicated reasoningContent field.
|
||||
*/
|
||||
export const LEGACY_REASONING_TAGS = {
|
||||
START: '<<<reasoning_content_start>>>',
|
||||
END: '<<<reasoning_content_end>>>'
|
||||
} as const;
|
||||
|
||||
// Regex for trimming leading/trailing newlines
|
||||
export const TRIM_NEWLINES_REGEX = /^\n+|\n+$/g;
|
||||
|
||||
// Regex patterns for parsing agentic content
|
||||
export const AGENTIC_REGEX = {
|
||||
// Matches completed tool calls (with END marker)
|
||||
/**
|
||||
* @deprecated Legacy regex patterns - only used for migration of old stored messages.
|
||||
*/
|
||||
export const LEGACY_AGENTIC_REGEX = {
|
||||
COMPLETED_TOOL_CALL:
|
||||
/<<<AGENTIC_TOOL_CALL_START>>>\n<<<TOOL_NAME:(.+?)>>>\n<<<TOOL_ARGS_START>>>([\s\S]*?)<<<TOOL_ARGS_END>>>([\s\S]*?)<<<AGENTIC_TOOL_CALL_END>>>/g,
|
||||
// Matches pending tool call (has NAME and ARGS but no END)
|
||||
PENDING_TOOL_CALL:
|
||||
/<<<AGENTIC_TOOL_CALL_START>>>\n<<<TOOL_NAME:(.+?)>>>\n<<<TOOL_ARGS_START>>>([\s\S]*?)<<<TOOL_ARGS_END>>>([\s\S]*)$/,
|
||||
// Matches partial tool call (has START and NAME, ARGS still streaming)
|
||||
PARTIAL_WITH_NAME:
|
||||
/<<<AGENTIC_TOOL_CALL_START>>>\n<<<TOOL_NAME:(.+?)>>>\n<<<TOOL_ARGS_START>>>([\s\S]*)$/,
|
||||
// Matches early tool call (just START marker)
|
||||
EARLY_MATCH: /<<<AGENTIC_TOOL_CALL_START>>>([\s\S]*)$/,
|
||||
// Matches partial marker at end of content
|
||||
PARTIAL_MARKER: /<<<[A-Za-z_]*$/,
|
||||
// Matches reasoning content blocks (including tags)
|
||||
REASONING_BLOCK: /<<<reasoning_content_start>>>[\s\S]*?<<<reasoning_content_end>>>/g,
|
||||
// Captures the reasoning text between start/end tags
|
||||
REASONING_EXTRACT: /<<<reasoning_content_start>>>([\s\S]*?)<<<reasoning_content_end>>>/,
|
||||
// Matches an opening reasoning tag and any remaining content (unterminated)
|
||||
REASONING_OPEN: /<<<reasoning_content_start>>>[\s\S]*$/,
|
||||
// Matches a complete agentic tool call display block (start to end marker)
|
||||
AGENTIC_TOOL_CALL_BLOCK: /\n*<<<AGENTIC_TOOL_CALL_START>>>[\s\S]*?<<<AGENTIC_TOOL_CALL_END>>>/g,
|
||||
// Matches a pending/partial agentic tool call (start marker with no matching end)
|
||||
AGENTIC_TOOL_CALL_OPEN: /\n*<<<AGENTIC_TOOL_CALL_START>>>[\s\S]*$/,
|
||||
// Matches tool name inside content
|
||||
TOOL_NAME_EXTRACT: /<<<TOOL_NAME:([^>]+)>>>/
|
||||
HAS_LEGACY_MARKERS: /<<<(?:AGENTIC_TOOL_CALL_START|reasoning_content_start)>>>/
|
||||
} as const;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { getJsonHeaders, formatAttachmentText, isAbortError } from '$lib/utils';
|
||||
import { getJsonHeaders } from '$lib/utils/api-headers';
|
||||
import { formatAttachmentText } from '$lib/utils/formatters';
|
||||
import { isAbortError } from '$lib/utils/abort';
|
||||
import {
|
||||
AGENTIC_REGEX,
|
||||
ATTACHMENT_LABEL_PDF_FILE,
|
||||
ATTACHMENT_LABEL_MCP_PROMPT,
|
||||
ATTACHMENT_LABEL_MCP_RESOURCE
|
||||
@@ -17,38 +18,6 @@ import type { DatabaseMessageExtraMcpPrompt, DatabaseMessageExtraMcpResource } f
|
||||
import { modelsStore } from '$lib/stores/models.svelte';
|
||||
|
||||
export class ChatService {
|
||||
private static stripReasoningContent(
|
||||
content: ApiChatMessageData['content'] | null | undefined
|
||||
): ApiChatMessageData['content'] | null | undefined {
|
||||
if (!content) {
|
||||
return content;
|
||||
}
|
||||
|
||||
if (typeof content === 'string') {
|
||||
return content
|
||||
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '');
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) {
|
||||
return content;
|
||||
}
|
||||
|
||||
return content.map((part: ApiChatMessageContentPart) => {
|
||||
if (part.type !== ContentPartType.TEXT || !part.text) return part;
|
||||
return {
|
||||
...part,
|
||||
text: part.text
|
||||
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '')
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
@@ -57,46 +26,6 @@ export class ChatService {
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Extracts reasoning text from content that contains internal reasoning tags.
|
||||
* Returns the concatenated reasoning content or undefined if none found.
|
||||
*/
|
||||
private static extractReasoningFromContent(
|
||||
content: ApiChatMessageData['content'] | null | undefined
|
||||
): string | undefined {
|
||||
if (!content) return undefined;
|
||||
|
||||
const extractFromString = (text: string): string => {
|
||||
const parts: string[] = [];
|
||||
// Use a fresh regex instance to avoid shared lastIndex state
|
||||
const re = new RegExp(AGENTIC_REGEX.REASONING_EXTRACT.source);
|
||||
let match = re.exec(text);
|
||||
while (match) {
|
||||
parts.push(match[1]);
|
||||
// advance past the matched portion and retry
|
||||
text = text.slice(match.index + match[0].length);
|
||||
match = re.exec(text);
|
||||
}
|
||||
return parts.join('');
|
||||
};
|
||||
|
||||
if (typeof content === 'string') {
|
||||
const result = extractFromString(content);
|
||||
return result || undefined;
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) return undefined;
|
||||
|
||||
const parts: string[] = [];
|
||||
for (const part of content) {
|
||||
if (part.type === ContentPartType.TEXT && part.text) {
|
||||
const result = extractFromString(part.text);
|
||||
if (result) parts.push(result);
|
||||
}
|
||||
}
|
||||
return parts.length > 0 ? parts.join('') : undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a chat completion request to the llama.cpp server.
|
||||
* Supports both streaming and non-streaming responses with comprehensive parameter configuration.
|
||||
@@ -201,20 +130,15 @@ export class ChatService {
|
||||
|
||||
const requestBody: ApiChatCompletionRequest = {
|
||||
messages: normalizedMessages.map((msg: ApiChatMessageData) => {
|
||||
// Always strip internal reasoning/agentic tags from content
|
||||
const cleanedContent = ChatService.stripReasoningContent(msg.content);
|
||||
const mapped: ApiChatCompletionRequest['messages'][0] = {
|
||||
role: msg.role,
|
||||
content: cleanedContent,
|
||||
content: msg.content,
|
||||
tool_calls: msg.tool_calls,
|
||||
tool_call_id: msg.tool_call_id
|
||||
};
|
||||
// When preserving reasoning, extract it from raw content and send as separate field
|
||||
if (!excludeReasoningFromContext) {
|
||||
const reasoning = ChatService.extractReasoningFromContent(msg.content);
|
||||
if (reasoning) {
|
||||
mapped.reasoning_content = reasoning;
|
||||
}
|
||||
// Include reasoning_content from the dedicated field
|
||||
if (!excludeReasoningFromContext && msg.reasoning_content) {
|
||||
mapped.reasoning_content = msg.reasoning_content;
|
||||
}
|
||||
return mapped;
|
||||
}),
|
||||
@@ -726,6 +650,10 @@ export class ChatService {
|
||||
content: message.content
|
||||
};
|
||||
|
||||
if (message.reasoningContent) {
|
||||
result.reasoning_content = message.reasoningContent;
|
||||
}
|
||||
|
||||
if (toolCalls && toolCalls.length > 0) {
|
||||
result.tool_calls = toolCalls;
|
||||
}
|
||||
@@ -854,6 +782,9 @@ export class ChatService {
|
||||
role: message.role as MessageRole,
|
||||
content: contentParts
|
||||
};
|
||||
if (message.reasoningContent) {
|
||||
result.reasoning_content = message.reasoningContent;
|
||||
}
|
||||
if (toolCalls && toolCalls.length > 0) {
|
||||
result.tool_calls = toolCalls;
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ import type {
|
||||
import {
|
||||
buildProxiedUrl,
|
||||
buildProxiedHeaders,
|
||||
getAuthHeaders,
|
||||
throwIfAborted,
|
||||
isAbortError,
|
||||
createBase64DataUrl
|
||||
@@ -124,7 +125,14 @@ export class MCPService {
|
||||
const requestInit: RequestInit = {};
|
||||
|
||||
if (config.headers) {
|
||||
requestInit.headers = buildProxiedHeaders(config.headers);
|
||||
requestInit.headers = config.useProxy ? buildProxiedHeaders(config.headers) : config.headers;
|
||||
}
|
||||
|
||||
if (useProxy) {
|
||||
requestInit.headers = {
|
||||
...getAuthHeaders(),
|
||||
...(requestInit.headers as Record<string, string>)
|
||||
};
|
||||
}
|
||||
|
||||
if (config.credentials) {
|
||||
|
||||
@@ -7,6 +7,10 @@
|
||||
* - Session state management
|
||||
* - Turn limit enforcement
|
||||
*
|
||||
* Each agentic turn produces separate DB messages:
|
||||
* - One assistant message per LLM turn (with tool_calls if any)
|
||||
* - One tool result message per tool call execution
|
||||
*
|
||||
* **Architecture & Relationships:**
|
||||
* - **ChatService**: Stateless API layer (sendMessage, streaming)
|
||||
* - **mcpStore**: MCP connection management and tool execution
|
||||
@@ -16,7 +20,6 @@
|
||||
* @see mcpStore in stores/mcp.svelte.ts for MCP operations
|
||||
*/
|
||||
|
||||
import { SvelteMap } from 'svelte/reactivity';
|
||||
import { ChatService } from '$lib/services';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
@@ -24,7 +27,6 @@ import { modelsStore } from '$lib/stores/models.svelte';
|
||||
import { isAbortError } from '$lib/utils';
|
||||
import {
|
||||
DEFAULT_AGENTIC_CONFIG,
|
||||
AGENTIC_TAGS,
|
||||
NEWLINE_SEPARATOR,
|
||||
TURN_LIMIT_MESSAGE,
|
||||
LLM_ERROR_BLOCK_START,
|
||||
@@ -193,17 +195,6 @@ class AgenticStore {
|
||||
|
||||
async runAgenticFlow(params: AgenticFlowParams): Promise<AgenticFlowResult> {
|
||||
const { conversationId, messages, options = {}, callbacks, signal, perChatOverrides } = params;
|
||||
const {
|
||||
onChunk,
|
||||
onReasoningChunk,
|
||||
onToolCallChunk,
|
||||
onAttachments,
|
||||
onModel,
|
||||
onComplete,
|
||||
onError,
|
||||
onTimings,
|
||||
onTurnComplete
|
||||
} = callbacks;
|
||||
|
||||
const agenticConfig = this.getConfig(config(), perChatOverrides);
|
||||
if (!agenticConfig.enabled) return { handled: false };
|
||||
@@ -253,24 +244,14 @@ class AgenticStore {
|
||||
options,
|
||||
tools,
|
||||
agenticConfig,
|
||||
callbacks: {
|
||||
onChunk,
|
||||
onReasoningChunk,
|
||||
onToolCallChunk,
|
||||
onAttachments,
|
||||
onModel,
|
||||
onComplete,
|
||||
onError,
|
||||
onTimings,
|
||||
onTurnComplete
|
||||
},
|
||||
callbacks,
|
||||
signal
|
||||
});
|
||||
return { handled: true };
|
||||
} catch (error) {
|
||||
const normalizedError = error instanceof Error ? error : new Error(String(error));
|
||||
this.updateSession(conversationId, { lastError: normalizedError });
|
||||
onError?.(normalizedError);
|
||||
callbacks.onError?.(normalizedError);
|
||||
return { handled: true, error: normalizedError };
|
||||
} finally {
|
||||
this.updateSession(conversationId, { isRunning: false });
|
||||
@@ -295,17 +276,20 @@ class AgenticStore {
|
||||
const {
|
||||
onChunk,
|
||||
onReasoningChunk,
|
||||
onToolCallChunk,
|
||||
onToolCallsStreaming,
|
||||
onAttachments,
|
||||
onModel,
|
||||
onComplete,
|
||||
onAssistantTurnComplete,
|
||||
createToolResultMessage,
|
||||
createAssistantMessage,
|
||||
onFlowComplete,
|
||||
onTimings,
|
||||
onTurnComplete
|
||||
} = callbacks;
|
||||
|
||||
const sessionMessages: AgenticMessage[] = toAgenticMessages(messages);
|
||||
const allToolCalls: ApiChatCompletionToolCall[] = [];
|
||||
let capturedTimings: ChatMessageTimings | undefined;
|
||||
let totalToolCallCount = 0;
|
||||
|
||||
const agenticTimings: ChatMessageAgenticTimings = {
|
||||
turns: 0,
|
||||
@@ -316,12 +300,7 @@ class AgenticStore {
|
||||
llm: { predicted_n: 0, predicted_ms: 0, prompt_n: 0, prompt_ms: 0 }
|
||||
};
|
||||
const maxTurns = agenticConfig.maxTurns;
|
||||
const maxToolPreviewLines = agenticConfig.maxToolPreviewLines;
|
||||
|
||||
// Resolve effective model for vision capability checks.
|
||||
// In ROUTER mode, options.model is always set by the caller.
|
||||
// In MODEL mode, options.model is undefined; use the single loaded model
|
||||
// which carries modalities bridged from /props.
|
||||
const effectiveModel = options.model || modelsStore.models[0]?.model || '';
|
||||
|
||||
for (let turn = 0; turn < maxTurns; turn++) {
|
||||
@@ -329,23 +308,20 @@ class AgenticStore {
|
||||
agenticTimings.turns = turn + 1;
|
||||
|
||||
if (signal?.aborted) {
|
||||
onComplete?.(
|
||||
'',
|
||||
undefined,
|
||||
this.buildFinalTimings(capturedTimings, agenticTimings),
|
||||
undefined
|
||||
);
|
||||
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
|
||||
return;
|
||||
}
|
||||
|
||||
// For turns > 0, create a new assistant message via callback
|
||||
if (turn > 0 && createAssistantMessage) {
|
||||
await createAssistantMessage();
|
||||
}
|
||||
|
||||
let turnContent = '';
|
||||
let turnReasoningContent = '';
|
||||
let turnToolCalls: ApiChatCompletionToolCall[] = [];
|
||||
let lastStreamingToolCallName = '';
|
||||
let lastStreamingToolCallArgsLength = 0;
|
||||
const emittedToolCallStates = new SvelteMap<
|
||||
number,
|
||||
{ emittedOnce: boolean; lastArgs: string }
|
||||
>();
|
||||
let turnTimings: ChatMessageTimings | undefined;
|
||||
|
||||
const turnStats: ChatMessageAgenticTurnStats = {
|
||||
@@ -366,30 +342,15 @@ class AgenticStore {
|
||||
turnContent += chunk;
|
||||
onChunk?.(chunk);
|
||||
},
|
||||
onReasoningChunk,
|
||||
onReasoningChunk: (chunk: string) => {
|
||||
turnReasoningContent += chunk;
|
||||
onReasoningChunk?.(chunk);
|
||||
},
|
||||
onToolCallChunk: (serialized: string) => {
|
||||
try {
|
||||
turnToolCalls = JSON.parse(serialized) as ApiChatCompletionToolCall[];
|
||||
for (let i = 0; i < turnToolCalls.length; i++) {
|
||||
const toolCall = turnToolCalls[i];
|
||||
const toolName = toolCall.function?.name ?? '';
|
||||
const toolArgs = toolCall.function?.arguments ?? '';
|
||||
const state = emittedToolCallStates.get(i) || {
|
||||
emittedOnce: false,
|
||||
lastArgs: ''
|
||||
};
|
||||
if (!state.emittedOnce) {
|
||||
const output = `\n\n${AGENTIC_TAGS.TOOL_CALL_START}\n${AGENTIC_TAGS.TOOL_NAME_PREFIX}${toolName}${AGENTIC_TAGS.TAG_SUFFIX}\n${AGENTIC_TAGS.TOOL_ARGS_START}\n${toolArgs}`;
|
||||
onChunk?.(output);
|
||||
state.emittedOnce = true;
|
||||
state.lastArgs = toolArgs;
|
||||
emittedToolCallStates.set(i, state);
|
||||
} else if (toolArgs.length > state.lastArgs.length) {
|
||||
onChunk?.(toolArgs.slice(state.lastArgs.length));
|
||||
state.lastArgs = toolArgs;
|
||||
emittedToolCallStates.set(i, state);
|
||||
}
|
||||
}
|
||||
onToolCallsStreaming?.(turnToolCalls);
|
||||
|
||||
if (turnToolCalls.length > 0 && turnToolCalls[0]?.function) {
|
||||
const name = turnToolCalls[0].function.name || '';
|
||||
const args = turnToolCalls[0].function.arguments || '';
|
||||
@@ -442,77 +403,84 @@ class AgenticStore {
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal?.aborted) {
|
||||
onComplete?.(
|
||||
'',
|
||||
undefined,
|
||||
// Save whatever we have for this turn before exiting
|
||||
await onAssistantTurnComplete?.(
|
||||
turnContent,
|
||||
turnReasoningContent || undefined,
|
||||
this.buildFinalTimings(capturedTimings, agenticTimings),
|
||||
undefined
|
||||
);
|
||||
|
||||
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
|
||||
return;
|
||||
}
|
||||
const normalizedError = error instanceof Error ? error : new Error('LLM stream error');
|
||||
// Save error as content in the current turn
|
||||
onChunk?.(`${LLM_ERROR_BLOCK_START}${normalizedError.message}${LLM_ERROR_BLOCK_END}`);
|
||||
onComplete?.(
|
||||
'',
|
||||
undefined,
|
||||
await onAssistantTurnComplete?.(
|
||||
turnContent + `${LLM_ERROR_BLOCK_START}${normalizedError.message}${LLM_ERROR_BLOCK_END}`,
|
||||
turnReasoningContent || undefined,
|
||||
this.buildFinalTimings(capturedTimings, agenticTimings),
|
||||
undefined
|
||||
);
|
||||
|
||||
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
|
||||
throw normalizedError;
|
||||
}
|
||||
|
||||
// No tool calls = final turn, save and complete
|
||||
if (turnToolCalls.length === 0) {
|
||||
agenticTimings.perTurn!.push(turnStats);
|
||||
|
||||
onComplete?.(
|
||||
'',
|
||||
undefined,
|
||||
this.buildFinalTimings(capturedTimings, agenticTimings),
|
||||
const finalTimings = this.buildFinalTimings(capturedTimings, agenticTimings);
|
||||
|
||||
await onAssistantTurnComplete?.(
|
||||
turnContent,
|
||||
turnReasoningContent || undefined,
|
||||
finalTimings,
|
||||
undefined
|
||||
);
|
||||
|
||||
if (finalTimings) onTurnComplete?.(finalTimings);
|
||||
|
||||
onFlowComplete?.(finalTimings);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Normalize and save assistant turn with tool calls
|
||||
const normalizedCalls = this.normalizeToolCalls(turnToolCalls);
|
||||
if (normalizedCalls.length === 0) {
|
||||
onComplete?.(
|
||||
'',
|
||||
undefined,
|
||||
await onAssistantTurnComplete?.(
|
||||
turnContent,
|
||||
turnReasoningContent || undefined,
|
||||
this.buildFinalTimings(capturedTimings, agenticTimings),
|
||||
undefined
|
||||
);
|
||||
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
|
||||
return;
|
||||
}
|
||||
|
||||
for (const call of normalizedCalls) {
|
||||
allToolCalls.push({
|
||||
id: call.id,
|
||||
type: call.type,
|
||||
function: call.function ? { ...call.function } : undefined
|
||||
});
|
||||
}
|
||||
totalToolCallCount += normalizedCalls.length;
|
||||
this.updateSession(conversationId, { totalToolCalls: totalToolCallCount });
|
||||
|
||||
this.updateSession(conversationId, { totalToolCalls: allToolCalls.length });
|
||||
onToolCallChunk?.(JSON.stringify(allToolCalls));
|
||||
// Save the assistant message with its tool calls
|
||||
await onAssistantTurnComplete?.(
|
||||
turnContent,
|
||||
turnReasoningContent || undefined,
|
||||
turnTimings,
|
||||
normalizedCalls
|
||||
);
|
||||
|
||||
// Add assistant message to session history
|
||||
sessionMessages.push({
|
||||
role: MessageRole.ASSISTANT,
|
||||
content: turnContent || undefined,
|
||||
tool_calls: normalizedCalls
|
||||
});
|
||||
|
||||
// Execute each tool call and create result messages
|
||||
for (const toolCall of normalizedCalls) {
|
||||
if (signal?.aborted) {
|
||||
onComplete?.(
|
||||
'',
|
||||
undefined,
|
||||
this.buildFinalTimings(capturedTimings, agenticTimings),
|
||||
undefined
|
||||
);
|
||||
|
||||
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -530,13 +498,7 @@ class AgenticStore {
|
||||
result = executionResult.content;
|
||||
} catch (error) {
|
||||
if (isAbortError(error)) {
|
||||
onComplete?.(
|
||||
'',
|
||||
undefined,
|
||||
this.buildFinalTimings(capturedTimings, agenticTimings),
|
||||
undefined
|
||||
);
|
||||
|
||||
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
|
||||
return;
|
||||
}
|
||||
result = `Error: ${error instanceof Error ? error.message : String(error)}`;
|
||||
@@ -557,21 +519,27 @@ class AgenticStore {
|
||||
turnStats.toolsMs += Math.round(toolDurationMs);
|
||||
|
||||
if (signal?.aborted) {
|
||||
onComplete?.(
|
||||
'',
|
||||
undefined,
|
||||
this.buildFinalTimings(capturedTimings, agenticTimings),
|
||||
undefined
|
||||
);
|
||||
|
||||
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
|
||||
return;
|
||||
}
|
||||
|
||||
const { cleanedResult, attachments } = this.extractBase64Attachments(result);
|
||||
if (attachments.length > 0) onAttachments?.(attachments);
|
||||
|
||||
this.emitToolCallResult(cleanedResult, maxToolPreviewLines, onChunk);
|
||||
// Create the tool result message in the DB
|
||||
let toolResultMessage: DatabaseMessage | undefined;
|
||||
if (createToolResultMessage) {
|
||||
toolResultMessage = await createToolResultMessage(
|
||||
toolCall.id,
|
||||
cleanedResult,
|
||||
attachments.length > 0 ? attachments : undefined
|
||||
);
|
||||
}
|
||||
|
||||
if (attachments.length > 0 && toolResultMessage) {
|
||||
onAttachments?.(toolResultMessage.id, attachments);
|
||||
}
|
||||
|
||||
// Build content parts for session history (including images for vision models)
|
||||
const contentParts: ApiChatMessageContentPart[] = [
|
||||
{ type: ContentPartType.TEXT, text: cleanedResult }
|
||||
];
|
||||
@@ -605,8 +573,15 @@ class AgenticStore {
|
||||
}
|
||||
}
|
||||
|
||||
// Turn limit reached
|
||||
onChunk?.(TURN_LIMIT_MESSAGE);
|
||||
onComplete?.('', undefined, this.buildFinalTimings(capturedTimings, agenticTimings), undefined);
|
||||
await onAssistantTurnComplete?.(
|
||||
TURN_LIMIT_MESSAGE,
|
||||
undefined,
|
||||
this.buildFinalTimings(capturedTimings, agenticTimings),
|
||||
undefined
|
||||
);
|
||||
onFlowComplete?.(this.buildFinalTimings(capturedTimings, agenticTimings));
|
||||
}
|
||||
|
||||
private buildFinalTimings(
|
||||
@@ -633,23 +608,6 @@ class AgenticStore {
|
||||
}));
|
||||
}
|
||||
|
||||
private emitToolCallResult(
|
||||
result: string,
|
||||
maxLines: number,
|
||||
emit?: (chunk: string) => void
|
||||
): void {
|
||||
if (!emit) {
|
||||
return;
|
||||
}
|
||||
|
||||
let output = `${NEWLINE_SEPARATOR}${AGENTIC_TAGS.TOOL_ARGS_END}`;
|
||||
const lines = result.split(NEWLINE_SEPARATOR);
|
||||
const trimmedLines = lines.length > maxLines ? lines.slice(-maxLines) : lines;
|
||||
|
||||
output += `${NEWLINE_SEPARATOR}${trimmedLines.join(NEWLINE_SEPARATOR)}${NEWLINE_SEPARATOR}${AGENTIC_TAGS.TOOL_CALL_END}${NEWLINE_SEPARATOR}`;
|
||||
emit(output);
|
||||
}
|
||||
|
||||
private extractBase64Attachments(result: string): {
|
||||
cleanedResult: string;
|
||||
attachments: DatabaseMessageExtra[];
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
*/
|
||||
|
||||
import { SvelteMap } from 'svelte/reactivity';
|
||||
import { DatabaseService, ChatService } from '$lib/services';
|
||||
import { DatabaseService } from '$lib/services/database.service';
|
||||
import { ChatService } from '$lib/services/chat.service';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { agenticStore } from '$lib/stores/agentic.svelte';
|
||||
@@ -28,12 +29,12 @@ import {
|
||||
filterByLeafNodeId,
|
||||
findDescendantMessages,
|
||||
findLeafNode,
|
||||
findMessageById,
|
||||
isAbortError
|
||||
} from '$lib/utils';
|
||||
import {
|
||||
MAX_INACTIVE_CONVERSATION_STATES,
|
||||
INACTIVE_CONVERSATION_STATE_MAX_AGE_MS,
|
||||
REASONING_TAGS,
|
||||
SYSTEM_MESSAGE_PLACEHOLDER
|
||||
} from '$lib/constants';
|
||||
import type {
|
||||
@@ -49,15 +50,6 @@ interface ConversationStateEntry {
|
||||
lastAccessed: number;
|
||||
}
|
||||
|
||||
const countOccurrences = (source: string, token: string): number =>
|
||||
source ? source.split(token).length - 1 : 0;
|
||||
const hasUnclosedReasoningTag = (content: string): boolean =>
|
||||
countOccurrences(content, REASONING_TAGS.START) > countOccurrences(content, REASONING_TAGS.END);
|
||||
const wrapReasoningContent = (content: string, reasoningContent?: string): string => {
|
||||
if (!reasoningContent) return content;
|
||||
return `${REASONING_TAGS.START}${reasoningContent}${REASONING_TAGS.END}${content}`;
|
||||
};
|
||||
|
||||
class ChatStore {
|
||||
activeProcessingState = $state<ApiProcessingState | null>(null);
|
||||
currentResponse = $state('');
|
||||
@@ -416,7 +408,7 @@ class ChatStore {
|
||||
if (!activeConv) return false;
|
||||
try {
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const systemMessage = allMessages.find((m) => m.id === messageId);
|
||||
const systemMessage = findMessageById(allMessages, messageId);
|
||||
if (!systemMessage || systemMessage.role !== MessageRole.SYSTEM) return false;
|
||||
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
||||
if (!rootMessage) return false;
|
||||
@@ -556,83 +548,76 @@ class ChatStore {
|
||||
await modelsStore.fetchModelProps(effectiveModel);
|
||||
}
|
||||
|
||||
let streamedContent = '',
|
||||
streamedToolCallContent = '',
|
||||
isReasoningOpen = false,
|
||||
hasStreamedChunks = false,
|
||||
resolvedModel: string | null = null,
|
||||
modelPersisted = false;
|
||||
let streamedExtras: DatabaseMessageExtra[] = assistantMessage.extra
|
||||
? JSON.parse(JSON.stringify(assistantMessage.extra))
|
||||
: [];
|
||||
// Mutable state for the current message being streamed
|
||||
let currentMessageId = assistantMessage.id;
|
||||
let streamedContent = '';
|
||||
let streamedReasoningContent = '';
|
||||
let resolvedModel: string | null = null;
|
||||
let modelPersisted = false;
|
||||
const convId = assistantMessage.convId;
|
||||
|
||||
const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => {
|
||||
if (!modelName) return;
|
||||
const n = normalizeModelName(modelName);
|
||||
if (!n || n === resolvedModel) return;
|
||||
resolvedModel = n;
|
||||
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
|
||||
const idx = conversationsStore.findMessageIndex(currentMessageId);
|
||||
conversationsStore.updateMessageAtIndex(idx, { model: n });
|
||||
if (persistImmediately && !modelPersisted) {
|
||||
modelPersisted = true;
|
||||
DatabaseService.updateMessage(assistantMessage.id, { model: n }).catch(() => {
|
||||
DatabaseService.updateMessage(currentMessageId, { model: n }).catch(() => {
|
||||
modelPersisted = false;
|
||||
resolvedModel = null;
|
||||
});
|
||||
}
|
||||
};
|
||||
const updateStreamingContent = () => {
|
||||
this.setChatStreaming(assistantMessage.convId, streamedContent, assistantMessage.id);
|
||||
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
|
||||
|
||||
const updateStreamingUI = () => {
|
||||
this.setChatStreaming(convId, streamedContent, currentMessageId);
|
||||
const idx = conversationsStore.findMessageIndex(currentMessageId);
|
||||
conversationsStore.updateMessageAtIndex(idx, { content: streamedContent });
|
||||
};
|
||||
const appendContentChunk = (chunk: string) => {
|
||||
if (isReasoningOpen) {
|
||||
streamedContent += REASONING_TAGS.END;
|
||||
isReasoningOpen = false;
|
||||
}
|
||||
streamedContent += chunk;
|
||||
hasStreamedChunks = true;
|
||||
updateStreamingContent();
|
||||
};
|
||||
const appendReasoningChunk = (chunk: string) => {
|
||||
if (!isReasoningOpen) {
|
||||
streamedContent += REASONING_TAGS.START;
|
||||
isReasoningOpen = true;
|
||||
}
|
||||
streamedContent += chunk;
|
||||
hasStreamedChunks = true;
|
||||
updateStreamingContent();
|
||||
};
|
||||
const finalizeReasoning = () => {
|
||||
if (isReasoningOpen) {
|
||||
streamedContent += REASONING_TAGS.END;
|
||||
isReasoningOpen = false;
|
||||
}
|
||||
|
||||
const cleanupStreamingState = () => {
|
||||
this.setStreamingActive(false);
|
||||
this.setChatLoading(convId, false);
|
||||
this.clearChatStreaming(convId);
|
||||
this.setProcessingState(convId, null);
|
||||
};
|
||||
|
||||
this.setStreamingActive(true);
|
||||
this.setActiveProcessingConversation(assistantMessage.convId);
|
||||
const abortController = this.getOrCreateAbortController(assistantMessage.convId);
|
||||
this.setActiveProcessingConversation(convId);
|
||||
const abortController = this.getOrCreateAbortController(convId);
|
||||
|
||||
const streamCallbacks: ChatStreamCallbacks = {
|
||||
onChunk: (chunk: string) => appendContentChunk(chunk),
|
||||
onReasoningChunk: (chunk: string) => appendReasoningChunk(chunk),
|
||||
onToolCallChunk: (chunk: string) => {
|
||||
const c = chunk.trim();
|
||||
if (!c) return;
|
||||
streamedToolCallContent = c;
|
||||
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
|
||||
conversationsStore.updateMessageAtIndex(idx, { toolCalls: streamedToolCallContent });
|
||||
onChunk: (chunk: string) => {
|
||||
streamedContent += chunk;
|
||||
updateStreamingUI();
|
||||
},
|
||||
onAttachments: (extras: DatabaseMessageExtra[]) => {
|
||||
onReasoningChunk: (chunk: string) => {
|
||||
streamedReasoningContent += chunk;
|
||||
// Update UI to show reasoning is being received
|
||||
const idx = conversationsStore.findMessageIndex(currentMessageId);
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
reasoningContent: streamedReasoningContent
|
||||
});
|
||||
},
|
||||
onToolCallsStreaming: (toolCalls) => {
|
||||
const idx = conversationsStore.findMessageIndex(currentMessageId);
|
||||
conversationsStore.updateMessageAtIndex(idx, { toolCalls: JSON.stringify(toolCalls) });
|
||||
},
|
||||
onAttachments: (messageId: string, extras: DatabaseMessageExtra[]) => {
|
||||
if (!extras.length) return;
|
||||
streamedExtras = [...streamedExtras, ...extras];
|
||||
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
|
||||
conversationsStore.updateMessageAtIndex(idx, { extra: streamedExtras });
|
||||
DatabaseService.updateMessage(assistantMessage.id, { extra: streamedExtras }).catch(
|
||||
console.error
|
||||
);
|
||||
const idx = conversationsStore.findMessageIndex(messageId);
|
||||
if (idx === -1) return;
|
||||
const msg = conversationsStore.activeMessages[idx];
|
||||
const updatedExtras = [...(msg.extra || []), ...extras];
|
||||
conversationsStore.updateMessageAtIndex(idx, { extra: updatedExtras });
|
||||
DatabaseService.updateMessage(messageId, { extra: updatedExtras }).catch(console.error);
|
||||
},
|
||||
onModel: (modelName: string) => recordModel(modelName),
|
||||
onTurnComplete: (intermediateTimings: ChatMessageTimings) => {
|
||||
// Update the first assistant message with cumulative agentic timings
|
||||
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
|
||||
conversationsStore.updateMessageAtIndex(idx, { timings: intermediateTimings });
|
||||
},
|
||||
@@ -650,56 +635,104 @@ class ChatStore {
|
||||
cache_n: timings?.cache_n || 0,
|
||||
prompt_progress: promptProgress
|
||||
},
|
||||
assistantMessage.convId
|
||||
convId
|
||||
);
|
||||
},
|
||||
onComplete: async (
|
||||
finalContent?: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings,
|
||||
toolCallContent?: string
|
||||
onAssistantTurnComplete: async (
|
||||
content: string,
|
||||
reasoningContent: string | undefined,
|
||||
timings: ChatMessageTimings | undefined,
|
||||
toolCalls: import('$lib/types/api').ApiChatCompletionToolCall[] | undefined
|
||||
) => {
|
||||
this.setStreamingActive(false);
|
||||
finalizeReasoning();
|
||||
const combinedContent = hasStreamedChunks
|
||||
? streamedContent
|
||||
: wrapReasoningContent(finalContent || '', reasoningContent);
|
||||
const updateData: Record<string, unknown> = {
|
||||
content: combinedContent,
|
||||
toolCalls: toolCallContent || streamedToolCallContent,
|
||||
content,
|
||||
reasoningContent: reasoningContent || undefined,
|
||||
toolCalls: toolCalls ? JSON.stringify(toolCalls) : '',
|
||||
timings
|
||||
};
|
||||
if (streamedExtras.length > 0) updateData.extra = streamedExtras;
|
||||
if (resolvedModel && !modelPersisted) updateData.model = resolvedModel;
|
||||
await DatabaseService.updateMessage(assistantMessage.id, updateData);
|
||||
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
|
||||
await DatabaseService.updateMessage(currentMessageId, updateData);
|
||||
const idx = conversationsStore.findMessageIndex(currentMessageId);
|
||||
const uiUpdate: Partial<DatabaseMessage> = {
|
||||
content: combinedContent,
|
||||
toolCalls: updateData.toolCalls as string
|
||||
content,
|
||||
reasoningContent: reasoningContent || undefined,
|
||||
toolCalls: toolCalls ? JSON.stringify(toolCalls) : ''
|
||||
};
|
||||
if (streamedExtras.length > 0) uiUpdate.extra = streamedExtras;
|
||||
if (timings) uiUpdate.timings = timings;
|
||||
if (resolvedModel) uiUpdate.model = resolvedModel;
|
||||
conversationsStore.updateMessageAtIndex(idx, uiUpdate);
|
||||
await conversationsStore.updateCurrentNode(assistantMessage.id);
|
||||
if (onComplete) await onComplete(combinedContent);
|
||||
this.setChatLoading(assistantMessage.convId, false);
|
||||
this.clearChatStreaming(assistantMessage.convId);
|
||||
this.setProcessingState(assistantMessage.convId, null);
|
||||
await conversationsStore.updateCurrentNode(currentMessageId);
|
||||
},
|
||||
createToolResultMessage: async (
|
||||
toolCallId: string,
|
||||
content: string,
|
||||
extras?: DatabaseMessageExtra[]
|
||||
) => {
|
||||
const msg = await DatabaseService.createMessageBranch(
|
||||
{
|
||||
convId,
|
||||
type: MessageType.TEXT,
|
||||
role: MessageRole.TOOL,
|
||||
content,
|
||||
toolCallId,
|
||||
timestamp: Date.now(),
|
||||
toolCalls: '',
|
||||
children: [],
|
||||
extra: extras
|
||||
},
|
||||
currentMessageId
|
||||
);
|
||||
conversationsStore.addMessageToActive(msg);
|
||||
await conversationsStore.updateCurrentNode(msg.id);
|
||||
return msg;
|
||||
},
|
||||
createAssistantMessage: async () => {
|
||||
// Reset streaming state for new message
|
||||
streamedContent = '';
|
||||
streamedReasoningContent = '';
|
||||
|
||||
const lastMsg =
|
||||
conversationsStore.activeMessages[conversationsStore.activeMessages.length - 1];
|
||||
const msg = await DatabaseService.createMessageBranch(
|
||||
{
|
||||
convId,
|
||||
type: MessageType.TEXT,
|
||||
role: MessageRole.ASSISTANT,
|
||||
content: '',
|
||||
timestamp: Date.now(),
|
||||
toolCalls: '',
|
||||
children: [],
|
||||
model: resolvedModel
|
||||
},
|
||||
lastMsg.id
|
||||
);
|
||||
conversationsStore.addMessageToActive(msg);
|
||||
currentMessageId = msg.id;
|
||||
return msg;
|
||||
},
|
||||
onFlowComplete: (finalTimings?: ChatMessageTimings) => {
|
||||
if (finalTimings) {
|
||||
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
|
||||
|
||||
conversationsStore.updateMessageAtIndex(idx, { timings: finalTimings });
|
||||
DatabaseService.updateMessage(assistantMessage.id, { timings: finalTimings }).catch(
|
||||
console.error
|
||||
);
|
||||
}
|
||||
|
||||
cleanupStreamingState();
|
||||
|
||||
if (onComplete) onComplete(streamedContent);
|
||||
if (isRouterMode()) modelsStore.fetchRouterModels().catch(console.error);
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
this.setStreamingActive(false);
|
||||
if (isAbortError(error)) {
|
||||
this.setChatLoading(assistantMessage.convId, false);
|
||||
this.clearChatStreaming(assistantMessage.convId);
|
||||
this.setProcessingState(assistantMessage.convId, null);
|
||||
cleanupStreamingState();
|
||||
return;
|
||||
}
|
||||
console.error('Streaming error:', error);
|
||||
this.setChatLoading(assistantMessage.convId, false);
|
||||
this.clearChatStreaming(assistantMessage.convId);
|
||||
this.setProcessingState(assistantMessage.convId, null);
|
||||
cleanupStreamingState();
|
||||
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
const failedMessage = conversationsStore.removeMessageAtIndex(idx);
|
||||
@@ -716,12 +749,13 @@ class ChatStore {
|
||||
if (onError) onError(error);
|
||||
}
|
||||
};
|
||||
|
||||
const perChatOverrides = conversationsStore.activeConversation?.mcpServerOverrides;
|
||||
|
||||
const agenticConfig = agenticStore.getConfig(config(), perChatOverrides);
|
||||
if (agenticConfig.enabled) {
|
||||
const agenticResult = await agenticStore.runAgenticFlow({
|
||||
conversationId: assistantMessage.convId,
|
||||
conversationId: convId,
|
||||
messages: allMessages,
|
||||
options: { ...this.getApiOptions(), ...(effectiveModel ? { model: effectiveModel } : {}) },
|
||||
callbacks: streamCallbacks,
|
||||
@@ -731,16 +765,50 @@ class ChatStore {
|
||||
if (agenticResult.handled) return;
|
||||
}
|
||||
|
||||
const completionOptions = {
|
||||
...this.getApiOptions(),
|
||||
...(effectiveModel ? { model: effectiveModel } : {}),
|
||||
...streamCallbacks
|
||||
};
|
||||
|
||||
// Non-agentic path: direct streaming into the single assistant message
|
||||
await ChatService.sendMessage(
|
||||
allMessages,
|
||||
completionOptions,
|
||||
assistantMessage.convId,
|
||||
{
|
||||
...this.getApiOptions(),
|
||||
...(effectiveModel ? { model: effectiveModel } : {}),
|
||||
stream: true,
|
||||
onChunk: streamCallbacks.onChunk,
|
||||
onReasoningChunk: streamCallbacks.onReasoningChunk,
|
||||
onModel: streamCallbacks.onModel,
|
||||
onTimings: streamCallbacks.onTimings,
|
||||
onComplete: async (
|
||||
finalContent?: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings,
|
||||
toolCalls?: string
|
||||
) => {
|
||||
const content = streamedContent || finalContent || '';
|
||||
const reasoning = streamedReasoningContent || reasoningContent;
|
||||
const updateData: Record<string, unknown> = {
|
||||
content,
|
||||
reasoningContent: reasoning || undefined,
|
||||
toolCalls: toolCalls || '',
|
||||
timings
|
||||
};
|
||||
if (resolvedModel && !modelPersisted) updateData.model = resolvedModel;
|
||||
await DatabaseService.updateMessage(currentMessageId, updateData);
|
||||
const idx = conversationsStore.findMessageIndex(currentMessageId);
|
||||
const uiUpdate: Partial<DatabaseMessage> = {
|
||||
content,
|
||||
reasoningContent: reasoning || undefined,
|
||||
toolCalls: toolCalls || ''
|
||||
};
|
||||
if (timings) uiUpdate.timings = timings;
|
||||
if (resolvedModel) uiUpdate.model = resolvedModel;
|
||||
conversationsStore.updateMessageAtIndex(idx, uiUpdate);
|
||||
await conversationsStore.updateCurrentNode(currentMessageId);
|
||||
cleanupStreamingState();
|
||||
if (onComplete) await onComplete(content);
|
||||
if (isRouterMode()) modelsStore.fetchRouterModels().catch(console.error);
|
||||
},
|
||||
onError: streamCallbacks.onError
|
||||
},
|
||||
convId,
|
||||
abortController.signal
|
||||
);
|
||||
}
|
||||
@@ -878,7 +946,7 @@ class ChatStore {
|
||||
const msg = conversationsStore.activeMessages[idx];
|
||||
if (msg.role !== MessageRole.ASSISTANT) return;
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const parentMessage = allMessages.find((m) => m.id === msg.parent);
|
||||
const parentMessage = findMessageById(allMessages, msg.parent);
|
||||
if (!parentMessage) return;
|
||||
this.setChatLoading(activeConv.id, true);
|
||||
this.clearChatStreaming(activeConv.id);
|
||||
@@ -928,7 +996,7 @@ class ChatStore {
|
||||
if (!activeConv)
|
||||
return { totalCount: 0, userMessages: 0, assistantMessages: 0, messageTypes: [] };
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const messageToDelete = allMessages.find((m) => m.id === messageId);
|
||||
const messageToDelete = findMessageById(allMessages, messageId);
|
||||
|
||||
// For system messages, don't count descendants as they will be preserved (reparented to root)
|
||||
if (messageToDelete?.role === MessageRole.SYSTEM) {
|
||||
@@ -975,7 +1043,7 @@ class ChatStore {
|
||||
if (!activeConv) return;
|
||||
try {
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const messageToDelete = allMessages.find((m) => m.id === messageId);
|
||||
const messageToDelete = findMessageById(allMessages, messageId);
|
||||
|
||||
if (!messageToDelete) return;
|
||||
|
||||
@@ -1024,7 +1092,7 @@ class ChatStore {
|
||||
this.clearChatStreaming(activeConv.id);
|
||||
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const dbMessage = allMessages.find((m) => m.id === messageId);
|
||||
const dbMessage = findMessageById(allMessages, messageId);
|
||||
|
||||
if (!dbMessage) {
|
||||
this.setChatLoading(activeConv.id, false);
|
||||
@@ -1032,56 +1100,40 @@ class ChatStore {
|
||||
}
|
||||
|
||||
const originalContent = dbMessage.content;
|
||||
const originalReasoning = dbMessage.reasoningContent || '';
|
||||
const conversationContext = conversationsStore.activeMessages.slice(0, idx);
|
||||
const contextWithContinue = [
|
||||
...conversationContext,
|
||||
{ role: MessageRole.ASSISTANT as const, content: originalContent }
|
||||
];
|
||||
|
||||
let appendedContent = '',
|
||||
hasReceivedContent = false,
|
||||
isReasoningOpen = hasUnclosedReasoningTag(originalContent);
|
||||
let appendedContent = '';
|
||||
let appendedReasoning = '';
|
||||
let hasReceivedContent = false;
|
||||
|
||||
const updateStreamingContent = (fullContent: string) => {
|
||||
this.setChatStreaming(msg.convId, fullContent, msg.id);
|
||||
conversationsStore.updateMessageAtIndex(idx, { content: fullContent });
|
||||
};
|
||||
|
||||
const appendContentChunk = (chunk: string) => {
|
||||
if (isReasoningOpen) {
|
||||
appendedContent += REASONING_TAGS.END;
|
||||
isReasoningOpen = false;
|
||||
}
|
||||
appendedContent += chunk;
|
||||
hasReceivedContent = true;
|
||||
updateStreamingContent(originalContent + appendedContent);
|
||||
};
|
||||
|
||||
const appendReasoningChunk = (chunk: string) => {
|
||||
if (!isReasoningOpen) {
|
||||
appendedContent += REASONING_TAGS.START;
|
||||
isReasoningOpen = true;
|
||||
}
|
||||
appendedContent += chunk;
|
||||
hasReceivedContent = true;
|
||||
updateStreamingContent(originalContent + appendedContent);
|
||||
};
|
||||
|
||||
const finalizeReasoning = () => {
|
||||
if (isReasoningOpen) {
|
||||
appendedContent += REASONING_TAGS.END;
|
||||
isReasoningOpen = false;
|
||||
}
|
||||
};
|
||||
|
||||
const abortController = this.getOrCreateAbortController(msg.convId);
|
||||
|
||||
await ChatService.sendMessage(
|
||||
contextWithContinue,
|
||||
{
|
||||
...this.getApiOptions(),
|
||||
onChunk: (chunk: string) => appendContentChunk(chunk),
|
||||
onReasoningChunk: (chunk: string) => appendReasoningChunk(chunk),
|
||||
onChunk: (chunk: string) => {
|
||||
appendedContent += chunk;
|
||||
hasReceivedContent = true;
|
||||
updateStreamingContent(originalContent + appendedContent);
|
||||
},
|
||||
onReasoningChunk: (chunk: string) => {
|
||||
appendedReasoning += chunk;
|
||||
hasReceivedContent = true;
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
reasoningContent: originalReasoning + appendedReasoning
|
||||
});
|
||||
},
|
||||
onTimings: (timings?: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => {
|
||||
const tokensPerSecond =
|
||||
timings?.predicted_ms && timings?.predicted_n
|
||||
@@ -1104,21 +1156,23 @@ class ChatStore {
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings
|
||||
) => {
|
||||
finalizeReasoning();
|
||||
|
||||
const appendedFromCompletion = hasReceivedContent
|
||||
? appendedContent
|
||||
: wrapReasoningContent(finalContent || '', reasoningContent);
|
||||
const fullContent = originalContent + appendedFromCompletion;
|
||||
const finalAppendedContent = hasReceivedContent ? appendedContent : finalContent || '';
|
||||
const finalAppendedReasoning = hasReceivedContent
|
||||
? appendedReasoning
|
||||
: reasoningContent || '';
|
||||
const fullContent = originalContent + finalAppendedContent;
|
||||
const fullReasoning = originalReasoning + finalAppendedReasoning || undefined;
|
||||
|
||||
await DatabaseService.updateMessage(msg.id, {
|
||||
content: fullContent,
|
||||
reasoningContent: fullReasoning,
|
||||
timestamp: Date.now(),
|
||||
timings
|
||||
});
|
||||
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
content: fullContent,
|
||||
reasoningContent: fullReasoning,
|
||||
timestamp: Date.now(),
|
||||
timings
|
||||
});
|
||||
@@ -1134,11 +1188,13 @@ class ChatStore {
|
||||
if (hasReceivedContent && appendedContent) {
|
||||
await DatabaseService.updateMessage(msg.id, {
|
||||
content: originalContent + appendedContent,
|
||||
reasoningContent: originalReasoning + appendedReasoning || undefined,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
content: originalContent + appendedContent,
|
||||
reasoningContent: originalReasoning + appendedReasoning || undefined,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
}
|
||||
@@ -1280,7 +1336,10 @@ class ChatStore {
|
||||
|
||||
let messageIdForResponse: string;
|
||||
|
||||
if (msg.children.length === 0) {
|
||||
const dbMsg = findMessageById(allMessages, msg.id);
|
||||
const hasChildren = dbMsg ? dbMsg.children.length > 0 : msg.children.length > 0;
|
||||
|
||||
if (!hasChildren) {
|
||||
// No responses after this message — update in place instead of branching
|
||||
const updates: Partial<DatabaseMessage> = {
|
||||
content: newContent,
|
||||
|
||||
@@ -23,7 +23,7 @@ import { browser } from '$app/environment';
|
||||
import { toast } from 'svelte-sonner';
|
||||
import { DatabaseService } from '$lib/services/database.service';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { filterByLeafNodeId, findLeafNode } from '$lib/utils';
|
||||
import { filterByLeafNodeId, findLeafNode, runLegacyMigration } from '$lib/utils';
|
||||
import type { McpServerOverride } from '$lib/types/database';
|
||||
import { MessageRole } from '$lib/enums';
|
||||
import {
|
||||
@@ -128,6 +128,10 @@ class ConversationsStore {
|
||||
if (this.isInitialized) return;
|
||||
|
||||
try {
|
||||
// @deprecated Legacy migration for old marker-based messages.
|
||||
// Remove once all users have migrated to the structured format.
|
||||
await runLegacyMigration();
|
||||
|
||||
await this.loadConversations();
|
||||
this.isInitialized = true;
|
||||
} catch (error) {
|
||||
|
||||
43
tools/server/webui/src/lib/types/agentic.d.ts
vendored
43
tools/server/webui/src/lib/types/agentic.d.ts
vendored
@@ -2,6 +2,7 @@ import type { MessageRole } from '$lib/enums';
|
||||
import { ToolCallType } from '$lib/enums';
|
||||
import type {
|
||||
ApiChatCompletionRequest,
|
||||
ApiChatCompletionToolCall,
|
||||
ApiChatMessageContentPart,
|
||||
ApiChatMessageData
|
||||
} from './api';
|
||||
@@ -70,22 +71,48 @@ export interface AgenticSession {
|
||||
}
|
||||
|
||||
/**
|
||||
* Callbacks for agentic flow execution
|
||||
* Callbacks for agentic flow execution.
|
||||
*
|
||||
* The agentic loop creates separate DB messages for each turn:
|
||||
* - assistant messages (one per LLM turn, with tool_calls if any)
|
||||
* - tool result messages (one per tool call execution)
|
||||
*
|
||||
* The first assistant message is created by the caller before starting the flow.
|
||||
* Subsequent messages are created via createToolResultMessage / createAssistantMessage.
|
||||
*/
|
||||
export interface AgenticFlowCallbacks {
|
||||
/** Content chunk for the current assistant message */
|
||||
onChunk?: (chunk: string) => void;
|
||||
/** Reasoning content chunk for the current assistant message */
|
||||
onReasoningChunk?: (chunk: string) => void;
|
||||
onToolCallChunk?: (serializedToolCalls: string) => void;
|
||||
onAttachments?: (extras: DatabaseMessageExtra[]) => void;
|
||||
/** Tool calls being streamed (partial, accumulating) for the current turn */
|
||||
onToolCallsStreaming?: (toolCalls: ApiChatCompletionToolCall[]) => void;
|
||||
/** Attachments extracted from tool results */
|
||||
onAttachments?: (messageId: string, extras: DatabaseMessageExtra[]) => void;
|
||||
/** Model name detected from response */
|
||||
onModel?: (model: string) => void;
|
||||
onComplete?: (
|
||||
/** Current assistant turn's streaming is complete - save to DB */
|
||||
onAssistantTurnComplete?: (
|
||||
content: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings,
|
||||
toolCalls?: string
|
||||
) => void;
|
||||
reasoningContent: string | undefined,
|
||||
timings: ChatMessageTimings | undefined,
|
||||
toolCalls: ApiChatCompletionToolCall[] | undefined
|
||||
) => Promise<void>;
|
||||
/** Create a tool result message in the DB tree */
|
||||
createToolResultMessage?: (
|
||||
toolCallId: string,
|
||||
content: string,
|
||||
extras?: DatabaseMessageExtra[]
|
||||
) => Promise<DatabaseMessage>;
|
||||
/** Create a new assistant message for the next agentic turn */
|
||||
createAssistantMessage?: () => Promise<DatabaseMessage>;
|
||||
/** Entire agentic flow is complete */
|
||||
onFlowComplete?: (timings?: ChatMessageTimings) => void;
|
||||
/** Error during flow */
|
||||
onError?: (error: Error) => void;
|
||||
/** Timing updates during streaming */
|
||||
onTimings?: (timings?: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => void;
|
||||
/** An agentic turn (LLM + tool execution) completed - intermediate timing update */
|
||||
onTurnComplete?: (intermediateTimings: ChatMessageTimings) => void;
|
||||
}
|
||||
|
||||
|
||||
28
tools/server/webui/src/lib/types/chat.d.ts
vendored
28
tools/server/webui/src/lib/types/chat.d.ts
vendored
@@ -1,5 +1,6 @@
|
||||
import type { ErrorDialogType } from '$lib/enums';
|
||||
import type { DatabaseMessageExtra } from './database';
|
||||
import type { ApiChatCompletionToolCall } from './api';
|
||||
import type { DatabaseMessage, DatabaseMessageExtra } from './database';
|
||||
|
||||
export interface ChatUploadedFile {
|
||||
id: string;
|
||||
@@ -99,21 +100,28 @@ export interface ChatMessageToolCallTiming {
|
||||
}
|
||||
|
||||
/**
|
||||
* Callbacks for streaming chat responses
|
||||
* Callbacks for streaming chat responses (used by both agentic and non-agentic paths)
|
||||
*/
|
||||
export interface ChatStreamCallbacks {
|
||||
onChunk?: (chunk: string) => void;
|
||||
onReasoningChunk?: (chunk: string) => void;
|
||||
onToolCallChunk?: (chunk: string) => void;
|
||||
onAttachments?: (extras: DatabaseMessageExtra[]) => void;
|
||||
onToolCallsStreaming?: (toolCalls: ApiChatCompletionToolCall[]) => void;
|
||||
onAttachments?: (messageId: string, extras: DatabaseMessageExtra[]) => void;
|
||||
onModel?: (model: string) => void;
|
||||
onTimings?: (timings?: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => void;
|
||||
onComplete?: (
|
||||
content?: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings,
|
||||
toolCallContent?: string
|
||||
) => void;
|
||||
onAssistantTurnComplete?: (
|
||||
content: string,
|
||||
reasoningContent: string | undefined,
|
||||
timings: ChatMessageTimings | undefined,
|
||||
toolCalls: ApiChatCompletionToolCall[] | undefined
|
||||
) => Promise<void>;
|
||||
createToolResultMessage?: (
|
||||
toolCallId: string,
|
||||
content: string,
|
||||
extras?: DatabaseMessageExtra[]
|
||||
) => Promise<DatabaseMessage>;
|
||||
createAssistantMessage?: () => Promise<DatabaseMessage>;
|
||||
onFlowComplete?: (timings?: ChatMessageTimings) => void;
|
||||
onError?: (error: Error) => void;
|
||||
onTurnComplete?: (intermediateTimings: ChatMessageTimings) => void;
|
||||
}
|
||||
|
||||
@@ -92,6 +92,8 @@ export interface DatabaseMessage {
|
||||
* @deprecated - left for backward compatibility
|
||||
*/
|
||||
thinking?: string;
|
||||
/** Reasoning content produced by the model (separate from visible content) */
|
||||
reasoningContent?: string;
|
||||
/** Serialized JSON array of tool calls made by assistant messages */
|
||||
toolCalls?: string;
|
||||
/** Tool call ID for tool result messages (role: 'tool') */
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
import { AgenticSectionType } from '$lib/enums';
|
||||
import { AGENTIC_TAGS, AGENTIC_REGEX, REASONING_TAGS, TRIM_NEWLINES_REGEX } from '$lib/constants';
|
||||
import { AgenticSectionType, MessageRole } from '$lib/enums';
|
||||
import { ATTACHMENT_SAVED_REGEX, NEWLINE_SEPARATOR } from '$lib/constants';
|
||||
import type { ApiChatCompletionToolCall } from '$lib/types/api';
|
||||
import type {
|
||||
DatabaseMessage,
|
||||
DatabaseMessageExtra,
|
||||
DatabaseMessageExtraImageFile
|
||||
} from '$lib/types/database';
|
||||
import { AttachmentType } from '$lib/enums';
|
||||
|
||||
/**
|
||||
* Represents a parsed section of agentic content
|
||||
* Represents a parsed section of agentic content for display
|
||||
*/
|
||||
export interface AgenticSection {
|
||||
type: AgenticSectionType;
|
||||
@@ -10,63 +17,70 @@ export interface AgenticSection {
|
||||
toolName?: string;
|
||||
toolArgs?: string;
|
||||
toolResult?: string;
|
||||
toolResultExtras?: DatabaseMessageExtra[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a segment of content that may contain reasoning blocks
|
||||
* Represents a tool result line that may reference an image attachment
|
||||
*/
|
||||
type ReasoningSegment = {
|
||||
type:
|
||||
| AgenticSectionType.TEXT
|
||||
| AgenticSectionType.REASONING
|
||||
| AgenticSectionType.REASONING_PENDING;
|
||||
content: string;
|
||||
export type ToolResultLine = {
|
||||
text: string;
|
||||
image?: DatabaseMessageExtraImageFile;
|
||||
};
|
||||
|
||||
/**
|
||||
* Parses agentic content into structured sections
|
||||
* Derives display sections from a single assistant message and its direct tool results.
|
||||
*
|
||||
* Main parsing function that processes content containing:
|
||||
* - Tool calls (completed, pending, or streaming)
|
||||
* - Reasoning blocks (completed or streaming)
|
||||
* - Regular text content
|
||||
*
|
||||
* The parser handles chronological display of agentic flow output, maintaining
|
||||
* the order of operations and properly identifying different states of tool calls
|
||||
* and reasoning blocks during streaming.
|
||||
*
|
||||
* @param rawContent - The raw content string to parse
|
||||
* @returns Array of structured agentic sections ready for display
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const content = "Some text <<<AGENTIC_TOOL_CALL>>>tool_name...";
|
||||
* const sections = parseAgenticContent(content);
|
||||
* // Returns: [{ type: 'text', content: 'Some text' }, { type: 'tool_call_streaming', ... }]
|
||||
* ```
|
||||
* @param message - The assistant message
|
||||
* @param toolMessages - Tool result messages for this assistant's tool_calls
|
||||
* @param streamingToolCalls - Partial tool calls during streaming (not yet persisted)
|
||||
*/
|
||||
export function parseAgenticContent(rawContent: string): AgenticSection[] {
|
||||
if (!rawContent) return [];
|
||||
|
||||
const segments = splitReasoningSegments(rawContent);
|
||||
function deriveSingleTurnSections(
|
||||
message: DatabaseMessage,
|
||||
toolMessages: DatabaseMessage[] = [],
|
||||
streamingToolCalls: ApiChatCompletionToolCall[] = []
|
||||
): AgenticSection[] {
|
||||
const sections: AgenticSection[] = [];
|
||||
|
||||
for (const segment of segments) {
|
||||
if (segment.type === AgenticSectionType.TEXT) {
|
||||
sections.push(...parseToolCallContent(segment.content));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (segment.type === AgenticSectionType.REASONING) {
|
||||
if (segment.content.trim()) {
|
||||
sections.push({ type: AgenticSectionType.REASONING, content: segment.content });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// 1. Reasoning content (from dedicated field)
|
||||
if (message.reasoningContent) {
|
||||
sections.push({
|
||||
type: AgenticSectionType.REASONING_PENDING,
|
||||
content: segment.content
|
||||
type: AgenticSectionType.REASONING,
|
||||
content: message.reasoningContent
|
||||
});
|
||||
}
|
||||
|
||||
// 2. Text content
|
||||
if (message.content?.trim()) {
|
||||
sections.push({
|
||||
type: AgenticSectionType.TEXT,
|
||||
content: message.content
|
||||
});
|
||||
}
|
||||
|
||||
// 3. Persisted tool calls (from message.toolCalls field)
|
||||
const toolCalls = parseToolCalls(message.toolCalls);
|
||||
for (const tc of toolCalls) {
|
||||
const resultMsg = toolMessages.find((m) => m.toolCallId === tc.id);
|
||||
sections.push({
|
||||
type: resultMsg ? AgenticSectionType.TOOL_CALL : AgenticSectionType.TOOL_CALL_PENDING,
|
||||
content: resultMsg?.content || '',
|
||||
toolName: tc.function?.name,
|
||||
toolArgs: tc.function?.arguments,
|
||||
toolResult: resultMsg?.content,
|
||||
toolResultExtras: resultMsg?.extra
|
||||
});
|
||||
}
|
||||
|
||||
// 4. Streaming tool calls (not yet persisted - currently being received)
|
||||
for (const tc of streamingToolCalls) {
|
||||
// Skip if already in persisted tool calls
|
||||
if (tc.id && toolCalls.find((t) => t.id === tc.id)) continue;
|
||||
sections.push({
|
||||
type: AgenticSectionType.TOOL_CALL_STREAMING,
|
||||
content: '',
|
||||
toolName: tc.function?.name,
|
||||
toolArgs: tc.function?.arguments
|
||||
});
|
||||
}
|
||||
|
||||
@@ -74,211 +88,123 @@ export function parseAgenticContent(rawContent: string): AgenticSection[] {
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses content containing tool call markers
|
||||
* Derives display sections from structured message data.
|
||||
*
|
||||
* Identifies and extracts tool calls from content, handling:
|
||||
* - Completed tool calls with name, arguments, and results
|
||||
* - Pending tool calls (execution in progress)
|
||||
* - Streaming tool calls (arguments being received)
|
||||
* - Early-stage tool calls (just started)
|
||||
* Handles both single-turn (one assistant + its tool results) and multi-turn
|
||||
* agentic sessions (multiple assistant + tool messages grouped together).
|
||||
*
|
||||
* @param rawContent - The raw content string to parse
|
||||
* @returns Array of agentic sections representing tool calls and text
|
||||
* When `toolMessages` contains continuation assistant messages (from multi-turn
|
||||
* agentic flows), they are processed in order to produce sections across all turns.
|
||||
*
|
||||
* @param message - The first/anchor assistant message
|
||||
* @param toolMessages - Tool result messages and continuation assistant messages
|
||||
* @param streamingToolCalls - Partial tool calls during streaming (not yet persisted)
|
||||
* @param isStreaming - Whether the message is currently being streamed
|
||||
*/
|
||||
function parseToolCallContent(rawContent: string): AgenticSection[] {
|
||||
if (!rawContent) return [];
|
||||
export function deriveAgenticSections(
|
||||
message: DatabaseMessage,
|
||||
toolMessages: DatabaseMessage[] = [],
|
||||
streamingToolCalls: ApiChatCompletionToolCall[] = []
|
||||
): AgenticSection[] {
|
||||
const hasAssistantContinuations = toolMessages.some((m) => m.role === MessageRole.ASSISTANT);
|
||||
|
||||
if (!hasAssistantContinuations) {
|
||||
return deriveSingleTurnSections(message, toolMessages, streamingToolCalls);
|
||||
}
|
||||
|
||||
const sections: AgenticSection[] = [];
|
||||
|
||||
const completedToolCallRegex = new RegExp(AGENTIC_REGEX.COMPLETED_TOOL_CALL.source, 'g');
|
||||
const firstTurnToolMsgs = collectToolMessages(toolMessages, 0);
|
||||
sections.push(...deriveSingleTurnSections(message, firstTurnToolMsgs));
|
||||
|
||||
let lastIndex = 0;
|
||||
let match;
|
||||
let i = firstTurnToolMsgs.length;
|
||||
|
||||
while ((match = completedToolCallRegex.exec(rawContent)) !== null) {
|
||||
if (match.index > lastIndex) {
|
||||
const textBefore = rawContent.slice(lastIndex, match.index).trim();
|
||||
if (textBefore) {
|
||||
sections.push({ type: AgenticSectionType.TEXT, content: textBefore });
|
||||
}
|
||||
while (i < toolMessages.length) {
|
||||
const msg = toolMessages[i];
|
||||
|
||||
if (msg.role === MessageRole.ASSISTANT) {
|
||||
const turnToolMsgs = collectToolMessages(toolMessages, i + 1);
|
||||
const isLastTurn = i + 1 + turnToolMsgs.length >= toolMessages.length;
|
||||
|
||||
sections.push(
|
||||
...deriveSingleTurnSections(msg, turnToolMsgs, isLastTurn ? streamingToolCalls : [])
|
||||
);
|
||||
|
||||
i += 1 + turnToolMsgs.length;
|
||||
} else {
|
||||
i++;
|
||||
}
|
||||
|
||||
const toolName = match[1];
|
||||
const toolArgs = match[2];
|
||||
const toolResult = match[3].replace(TRIM_NEWLINES_REGEX, '');
|
||||
|
||||
sections.push({
|
||||
type: AgenticSectionType.TOOL_CALL,
|
||||
content: toolResult,
|
||||
toolName,
|
||||
toolArgs,
|
||||
toolResult
|
||||
});
|
||||
|
||||
lastIndex = match.index + match[0].length;
|
||||
}
|
||||
|
||||
const remainingContent = rawContent.slice(lastIndex);
|
||||
|
||||
const pendingMatch = remainingContent.match(AGENTIC_REGEX.PENDING_TOOL_CALL);
|
||||
const partialWithNameMatch = remainingContent.match(AGENTIC_REGEX.PARTIAL_WITH_NAME);
|
||||
const earlyMatch = remainingContent.match(AGENTIC_REGEX.EARLY_MATCH);
|
||||
|
||||
if (pendingMatch) {
|
||||
const pendingIndex = remainingContent.indexOf(AGENTIC_TAGS.TOOL_CALL_START);
|
||||
|
||||
if (pendingIndex > 0) {
|
||||
const textBefore = remainingContent.slice(0, pendingIndex).trim();
|
||||
|
||||
if (textBefore) {
|
||||
sections.push({ type: AgenticSectionType.TEXT, content: textBefore });
|
||||
}
|
||||
}
|
||||
|
||||
const toolName = pendingMatch[1];
|
||||
const toolArgs = pendingMatch[2];
|
||||
const streamingResult = (pendingMatch[3] || '').replace(TRIM_NEWLINES_REGEX, '');
|
||||
|
||||
sections.push({
|
||||
type: AgenticSectionType.TOOL_CALL_PENDING,
|
||||
content: streamingResult,
|
||||
toolName,
|
||||
toolArgs,
|
||||
toolResult: streamingResult || undefined
|
||||
});
|
||||
} else if (partialWithNameMatch) {
|
||||
const pendingIndex = remainingContent.indexOf(AGENTIC_TAGS.TOOL_CALL_START);
|
||||
|
||||
if (pendingIndex > 0) {
|
||||
const textBefore = remainingContent.slice(0, pendingIndex).trim();
|
||||
if (textBefore) {
|
||||
sections.push({ type: AgenticSectionType.TEXT, content: textBefore });
|
||||
}
|
||||
}
|
||||
|
||||
const partialArgs = partialWithNameMatch[2] || '';
|
||||
|
||||
sections.push({
|
||||
type: AgenticSectionType.TOOL_CALL_STREAMING,
|
||||
content: '',
|
||||
toolName: partialWithNameMatch[1],
|
||||
toolArgs: partialArgs || undefined,
|
||||
toolResult: undefined
|
||||
});
|
||||
} else if (earlyMatch) {
|
||||
const pendingIndex = remainingContent.indexOf(AGENTIC_TAGS.TOOL_CALL_START);
|
||||
|
||||
if (pendingIndex > 0) {
|
||||
const textBefore = remainingContent.slice(0, pendingIndex).trim();
|
||||
if (textBefore) {
|
||||
sections.push({ type: AgenticSectionType.TEXT, content: textBefore });
|
||||
}
|
||||
}
|
||||
|
||||
const nameMatch = earlyMatch[1]?.match(AGENTIC_REGEX.TOOL_NAME_EXTRACT);
|
||||
|
||||
sections.push({
|
||||
type: AgenticSectionType.TOOL_CALL_STREAMING,
|
||||
content: '',
|
||||
toolName: nameMatch?.[1],
|
||||
toolArgs: undefined,
|
||||
toolResult: undefined
|
||||
});
|
||||
} else if (lastIndex < rawContent.length) {
|
||||
let remainingText = rawContent.slice(lastIndex).trim();
|
||||
|
||||
const partialMarkerMatch = remainingText.match(AGENTIC_REGEX.PARTIAL_MARKER);
|
||||
|
||||
if (partialMarkerMatch) {
|
||||
remainingText = remainingText.slice(0, partialMarkerMatch.index).trim();
|
||||
}
|
||||
|
||||
if (remainingText) {
|
||||
sections.push({ type: AgenticSectionType.TEXT, content: remainingText });
|
||||
}
|
||||
}
|
||||
|
||||
if (sections.length === 0 && rawContent.trim()) {
|
||||
sections.push({ type: AgenticSectionType.TEXT, content: rawContent });
|
||||
}
|
||||
|
||||
return sections;
|
||||
}
|
||||
|
||||
/**
|
||||
* Strips partial marker from text content
|
||||
*
|
||||
* Removes incomplete agentic markers (e.g., "<<<", "<<<AGENTIC") that may appear
|
||||
* at the end of streaming content.
|
||||
*
|
||||
* @param text - The text content to process
|
||||
* @returns Text with partial markers removed
|
||||
* Collect consecutive tool messages starting at `startIndex`.
|
||||
*/
|
||||
function stripPartialMarker(text: string): string {
|
||||
const partialMarkerMatch = text.match(AGENTIC_REGEX.PARTIAL_MARKER);
|
||||
function collectToolMessages(messages: DatabaseMessage[], startIndex: number): DatabaseMessage[] {
|
||||
const result: DatabaseMessage[] = [];
|
||||
|
||||
if (partialMarkerMatch) {
|
||||
return text.slice(0, partialMarkerMatch.index).trim();
|
||||
for (let i = startIndex; i < messages.length; i++) {
|
||||
if (messages[i].role === MessageRole.TOOL) {
|
||||
result.push(messages[i]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return text;
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Splits raw content into segments based on reasoning blocks
|
||||
*
|
||||
* Identifies and extracts reasoning content wrapped in REASONING_TAGS.START/END markers,
|
||||
* separating it from regular text content. Handles both complete and incomplete
|
||||
* (streaming) reasoning blocks.
|
||||
*
|
||||
* @param rawContent - The raw content string to parse
|
||||
* @returns Array of reasoning segments with their types and content
|
||||
* Parse tool result text into lines, matching image attachments by name.
|
||||
*/
|
||||
function splitReasoningSegments(rawContent: string): ReasoningSegment[] {
|
||||
if (!rawContent) return [];
|
||||
export function parseToolResultWithImages(
|
||||
toolResult: string,
|
||||
extras?: DatabaseMessageExtra[]
|
||||
): ToolResultLine[] {
|
||||
const lines = toolResult.split(NEWLINE_SEPARATOR);
|
||||
return lines.map((line) => {
|
||||
const match = line.match(ATTACHMENT_SAVED_REGEX);
|
||||
if (!match || !extras) return { text: line };
|
||||
|
||||
const segments: ReasoningSegment[] = [];
|
||||
let cursor = 0;
|
||||
const attachmentName = match[1];
|
||||
const image = extras.find(
|
||||
(e): e is DatabaseMessageExtraImageFile =>
|
||||
e.type === AttachmentType.IMAGE && e.name === attachmentName
|
||||
);
|
||||
|
||||
while (cursor < rawContent.length) {
|
||||
const startIndex = rawContent.indexOf(REASONING_TAGS.START, cursor);
|
||||
return { text: line, image };
|
||||
});
|
||||
}
|
||||
|
||||
if (startIndex === -1) {
|
||||
const remainingText = rawContent.slice(cursor);
|
||||
/**
|
||||
* Safely parse the toolCalls JSON string from a DatabaseMessage.
|
||||
*/
|
||||
function parseToolCalls(toolCallsJson?: string): ApiChatCompletionToolCall[] {
|
||||
if (!toolCallsJson) return [];
|
||||
|
||||
if (remainingText) {
|
||||
segments.push({ type: AgenticSectionType.TEXT, content: remainingText });
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(toolCallsJson);
|
||||
|
||||
break;
|
||||
}
|
||||
return Array.isArray(parsed) ? parsed : [];
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
if (startIndex > cursor) {
|
||||
const textBefore = rawContent.slice(cursor, startIndex);
|
||||
/**
|
||||
* Check if a message has agentic content (tool calls or is part of an agentic flow).
|
||||
*/
|
||||
export function hasAgenticContent(
|
||||
message: DatabaseMessage,
|
||||
toolMessages: DatabaseMessage[] = []
|
||||
): boolean {
|
||||
if (message.toolCalls) {
|
||||
const tc = parseToolCalls(message.toolCalls);
|
||||
|
||||
if (textBefore) {
|
||||
segments.push({ type: AgenticSectionType.TEXT, content: textBefore });
|
||||
}
|
||||
}
|
||||
|
||||
const contentStart = startIndex + REASONING_TAGS.START.length;
|
||||
const endIndex = rawContent.indexOf(REASONING_TAGS.END, contentStart);
|
||||
|
||||
if (endIndex === -1) {
|
||||
const pendingContent = rawContent.slice(contentStart);
|
||||
|
||||
segments.push({
|
||||
type: AgenticSectionType.REASONING_PENDING,
|
||||
content: stripPartialMarker(pendingContent)
|
||||
});
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
const reasoningContent = rawContent.slice(contentStart, endIndex);
|
||||
segments.push({ type: AgenticSectionType.REASONING, content: reasoningContent });
|
||||
cursor = endIndex + REASONING_TAGS.END.length;
|
||||
if (tc.length > 0) return true;
|
||||
}
|
||||
|
||||
return segments;
|
||||
return toolMessages.length > 0;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,17 @@
|
||||
|
||||
import { MessageRole } from '$lib/enums';
|
||||
|
||||
/**
|
||||
* Finds a message by its ID in the given messages array.
|
||||
*/
|
||||
export function findMessageById(
|
||||
messages: readonly DatabaseMessage[],
|
||||
id: string | null | undefined
|
||||
): DatabaseMessage | undefined {
|
||||
if (!id) return undefined;
|
||||
return messages.find((m) => m.id === id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters messages to get the conversation path from root to a specific leaf node.
|
||||
* If the leafNodeId doesn't exist, returns the path with the latest timestamp.
|
||||
|
||||
@@ -22,6 +22,7 @@ export { default as autoResizeTextarea } from './autoresize-textarea';
|
||||
// Branching utilities
|
||||
export {
|
||||
filterByLeafNodeId,
|
||||
findMessageById,
|
||||
findLeafNode,
|
||||
findDescendantMessages,
|
||||
getMessageSiblings,
|
||||
@@ -148,8 +149,17 @@ export { parseHeadersToArray, serializeHeaders } from './headers';
|
||||
// Favicon utilities
|
||||
export { getFaviconUrl } from './favicon';
|
||||
|
||||
// Agentic content parsing utilities
|
||||
export { parseAgenticContent, type AgenticSection } from './agentic';
|
||||
// Agentic content utilities (structured section derivation)
|
||||
export {
|
||||
deriveAgenticSections,
|
||||
parseToolResultWithImages,
|
||||
hasAgenticContent,
|
||||
type AgenticSection,
|
||||
type ToolResultLine
|
||||
} from './agentic';
|
||||
|
||||
// Legacy migration utilities
|
||||
export { runLegacyMigration, isMigrationNeeded } from './legacy-migration';
|
||||
|
||||
// Cache utilities
|
||||
export { TTLCache, ReactiveTTLMap, type TTLCacheOptions } from './cache-ttl';
|
||||
|
||||
345
tools/server/webui/src/lib/utils/legacy-migration.ts
Normal file
345
tools/server/webui/src/lib/utils/legacy-migration.ts
Normal file
@@ -0,0 +1,345 @@
|
||||
/**
|
||||
* @deprecated Legacy migration utility — remove at some point in the future once all users have migrated to the new structured agentic message format.
|
||||
*
|
||||
* Converts old marker-based agentic messages to the new structured format
|
||||
* with separate messages per turn.
|
||||
*
|
||||
* Old format: Single assistant message with markers in content:
|
||||
* <<<reasoning_content_start>>>...<<<reasoning_content_end>>>
|
||||
* <<<AGENTIC_TOOL_CALL_START>>>...<<<AGENTIC_TOOL_CALL_END>>>
|
||||
*
|
||||
* New format: Separate messages per turn:
|
||||
* - assistant (content + reasoningContent + toolCalls)
|
||||
* - tool (toolCallId + content)
|
||||
* - assistant (next turn)
|
||||
* - ...
|
||||
*/
|
||||
|
||||
import { LEGACY_AGENTIC_REGEX, LEGACY_REASONING_TAGS } from '$lib/constants';
|
||||
import { DatabaseService } from '$lib/services/database.service';
|
||||
import { MessageRole, MessageType } from '$lib/enums';
|
||||
import type { DatabaseMessage } from '$lib/types/database';
|
||||
|
||||
const MIGRATION_DONE_KEY = 'llama-webui-migration-v2-done';
|
||||
|
||||
/**
|
||||
* @deprecated Part of legacy migration — remove with the migration module.
|
||||
* Check if migration has been performed.
|
||||
*/
|
||||
export function isMigrationNeeded(): boolean {
|
||||
try {
|
||||
return !localStorage.getItem(MIGRATION_DONE_KEY);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark migration as done.
|
||||
*/
|
||||
function markMigrationDone(): void {
|
||||
try {
|
||||
localStorage.setItem(MIGRATION_DONE_KEY, String(Date.now()));
|
||||
} catch {
|
||||
// Ignore localStorage errors
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message has legacy markers in its content.
|
||||
*/
|
||||
function hasLegacyMarkers(message: DatabaseMessage): boolean {
|
||||
if (!message.content) return false;
|
||||
return LEGACY_AGENTIC_REGEX.HAS_LEGACY_MARKERS.test(message.content);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract reasoning content from legacy marker format.
|
||||
*/
|
||||
function extractLegacyReasoning(content: string): { reasoning: string; cleanContent: string } {
|
||||
let reasoning = '';
|
||||
let cleanContent = content;
|
||||
|
||||
// Extract all reasoning blocks
|
||||
const re = new RegExp(LEGACY_AGENTIC_REGEX.REASONING_EXTRACT.source, 'g');
|
||||
let match;
|
||||
while ((match = re.exec(content)) !== null) {
|
||||
reasoning += match[1];
|
||||
}
|
||||
|
||||
// Remove reasoning tags from content
|
||||
cleanContent = cleanContent
|
||||
.replace(new RegExp(LEGACY_AGENTIC_REGEX.REASONING_BLOCK.source, 'g'), '')
|
||||
.replace(LEGACY_AGENTIC_REGEX.REASONING_OPEN, '');
|
||||
|
||||
return { reasoning, cleanContent };
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse legacy content with tool call markers into structured turns.
|
||||
*/
|
||||
interface ParsedTurn {
|
||||
textBefore: string;
|
||||
toolCalls: Array<{
|
||||
name: string;
|
||||
args: string;
|
||||
result: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
function parseLegacyToolCalls(content: string): ParsedTurn[] {
|
||||
const turns: ParsedTurn[] = [];
|
||||
const regex = new RegExp(LEGACY_AGENTIC_REGEX.COMPLETED_TOOL_CALL.source, 'g');
|
||||
|
||||
let lastIndex = 0;
|
||||
let currentTurn: ParsedTurn = { textBefore: '', toolCalls: [] };
|
||||
let match;
|
||||
|
||||
while ((match = regex.exec(content)) !== null) {
|
||||
const textBefore = content.slice(lastIndex, match.index).trim();
|
||||
|
||||
// If there's text between tool calls and we already have tool calls,
|
||||
// that means a new turn started (text after tool results = new LLM turn)
|
||||
if (textBefore && currentTurn.toolCalls.length > 0) {
|
||||
turns.push(currentTurn);
|
||||
currentTurn = { textBefore, toolCalls: [] };
|
||||
} else if (textBefore && currentTurn.toolCalls.length === 0) {
|
||||
currentTurn.textBefore = textBefore;
|
||||
}
|
||||
|
||||
currentTurn.toolCalls.push({
|
||||
name: match[1],
|
||||
args: match[2],
|
||||
result: match[3].replace(/^\n+|\n+$/g, '')
|
||||
});
|
||||
|
||||
lastIndex = match.index + match[0].length;
|
||||
}
|
||||
|
||||
// Any remaining text after the last tool call
|
||||
const remainingText = content.slice(lastIndex).trim();
|
||||
|
||||
if (currentTurn.toolCalls.length > 0) {
|
||||
turns.push(currentTurn);
|
||||
}
|
||||
|
||||
// If there's text after all tool calls, it's the final assistant response
|
||||
if (remainingText) {
|
||||
// Remove any partial/open markers
|
||||
const cleanRemaining = remainingText
|
||||
.replace(LEGACY_AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '')
|
||||
.trim();
|
||||
if (cleanRemaining) {
|
||||
turns.push({ textBefore: cleanRemaining, toolCalls: [] });
|
||||
}
|
||||
}
|
||||
|
||||
// If no tool calls found at all, return the original content as a single turn
|
||||
if (turns.length === 0) {
|
||||
turns.push({ textBefore: content.trim(), toolCalls: [] });
|
||||
}
|
||||
|
||||
return turns;
|
||||
}
|
||||
|
||||
/**
|
||||
* Migrate a single conversation's messages from legacy format to new format.
|
||||
*/
|
||||
async function migrateConversation(convId: string): Promise<number> {
|
||||
const allMessages = await DatabaseService.getConversationMessages(convId);
|
||||
let migratedCount = 0;
|
||||
|
||||
for (const message of allMessages) {
|
||||
if (message.role !== MessageRole.ASSISTANT) continue;
|
||||
if (!hasLegacyMarkers(message)) {
|
||||
// Still check for reasoning-only markers (no tool calls)
|
||||
if (message.content?.includes(LEGACY_REASONING_TAGS.START)) {
|
||||
const { reasoning, cleanContent } = extractLegacyReasoning(message.content);
|
||||
await DatabaseService.updateMessage(message.id, {
|
||||
content: cleanContent.trim(),
|
||||
reasoningContent: reasoning || undefined
|
||||
});
|
||||
migratedCount++;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Has agentic markers - full migration needed
|
||||
const { reasoning, cleanContent } = extractLegacyReasoning(message.content);
|
||||
const turns = parseLegacyToolCalls(cleanContent);
|
||||
|
||||
// Parse existing toolCalls JSON to try to match IDs
|
||||
let existingToolCalls: Array<{ id: string; function?: { name: string; arguments: string } }> =
|
||||
[];
|
||||
if (message.toolCalls) {
|
||||
try {
|
||||
existingToolCalls = JSON.parse(message.toolCalls);
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
|
||||
// First turn uses the existing message
|
||||
const firstTurn = turns[0];
|
||||
if (!firstTurn) continue;
|
||||
|
||||
// Match tool calls from the first turn to existing IDs
|
||||
const firstTurnToolCalls = firstTurn.toolCalls.map((tc, i) => {
|
||||
const existing =
|
||||
existingToolCalls.find((e) => e.function?.name === tc.name) || existingToolCalls[i];
|
||||
return {
|
||||
id: existing?.id || `legacy_tool_${i}`,
|
||||
type: 'function' as const,
|
||||
function: { name: tc.name, arguments: tc.args }
|
||||
};
|
||||
});
|
||||
|
||||
// Update the existing message for the first turn
|
||||
await DatabaseService.updateMessage(message.id, {
|
||||
content: firstTurn.textBefore,
|
||||
reasoningContent: reasoning || undefined,
|
||||
toolCalls: firstTurnToolCalls.length > 0 ? JSON.stringify(firstTurnToolCalls) : ''
|
||||
});
|
||||
|
||||
let currentParentId = message.id;
|
||||
let toolCallIdCounter = existingToolCalls.length;
|
||||
|
||||
// Create tool result messages for the first turn
|
||||
for (let i = 0; i < firstTurn.toolCalls.length; i++) {
|
||||
const tc = firstTurn.toolCalls[i];
|
||||
const toolCallId = firstTurnToolCalls[i]?.id || `legacy_tool_${i}`;
|
||||
|
||||
const toolMsg = await DatabaseService.createMessageBranch(
|
||||
{
|
||||
convId,
|
||||
type: MessageType.TEXT,
|
||||
role: MessageRole.TOOL,
|
||||
content: tc.result,
|
||||
toolCallId,
|
||||
timestamp: message.timestamp + i + 1,
|
||||
toolCalls: '',
|
||||
children: []
|
||||
},
|
||||
currentParentId
|
||||
);
|
||||
currentParentId = toolMsg.id;
|
||||
}
|
||||
|
||||
// Create messages for subsequent turns
|
||||
for (let turnIdx = 1; turnIdx < turns.length; turnIdx++) {
|
||||
const turn = turns[turnIdx];
|
||||
|
||||
const turnToolCalls = turn.toolCalls.map((tc, i) => {
|
||||
const idx = toolCallIdCounter + i;
|
||||
const existing = existingToolCalls[idx];
|
||||
return {
|
||||
id: existing?.id || `legacy_tool_${idx}`,
|
||||
type: 'function' as const,
|
||||
function: { name: tc.name, arguments: tc.args }
|
||||
};
|
||||
});
|
||||
toolCallIdCounter += turn.toolCalls.length;
|
||||
|
||||
// Create assistant message for this turn
|
||||
const assistantMsg = await DatabaseService.createMessageBranch(
|
||||
{
|
||||
convId,
|
||||
type: MessageType.TEXT,
|
||||
role: MessageRole.ASSISTANT,
|
||||
content: turn.textBefore,
|
||||
timestamp: message.timestamp + turnIdx * 100,
|
||||
toolCalls: turnToolCalls.length > 0 ? JSON.stringify(turnToolCalls) : '',
|
||||
children: [],
|
||||
model: message.model
|
||||
},
|
||||
currentParentId
|
||||
);
|
||||
currentParentId = assistantMsg.id;
|
||||
|
||||
// Create tool result messages for this turn
|
||||
for (let i = 0; i < turn.toolCalls.length; i++) {
|
||||
const tc = turn.toolCalls[i];
|
||||
const toolCallId = turnToolCalls[i]?.id || `legacy_tool_${toolCallIdCounter + i}`;
|
||||
|
||||
const toolMsg = await DatabaseService.createMessageBranch(
|
||||
{
|
||||
convId,
|
||||
type: MessageType.TEXT,
|
||||
role: MessageRole.TOOL,
|
||||
content: tc.result,
|
||||
toolCallId,
|
||||
timestamp: message.timestamp + turnIdx * 100 + i + 1,
|
||||
toolCalls: '',
|
||||
children: []
|
||||
},
|
||||
currentParentId
|
||||
);
|
||||
currentParentId = toolMsg.id;
|
||||
}
|
||||
}
|
||||
|
||||
// Re-parent any children of the original message to the last created message
|
||||
// (the original message's children list was the next user message or similar)
|
||||
if (message.children.length > 0 && currentParentId !== message.id) {
|
||||
for (const childId of message.children) {
|
||||
// Skip children we just created (they were already properly parented)
|
||||
const child = allMessages.find((m) => m.id === childId);
|
||||
if (!child) continue;
|
||||
// Only re-parent non-tool messages that were original children
|
||||
if (child.role !== MessageRole.TOOL) {
|
||||
await DatabaseService.updateMessage(childId, { parent: currentParentId });
|
||||
// Add to new parent's children
|
||||
const newParent = await DatabaseService.getConversationMessages(convId).then((msgs) =>
|
||||
msgs.find((m) => m.id === currentParentId)
|
||||
);
|
||||
if (newParent && !newParent.children.includes(childId)) {
|
||||
await DatabaseService.updateMessage(currentParentId, {
|
||||
children: [...newParent.children, childId]
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Clear re-parented children from the original message
|
||||
await DatabaseService.updateMessage(message.id, { children: [] });
|
||||
}
|
||||
|
||||
migratedCount++;
|
||||
}
|
||||
|
||||
return migratedCount;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Part of legacy migration — remove with the migration module.
|
||||
* Run the full migration across all conversations.
|
||||
* This should be called once at app startup if migration is needed.
|
||||
*/
|
||||
export async function runLegacyMigration(): Promise<void> {
|
||||
if (!isMigrationNeeded()) return;
|
||||
|
||||
console.log('[Migration] Starting legacy message format migration...');
|
||||
|
||||
try {
|
||||
const conversations = await DatabaseService.getAllConversations();
|
||||
let totalMigrated = 0;
|
||||
|
||||
for (const conv of conversations) {
|
||||
const count = await migrateConversation(conv.id);
|
||||
totalMigrated += count;
|
||||
}
|
||||
|
||||
if (totalMigrated > 0) {
|
||||
console.log(
|
||||
`[Migration] Migrated ${totalMigrated} messages across ${conversations.length} conversations`
|
||||
);
|
||||
} else {
|
||||
console.log('[Migration] No legacy messages found, marking as done');
|
||||
}
|
||||
|
||||
markMigrationDone();
|
||||
} catch (error) {
|
||||
console.error('[Migration] Failed to migrate legacy messages:', error);
|
||||
// Still mark as done to avoid infinite retry loops
|
||||
markMigrationDone();
|
||||
}
|
||||
}
|
||||
211
tools/server/webui/tests/unit/agentic-sections.test.ts
Normal file
211
tools/server/webui/tests/unit/agentic-sections.test.ts
Normal file
@@ -0,0 +1,211 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { deriveAgenticSections, hasAgenticContent } from '$lib/utils/agentic';
|
||||
import { AgenticSectionType, MessageRole } from '$lib/enums';
|
||||
import type { DatabaseMessage } from '$lib/types/database';
|
||||
import type { ApiChatCompletionToolCall } from '$lib/types/api';
|
||||
|
||||
function makeAssistant(overrides: Partial<DatabaseMessage> = {}): DatabaseMessage {
|
||||
return {
|
||||
id: overrides.id ?? 'ast-1',
|
||||
convId: 'conv-1',
|
||||
type: 'text',
|
||||
timestamp: Date.now(),
|
||||
role: MessageRole.ASSISTANT,
|
||||
content: overrides.content ?? '',
|
||||
parent: null,
|
||||
children: [],
|
||||
...overrides
|
||||
} as DatabaseMessage;
|
||||
}
|
||||
|
||||
function makeToolMsg(overrides: Partial<DatabaseMessage> = {}): DatabaseMessage {
|
||||
return {
|
||||
id: overrides.id ?? 'tool-1',
|
||||
convId: 'conv-1',
|
||||
type: 'text',
|
||||
timestamp: Date.now(),
|
||||
role: MessageRole.TOOL,
|
||||
content: overrides.content ?? 'tool result',
|
||||
parent: null,
|
||||
children: [],
|
||||
toolCallId: overrides.toolCallId ?? 'call_1',
|
||||
...overrides
|
||||
} as DatabaseMessage;
|
||||
}
|
||||
|
||||
describe('deriveAgenticSections', () => {
|
||||
it('returns empty array for assistant with no content', () => {
|
||||
const msg = makeAssistant({ content: '' });
|
||||
const sections = deriveAgenticSections(msg);
|
||||
expect(sections).toEqual([]);
|
||||
});
|
||||
|
||||
it('returns text section for simple assistant message', () => {
|
||||
const msg = makeAssistant({ content: 'Hello world' });
|
||||
const sections = deriveAgenticSections(msg);
|
||||
expect(sections).toHaveLength(1);
|
||||
expect(sections[0].type).toBe(AgenticSectionType.TEXT);
|
||||
expect(sections[0].content).toBe('Hello world');
|
||||
});
|
||||
|
||||
it('returns reasoning + text for message with reasoning', () => {
|
||||
const msg = makeAssistant({
|
||||
content: 'Answer is 4.',
|
||||
reasoningContent: 'Let me think...'
|
||||
});
|
||||
const sections = deriveAgenticSections(msg);
|
||||
expect(sections).toHaveLength(2);
|
||||
expect(sections[0].type).toBe(AgenticSectionType.REASONING);
|
||||
expect(sections[0].content).toBe('Let me think...');
|
||||
expect(sections[1].type).toBe(AgenticSectionType.TEXT);
|
||||
});
|
||||
|
||||
it('single turn: assistant with tool calls and results', () => {
|
||||
const msg = makeAssistant({
|
||||
content: 'Let me check.',
|
||||
toolCalls: JSON.stringify([
|
||||
{ id: 'call_1', type: 'function', function: { name: 'search', arguments: '{"q":"test"}' } }
|
||||
])
|
||||
});
|
||||
const toolResult = makeToolMsg({
|
||||
toolCallId: 'call_1',
|
||||
content: 'Found 3 results'
|
||||
});
|
||||
const sections = deriveAgenticSections(msg, [toolResult]);
|
||||
expect(sections).toHaveLength(2);
|
||||
expect(sections[0].type).toBe(AgenticSectionType.TEXT);
|
||||
expect(sections[1].type).toBe(AgenticSectionType.TOOL_CALL);
|
||||
expect(sections[1].toolName).toBe('search');
|
||||
expect(sections[1].toolResult).toBe('Found 3 results');
|
||||
});
|
||||
|
||||
it('single turn: pending tool call without result', () => {
|
||||
const msg = makeAssistant({
|
||||
toolCalls: JSON.stringify([
|
||||
{ id: 'call_1', type: 'function', function: { name: 'bash', arguments: '{}' } }
|
||||
])
|
||||
});
|
||||
const sections = deriveAgenticSections(msg, []);
|
||||
expect(sections).toHaveLength(1);
|
||||
expect(sections[0].type).toBe(AgenticSectionType.TOOL_CALL_PENDING);
|
||||
expect(sections[0].toolName).toBe('bash');
|
||||
});
|
||||
|
||||
it('multi-turn: two assistant turns grouped as one session', () => {
|
||||
const assistant1 = makeAssistant({
|
||||
id: 'ast-1',
|
||||
content: 'Turn 1 text',
|
||||
toolCalls: JSON.stringify([
|
||||
{ id: 'call_1', type: 'function', function: { name: 'search', arguments: '{"q":"foo"}' } }
|
||||
])
|
||||
});
|
||||
const tool1 = makeToolMsg({ id: 'tool-1', toolCallId: 'call_1', content: 'result 1' });
|
||||
const assistant2 = makeAssistant({
|
||||
id: 'ast-2',
|
||||
content: 'Final answer based on results.'
|
||||
});
|
||||
|
||||
// toolMessages contains both tool result and continuation assistant
|
||||
const sections = deriveAgenticSections(assistant1, [tool1, assistant2]);
|
||||
expect(sections).toHaveLength(3);
|
||||
// Turn 1
|
||||
expect(sections[0].type).toBe(AgenticSectionType.TEXT);
|
||||
expect(sections[0].content).toBe('Turn 1 text');
|
||||
expect(sections[1].type).toBe(AgenticSectionType.TOOL_CALL);
|
||||
expect(sections[1].toolName).toBe('search');
|
||||
expect(sections[1].toolResult).toBe('result 1');
|
||||
// Turn 2 (final)
|
||||
expect(sections[2].type).toBe(AgenticSectionType.TEXT);
|
||||
expect(sections[2].content).toBe('Final answer based on results.');
|
||||
});
|
||||
|
||||
it('multi-turn: three turns with tool calls', () => {
|
||||
const assistant1 = makeAssistant({
|
||||
id: 'ast-1',
|
||||
content: '',
|
||||
toolCalls: JSON.stringify([
|
||||
{ id: 'call_1', type: 'function', function: { name: 'list_files', arguments: '{}' } }
|
||||
])
|
||||
});
|
||||
const tool1 = makeToolMsg({ id: 'tool-1', toolCallId: 'call_1', content: 'file1 file2' });
|
||||
const assistant2 = makeAssistant({
|
||||
id: 'ast-2',
|
||||
content: 'Reading file1...',
|
||||
toolCalls: JSON.stringify([
|
||||
{
|
||||
id: 'call_2',
|
||||
type: 'function',
|
||||
function: { name: 'read_file', arguments: '{"path":"file1"}' }
|
||||
}
|
||||
])
|
||||
});
|
||||
const tool2 = makeToolMsg({ id: 'tool-2', toolCallId: 'call_2', content: 'contents of file1' });
|
||||
const assistant3 = makeAssistant({
|
||||
id: 'ast-3',
|
||||
content: 'Here is the analysis.',
|
||||
reasoningContent: 'The file contains...'
|
||||
});
|
||||
|
||||
const sections = deriveAgenticSections(assistant1, [tool1, assistant2, tool2, assistant3]);
|
||||
// Turn 1: tool_call (no text since content is empty)
|
||||
// Turn 2: text + tool_call
|
||||
// Turn 3: reasoning + text
|
||||
expect(sections).toHaveLength(5);
|
||||
expect(sections[0].type).toBe(AgenticSectionType.TOOL_CALL);
|
||||
expect(sections[0].toolName).toBe('list_files');
|
||||
expect(sections[1].type).toBe(AgenticSectionType.TEXT);
|
||||
expect(sections[1].content).toBe('Reading file1...');
|
||||
expect(sections[2].type).toBe(AgenticSectionType.TOOL_CALL);
|
||||
expect(sections[2].toolName).toBe('read_file');
|
||||
expect(sections[3].type).toBe(AgenticSectionType.REASONING);
|
||||
expect(sections[4].type).toBe(AgenticSectionType.TEXT);
|
||||
expect(sections[4].content).toBe('Here is the analysis.');
|
||||
});
|
||||
|
||||
it('multi-turn: streaming tool calls on last turn', () => {
|
||||
const assistant1 = makeAssistant({
|
||||
toolCalls: JSON.stringify([
|
||||
{ id: 'call_1', type: 'function', function: { name: 'search', arguments: '{}' } }
|
||||
])
|
||||
});
|
||||
const tool1 = makeToolMsg({ toolCallId: 'call_1', content: 'result' });
|
||||
const assistant2 = makeAssistant({ id: 'ast-2', content: '' });
|
||||
|
||||
const streamingToolCalls: ApiChatCompletionToolCall[] = [
|
||||
{ id: 'call_2', type: 'function', function: { name: 'write_file', arguments: '{"pa' } }
|
||||
];
|
||||
|
||||
const sections = deriveAgenticSections(assistant1, [tool1, assistant2], streamingToolCalls);
|
||||
// Turn 1: tool_call
|
||||
// Turn 2 (streaming): streaming tool call
|
||||
expect(sections.some((s) => s.type === AgenticSectionType.TOOL_CALL)).toBe(true);
|
||||
expect(sections.some((s) => s.type === AgenticSectionType.TOOL_CALL_STREAMING)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('hasAgenticContent', () => {
|
||||
it('returns false for plain assistant', () => {
|
||||
const msg = makeAssistant({ content: 'Just text' });
|
||||
expect(hasAgenticContent(msg)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true when message has toolCalls', () => {
|
||||
const msg = makeAssistant({
|
||||
toolCalls: JSON.stringify([
|
||||
{ id: 'call_1', type: 'function', function: { name: 'test', arguments: '{}' } }
|
||||
])
|
||||
});
|
||||
expect(hasAgenticContent(msg)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true when toolMessages are provided', () => {
|
||||
const msg = makeAssistant();
|
||||
const tool = makeToolMsg();
|
||||
expect(hasAgenticContent(msg, [tool])).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false for empty toolCalls JSON', () => {
|
||||
const msg = makeAssistant({ toolCalls: '[]' });
|
||||
expect(hasAgenticContent(msg)).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -1,17 +1,22 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { AGENTIC_REGEX } from '$lib/constants/agentic';
|
||||
import { LEGACY_AGENTIC_REGEX } from '$lib/constants/agentic';
|
||||
|
||||
// Mirror the logic in ChatService.stripReasoningContent so we can test it in isolation.
|
||||
// The real function is private static, so we replicate the strip pipeline here.
|
||||
function stripContextMarkers(content: string): string {
|
||||
/**
|
||||
* Tests for legacy marker stripping (used in migration).
|
||||
* The new system does not embed markers in content - these tests verify
|
||||
* the legacy regex patterns still work for the migration code.
|
||||
*/
|
||||
|
||||
// Mirror the legacy stripping logic used during migration
|
||||
function stripLegacyContextMarkers(content: string): string {
|
||||
return content
|
||||
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '');
|
||||
.replace(new RegExp(LEGACY_AGENTIC_REGEX.REASONING_BLOCK.source, 'g'), '')
|
||||
.replace(LEGACY_AGENTIC_REGEX.REASONING_OPEN, '')
|
||||
.replace(new RegExp(LEGACY_AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK.source, 'g'), '')
|
||||
.replace(LEGACY_AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '');
|
||||
}
|
||||
|
||||
// A realistic complete tool call block as stored in message.content after a turn.
|
||||
// A realistic complete tool call block as stored in old message.content
|
||||
const COMPLETE_BLOCK =
|
||||
'\n\n<<<AGENTIC_TOOL_CALL_START>>>\n' +
|
||||
'<<<TOOL_NAME:bash_tool>>>\n' +
|
||||
@@ -30,11 +35,10 @@ const OPEN_BLOCK =
|
||||
'<<<TOOL_ARGS_END>>>\n' +
|
||||
'partial output...';
|
||||
|
||||
describe('agentic marker stripping for context', () => {
|
||||
describe('legacy agentic marker stripping (for migration)', () => {
|
||||
it('strips a complete tool call block, leaving surrounding text', () => {
|
||||
const input = 'Before.' + COMPLETE_BLOCK + 'After.';
|
||||
const result = stripContextMarkers(input);
|
||||
// markers gone; residual newlines between fragments are fine
|
||||
const result = stripLegacyContextMarkers(input);
|
||||
expect(result).not.toContain('<<<');
|
||||
expect(result).toContain('Before.');
|
||||
expect(result).toContain('After.');
|
||||
@@ -42,7 +46,7 @@ describe('agentic marker stripping for context', () => {
|
||||
|
||||
it('strips multiple complete tool call blocks', () => {
|
||||
const input = 'A' + COMPLETE_BLOCK + 'B' + COMPLETE_BLOCK + 'C';
|
||||
const result = stripContextMarkers(input);
|
||||
const result = stripLegacyContextMarkers(input);
|
||||
expect(result).not.toContain('<<<');
|
||||
expect(result).toContain('A');
|
||||
expect(result).toContain('B');
|
||||
@@ -51,19 +55,19 @@ describe('agentic marker stripping for context', () => {
|
||||
|
||||
it('strips an open/partial tool call block (no END marker)', () => {
|
||||
const input = 'Lead text.' + OPEN_BLOCK;
|
||||
const result = stripContextMarkers(input);
|
||||
const result = stripLegacyContextMarkers(input);
|
||||
expect(result).toBe('Lead text.');
|
||||
expect(result).not.toContain('<<<');
|
||||
});
|
||||
|
||||
it('does not alter content with no markers', () => {
|
||||
const input = 'Just a normal assistant response.';
|
||||
expect(stripContextMarkers(input)).toBe(input);
|
||||
expect(stripLegacyContextMarkers(input)).toBe(input);
|
||||
});
|
||||
|
||||
it('strips reasoning block independently', () => {
|
||||
const input = '<<<reasoning_content_start>>>think hard<<<reasoning_content_end>>>Answer.';
|
||||
expect(stripContextMarkers(input)).toBe('Answer.');
|
||||
expect(stripLegacyContextMarkers(input)).toBe('Answer.');
|
||||
});
|
||||
|
||||
it('strips both reasoning and agentic blocks together', () => {
|
||||
@@ -71,11 +75,21 @@ describe('agentic marker stripping for context', () => {
|
||||
'<<<reasoning_content_start>>>plan<<<reasoning_content_end>>>' +
|
||||
'Some text.' +
|
||||
COMPLETE_BLOCK;
|
||||
expect(stripContextMarkers(input)).not.toContain('<<<');
|
||||
expect(stripContextMarkers(input)).toContain('Some text.');
|
||||
expect(stripLegacyContextMarkers(input)).not.toContain('<<<');
|
||||
expect(stripLegacyContextMarkers(input)).toContain('Some text.');
|
||||
});
|
||||
|
||||
it('empty string survives', () => {
|
||||
expect(stripContextMarkers('')).toBe('');
|
||||
expect(stripLegacyContextMarkers('')).toBe('');
|
||||
});
|
||||
|
||||
it('detects legacy markers', () => {
|
||||
expect(LEGACY_AGENTIC_REGEX.HAS_LEGACY_MARKERS.test('normal text')).toBe(false);
|
||||
expect(
|
||||
LEGACY_AGENTIC_REGEX.HAS_LEGACY_MARKERS.test('text<<<AGENTIC_TOOL_CALL_START>>>more')
|
||||
).toBe(true);
|
||||
expect(LEGACY_AGENTIC_REGEX.HAS_LEGACY_MARKERS.test('<<<reasoning_content_start>>>think')).toBe(
|
||||
true
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,196 +1,89 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { AGENTIC_REGEX, REASONING_TAGS } from '$lib/constants/agentic';
|
||||
import { ContentPartType } from '$lib/enums';
|
||||
import { MessageRole } from '$lib/enums';
|
||||
|
||||
// Replicate ChatService.extractReasoningFromContent (private static)
|
||||
function extractReasoningFromContent(
|
||||
content: string | Array<{ type: string; text?: string }> | null | undefined
|
||||
): string | undefined {
|
||||
if (!content) return undefined;
|
||||
/**
|
||||
* Tests for the new reasoning content handling.
|
||||
* In the new architecture, reasoning content is stored in a dedicated
|
||||
* `reasoningContent` field on DatabaseMessage, not embedded in content with tags.
|
||||
* The API sends it as `reasoning_content` on ApiChatMessageData.
|
||||
*/
|
||||
|
||||
const extractFromString = (text: string): string => {
|
||||
const parts: string[] = [];
|
||||
const re = new RegExp(AGENTIC_REGEX.REASONING_EXTRACT.source);
|
||||
let match = re.exec(text);
|
||||
while (match) {
|
||||
parts.push(match[1]);
|
||||
text = text.slice(match.index + match[0].length);
|
||||
match = re.exec(text);
|
||||
}
|
||||
return parts.join('');
|
||||
};
|
||||
|
||||
if (typeof content === 'string') {
|
||||
const result = extractFromString(content);
|
||||
return result || undefined;
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) return undefined;
|
||||
|
||||
const parts: string[] = [];
|
||||
for (const part of content) {
|
||||
if (part.type === ContentPartType.TEXT && part.text) {
|
||||
const result = extractFromString(part.text);
|
||||
if (result) parts.push(result);
|
||||
}
|
||||
}
|
||||
return parts.length > 0 ? parts.join('') : undefined;
|
||||
}
|
||||
|
||||
// Replicate ChatService.stripReasoningContent (private static)
|
||||
function stripReasoningContent(
|
||||
content: string | Array<{ type: string; text?: string }> | null | undefined
|
||||
): typeof content {
|
||||
if (!content) return content;
|
||||
|
||||
if (typeof content === 'string') {
|
||||
return content
|
||||
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '');
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) return content;
|
||||
|
||||
return content.map((part) => {
|
||||
if (part.type !== ContentPartType.TEXT || !part.text) return part;
|
||||
return {
|
||||
...part,
|
||||
text: part.text
|
||||
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '')
|
||||
describe('reasoning content in new structured format', () => {
|
||||
it('reasoning is stored as separate field, not in content', () => {
|
||||
// Simulate what the new chat store does
|
||||
const message = {
|
||||
content: 'The answer is 4.',
|
||||
reasoningContent: 'Let me think: 2+2=4, basic arithmetic.'
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
// Simulate the message mapping logic from ChatService.sendMessage
|
||||
function buildApiMessage(
|
||||
content: string,
|
||||
excludeReasoningFromContext: boolean
|
||||
): { role: string; content: string; reasoning_content?: string } {
|
||||
const cleaned = stripReasoningContent(content) as string;
|
||||
const mapped: { role: string; content: string; reasoning_content?: string } = {
|
||||
role: 'assistant',
|
||||
content: cleaned
|
||||
};
|
||||
if (!excludeReasoningFromContext) {
|
||||
const reasoning = extractReasoningFromContent(content);
|
||||
if (reasoning) {
|
||||
mapped.reasoning_content = reasoning;
|
||||
// Content should be clean
|
||||
expect(message.content).not.toContain('<<<');
|
||||
expect(message.content).toBe('The answer is 4.');
|
||||
|
||||
// Reasoning in dedicated field
|
||||
expect(message.reasoningContent).toBe('Let me think: 2+2=4, basic arithmetic.');
|
||||
});
|
||||
|
||||
it('convertDbMessageToApiChatMessageData includes reasoning_content', () => {
|
||||
// Simulate the conversion logic
|
||||
const dbMessage = {
|
||||
role: MessageRole.ASSISTANT,
|
||||
content: 'The answer is 4.',
|
||||
reasoningContent: 'Let me think: 2+2=4, basic arithmetic.'
|
||||
};
|
||||
|
||||
const apiMessage: Record<string, unknown> = {
|
||||
role: dbMessage.role,
|
||||
content: dbMessage.content
|
||||
};
|
||||
if (dbMessage.reasoningContent) {
|
||||
apiMessage.reasoning_content = dbMessage.reasoningContent;
|
||||
}
|
||||
}
|
||||
return mapped;
|
||||
}
|
||||
|
||||
// Helper: wrap reasoning the same way the chat store does during streaming
|
||||
function wrapReasoning(reasoning: string, content: string): string {
|
||||
return `${REASONING_TAGS.START}${reasoning}${REASONING_TAGS.END}${content}`;
|
||||
}
|
||||
|
||||
describe('reasoning content extraction', () => {
|
||||
it('extracts reasoning from tagged string content', () => {
|
||||
const input = wrapReasoning('step 1, step 2', 'The answer is 42.');
|
||||
const result = extractReasoningFromContent(input);
|
||||
expect(result).toBe('step 1, step 2');
|
||||
expect(apiMessage.content).toBe('The answer is 4.');
|
||||
expect(apiMessage.reasoning_content).toBe('Let me think: 2+2=4, basic arithmetic.');
|
||||
// No internal tags leak into either field
|
||||
expect(apiMessage.content).not.toContain('<<<');
|
||||
expect(apiMessage.reasoning_content).not.toContain('<<<');
|
||||
});
|
||||
|
||||
it('returns undefined when no reasoning tags present', () => {
|
||||
expect(extractReasoningFromContent('Just a normal response.')).toBeUndefined();
|
||||
it('API message excludes reasoning when excludeReasoningFromContext is true', () => {
|
||||
const dbMessage = {
|
||||
role: MessageRole.ASSISTANT,
|
||||
content: 'The answer is 4.',
|
||||
reasoningContent: 'internal thinking'
|
||||
};
|
||||
|
||||
const excludeReasoningFromContext = true;
|
||||
|
||||
const apiMessage: Record<string, unknown> = {
|
||||
role: dbMessage.role,
|
||||
content: dbMessage.content
|
||||
};
|
||||
if (!excludeReasoningFromContext && dbMessage.reasoningContent) {
|
||||
apiMessage.reasoning_content = dbMessage.reasoningContent;
|
||||
}
|
||||
|
||||
expect(apiMessage.content).toBe('The answer is 4.');
|
||||
expect(apiMessage.reasoning_content).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns undefined for null/empty input', () => {
|
||||
expect(extractReasoningFromContent(null)).toBeUndefined();
|
||||
expect(extractReasoningFromContent(undefined)).toBeUndefined();
|
||||
expect(extractReasoningFromContent('')).toBeUndefined();
|
||||
});
|
||||
it('handles messages with no reasoning', () => {
|
||||
const dbMessage = {
|
||||
role: MessageRole.ASSISTANT,
|
||||
content: 'No reasoning here.',
|
||||
reasoningContent: undefined
|
||||
};
|
||||
|
||||
it('extracts reasoning from content part arrays', () => {
|
||||
const input = [
|
||||
{
|
||||
type: ContentPartType.TEXT,
|
||||
text: wrapReasoning('thinking hard', 'result')
|
||||
}
|
||||
];
|
||||
expect(extractReasoningFromContent(input)).toBe('thinking hard');
|
||||
});
|
||||
const apiMessage: Record<string, unknown> = {
|
||||
role: dbMessage.role,
|
||||
content: dbMessage.content
|
||||
};
|
||||
if (dbMessage.reasoningContent) {
|
||||
apiMessage.reasoning_content = dbMessage.reasoningContent;
|
||||
}
|
||||
|
||||
it('handles multiple reasoning blocks', () => {
|
||||
const input =
|
||||
REASONING_TAGS.START +
|
||||
'block1' +
|
||||
REASONING_TAGS.END +
|
||||
'middle' +
|
||||
REASONING_TAGS.START +
|
||||
'block2' +
|
||||
REASONING_TAGS.END +
|
||||
'end';
|
||||
expect(extractReasoningFromContent(input)).toBe('block1block2');
|
||||
});
|
||||
|
||||
it('ignores non-text content parts', () => {
|
||||
const input = [{ type: 'image_url', text: wrapReasoning('hidden', 'img') }];
|
||||
expect(extractReasoningFromContent(input)).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('strip reasoning content', () => {
|
||||
it('removes reasoning tags from string content', () => {
|
||||
const input = wrapReasoning('internal thoughts', 'visible answer');
|
||||
expect(stripReasoningContent(input)).toBe('visible answer');
|
||||
});
|
||||
|
||||
it('removes reasoning from content part arrays', () => {
|
||||
const input = [
|
||||
{
|
||||
type: ContentPartType.TEXT,
|
||||
text: wrapReasoning('thoughts', 'answer')
|
||||
}
|
||||
];
|
||||
const result = stripReasoningContent(input) as Array<{ type: string; text?: string }>;
|
||||
expect(result[0].text).toBe('answer');
|
||||
});
|
||||
});
|
||||
|
||||
describe('API message building with reasoning preservation', () => {
|
||||
const storedContent = wrapReasoning('Let me think: 2+2=4, basic arithmetic.', 'The answer is 4.');
|
||||
|
||||
it('preserves reasoning_content when excludeReasoningFromContext is false', () => {
|
||||
const msg = buildApiMessage(storedContent, false);
|
||||
expect(msg.content).toBe('The answer is 4.');
|
||||
expect(msg.reasoning_content).toBe('Let me think: 2+2=4, basic arithmetic.');
|
||||
// no internal tags leak into either field
|
||||
expect(msg.content).not.toContain('<<<');
|
||||
expect(msg.reasoning_content).not.toContain('<<<');
|
||||
});
|
||||
|
||||
it('strips reasoning_content when excludeReasoningFromContext is true', () => {
|
||||
const msg = buildApiMessage(storedContent, true);
|
||||
expect(msg.content).toBe('The answer is 4.');
|
||||
expect(msg.reasoning_content).toBeUndefined();
|
||||
});
|
||||
|
||||
it('handles content with no reasoning in both modes', () => {
|
||||
const plain = 'No reasoning here.';
|
||||
const msgPreserve = buildApiMessage(plain, false);
|
||||
const msgExclude = buildApiMessage(plain, true);
|
||||
expect(msgPreserve.content).toBe(plain);
|
||||
expect(msgPreserve.reasoning_content).toBeUndefined();
|
||||
expect(msgExclude.content).toBe(plain);
|
||||
expect(msgExclude.reasoning_content).toBeUndefined();
|
||||
});
|
||||
|
||||
it('cleans agentic tool call blocks from content even when preserving reasoning', () => {
|
||||
const input =
|
||||
wrapReasoning('plan', 'text') +
|
||||
'\n\n<<<AGENTIC_TOOL_CALL_START>>>\n' +
|
||||
'<<<TOOL_NAME:bash>>>\n' +
|
||||
'<<<TOOL_ARGS_START>>>\n{}\n<<<TOOL_ARGS_END>>>\nout\n' +
|
||||
'<<<AGENTIC_TOOL_CALL_END>>>\n';
|
||||
const msg = buildApiMessage(input, false);
|
||||
expect(msg.content).not.toContain('<<<');
|
||||
expect(msg.reasoning_content).toBe('plan');
|
||||
expect(apiMessage.content).toBe('No reasoning here.');
|
||||
expect(apiMessage.reasoning_content).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -551,6 +551,8 @@ int main(int argc, char ** argv) {
|
||||
params.sampling.top_k = 4;
|
||||
params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K, };
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -558,8 +560,6 @@ int main(int argc, char ** argv) {
|
||||
const int n_parallel = params.n_parallel;
|
||||
const int n_predict = params.n_predict;
|
||||
|
||||
common_init();
|
||||
|
||||
// init LLM
|
||||
|
||||
llama_backend_init();
|
||||
|
||||
2
vendor/cpp-httplib/CMakeLists.txt
vendored
2
vendor/cpp-httplib/CMakeLists.txt
vendored
@@ -39,7 +39,7 @@ if (LLAMA_BUILD_BORINGSSL)
|
||||
set(FIPS OFF CACHE BOOL "Enable FIPS (BoringSSL)")
|
||||
|
||||
set(BORINGSSL_GIT "https://boringssl.googlesource.com/boringssl" CACHE STRING "BoringSSL git repository")
|
||||
set(BORINGSSL_VERSION "0.20260211.0" CACHE STRING "BoringSSL version")
|
||||
set(BORINGSSL_VERSION "0.20260327.0" CACHE STRING "BoringSSL version")
|
||||
|
||||
message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user