mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-05-21 17:17:24 +03:00
Compare commits
81 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40d5358d3c | ||
|
|
b65bb4baae | ||
|
|
a1a69f777a | ||
|
|
52fb93a2bd | ||
|
|
c9021714e8 | ||
|
|
1d7ab2b947 | ||
|
|
12e5d99078 | ||
|
|
7ea23ddf7b | ||
|
|
2fc8d1851e | ||
|
|
5e932a1c8d | ||
|
|
2754ce1b3e | ||
|
|
eeeaf6180b | ||
|
|
0be84685bd | ||
|
|
ce02093fdd | ||
|
|
6a257d4463 | ||
|
|
3a479c9132 | ||
|
|
ad27757261 | ||
|
|
3a6db741a8 | ||
|
|
510b5c2a35 | ||
|
|
a8681a0ed2 | ||
|
|
acd604fb27 | ||
|
|
6ce96713de | ||
|
|
c9872a2575 | ||
|
|
e947228222 | ||
|
|
29f1482221 | ||
|
|
e6b4acfe86 | ||
|
|
e2b129e1bf | ||
|
|
7e50ef7d79 | ||
|
|
5028447384 | ||
|
|
585080d310 | ||
|
|
57ebaf4edd | ||
|
|
871b0b70f8 | ||
|
|
b39a7bf1b0 | ||
|
|
b28a2f372a | ||
|
|
17d22a35b2 | ||
|
|
67ace021da | ||
|
|
a8078675a6 | ||
|
|
57cb35c886 | ||
|
|
7256fce047 | ||
|
|
b7393a4d19 | ||
|
|
ac76808e4d | ||
|
|
baf3cc6e1d | ||
|
|
d14ce3dab4 | ||
|
|
6db130445d | ||
|
|
4b262ab662 | ||
|
|
00c461ce1a | ||
|
|
ccee426426 | ||
|
|
3c81c8deea | ||
|
|
cd963fee6a | ||
|
|
d2e179a477 | ||
|
|
c85a242ed0 | ||
|
|
aabee047d8 | ||
|
|
f1c1c5c057 | ||
|
|
439f1b193d | ||
|
|
c3e9ade6dd | ||
|
|
9a532ae4ba | ||
|
|
b7340443d4 | ||
|
|
5cbaa5e69e | ||
|
|
45b455e66f | ||
|
|
3a9c1b854d | ||
|
|
b9a2170fce | ||
|
|
1ff0fc1384 | ||
|
|
a135ec0baa | ||
|
|
232f466583 | ||
|
|
49c21f97cd | ||
|
|
77e38d68f2 | ||
|
|
053e01dff6 | ||
|
|
c3f95c1f06 | ||
|
|
0caf2a1d48 | ||
|
|
5511965b19 | ||
|
|
e98bcfec28 | ||
|
|
1867a0c692 | ||
|
|
dd7cad7197 | ||
|
|
726704a160 | ||
|
|
87589042ca | ||
|
|
e0de4c2419 | ||
|
|
84c678242a | ||
|
|
3e12fbdea5 | ||
|
|
39cf5d6191 | ||
|
|
a6d6183dbc | ||
|
|
fcae601e44 |
@@ -5,6 +5,9 @@
|
||||
# Define the CANN base image for easier version updates later
|
||||
ARG CHIP_TYPE=910b
|
||||
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.5.0-${CHIP_TYPE}-openeuler24.03-py3.11
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
# ==============================================================================
|
||||
# BUILD STAGE
|
||||
@@ -55,6 +58,7 @@ RUN mkdir -p /app/lib && \
|
||||
RUN mkdir -p /app/full && \
|
||||
cp build/bin/* /app/full/ && \
|
||||
cp *.py /app/full/ && \
|
||||
cp -r conversion /app/full/ && \
|
||||
cp -r gguf-py /app/full/ && \
|
||||
cp -r requirements /app/full/ && \
|
||||
cp requirements.txt /app/full/
|
||||
@@ -67,6 +71,19 @@ RUN mkdir -p /app/full && \
|
||||
# ==============================================================================
|
||||
FROM ${CANN_BASE_IMAGE} AS base
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
ARG IMAGE_URL=https://github.com/ggml-org/llama.cpp
|
||||
ARG IMAGE_SOURCE=https://github.com/ggml-org/llama.cpp
|
||||
LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.version=$APP_VERSION \
|
||||
org.opencontainers.image.revision=$APP_REVISION \
|
||||
org.opencontainers.image.title="llama.cpp" \
|
||||
org.opencontainers.image.description="LLM inference in C/C++" \
|
||||
org.opencontainers.image.url=$IMAGE_URL \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
# -- Install runtime dependencies --
|
||||
RUN yum install -y libgomp curl && \
|
||||
yum clean all && \
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
@@ -27,6 +30,7 @@ RUN mkdir -p /app/lib && \
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r conversion /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
@@ -35,6 +39,19 @@ RUN mkdir -p /app/full \
|
||||
## Base image
|
||||
FROM ubuntu:$UBUNTU_VERSION AS base
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
ARG IMAGE_URL=https://github.com/ggml-org/llama.cpp
|
||||
ARG IMAGE_SOURCE=https://github.com/ggml-org/llama.cpp
|
||||
LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.version=$APP_VERSION \
|
||||
org.opencontainers.image.revision=$APP_REVISION \
|
||||
org.opencontainers.image.title="llama.cpp" \
|
||||
org.opencontainers.image.description="LLM inference in C/C++" \
|
||||
org.opencontainers.image.url=$IMAGE_URL \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
|
||||
@@ -6,6 +6,10 @@ ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VER
|
||||
|
||||
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
|
||||
|
||||
# CUDA architecture to build for (defaults to all supported archs)
|
||||
@@ -32,6 +36,7 @@ RUN mkdir -p /app/lib && \
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r conversion /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
@@ -40,6 +45,19 @@ RUN mkdir -p /app/full \
|
||||
## Base image
|
||||
FROM ${BASE_CUDA_RUN_CONTAINER} AS base
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
ARG IMAGE_URL=https://github.com/ggml-org/llama.cpp
|
||||
ARG IMAGE_SOURCE=https://github.com/ggml-org/llama.cpp
|
||||
LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.version=$APP_VERSION \
|
||||
org.opencontainers.image.revision=$APP_REVISION \
|
||||
org.opencontainers.image.title="llama.cpp" \
|
||||
org.opencontainers.image.description="LLM inference in C/C++" \
|
||||
org.opencontainers.image.url=$IMAGE_URL \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
ARG ONEAPI_VERSION=2025.3.3-0-devel-ubuntu24.04
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
## Build Image
|
||||
|
||||
@@ -33,6 +36,7 @@ RUN mkdir -p /app/lib && \
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r conversion /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
@@ -40,6 +44,19 @@ RUN mkdir -p /app/full \
|
||||
|
||||
FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS base
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
ARG IMAGE_URL=https://github.com/ggml-org/llama.cpp
|
||||
ARG IMAGE_SOURCE=https://github.com/ggml-org/llama.cpp
|
||||
LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.version=$APP_VERSION \
|
||||
org.opencontainers.image.revision=$APP_REVISION \
|
||||
org.opencontainers.image.title="llama.cpp" \
|
||||
org.opencontainers.image.description="LLM inference in C/C++" \
|
||||
org.opencontainers.image.url=$IMAGE_URL \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
ARG IGC_VERSION=v2.20.5
|
||||
ARG IGC_VERSION_FULL=2_2.20.5+19972
|
||||
ARG COMPUTE_RUNTIME_VERSION=25.40.35563.10
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
ARG ASCEND_VERSION=8.5.0-910b-openeuler22.03-py3.10
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
FROM ascendai/cann:$ASCEND_VERSION AS build
|
||||
|
||||
@@ -28,6 +31,20 @@ RUN echo "Building with static libs" && \
|
||||
|
||||
# TODO: use image with NNRT
|
||||
FROM ascendai/cann:$ASCEND_VERSION AS runtime
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
ARG IMAGE_URL=https://github.com/ggml-org/llama.cpp
|
||||
ARG IMAGE_SOURCE=https://github.com/ggml-org/llama.cpp
|
||||
LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.version=$APP_VERSION \
|
||||
org.opencontainers.image.revision=$APP_REVISION \
|
||||
org.opencontainers.image.title="llama.cpp" \
|
||||
org.opencontainers.image.description="LLM inference in C/C++" \
|
||||
org.opencontainers.image.url=$IMAGE_URL \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
COPY --from=build /app/build/bin/llama-cli /app/build/bin/llama-completion /
|
||||
|
||||
ENV LC_ALL=C.utf8
|
||||
|
||||
@@ -6,6 +6,10 @@ ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_V
|
||||
|
||||
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}-amd64
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
|
||||
|
||||
# MUSA architecture to build for (defaults to all supported archs)
|
||||
@@ -37,6 +41,7 @@ RUN mkdir -p /app/lib && \
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r conversion /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
@@ -45,6 +50,19 @@ RUN mkdir -p /app/full \
|
||||
## Base image
|
||||
FROM ${BASE_MUSA_RUN_CONTAINER} AS base
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
ARG IMAGE_URL=https://github.com/ggml-org/llama.cpp
|
||||
ARG IMAGE_SOURCE=https://github.com/ggml-org/llama.cpp
|
||||
LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.version=$APP_VERSION \
|
||||
org.opencontainers.image.revision=$APP_REVISION \
|
||||
org.opencontainers.image.title="llama.cpp" \
|
||||
org.opencontainers.image.description="LLM inference in C/C++" \
|
||||
org.opencontainers.image.url=$IMAGE_URL \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
|
||||
@@ -18,6 +18,10 @@ ARG LIBZE1_VERSION=1.27.0-1~24.04~ppa2
|
||||
ARG http_proxy=
|
||||
ARG https_proxy=
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
## Build Image
|
||||
FROM ubuntu:${UBUNTU_VERSION} AS build
|
||||
|
||||
@@ -77,6 +81,7 @@ RUN mkdir -p /app/lib && \
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/ReleaseOV/bin/* /app/full/ \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r conversion /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
@@ -88,6 +93,18 @@ FROM ubuntu:${UBUNTU_VERSION} AS base
|
||||
# Pass proxy args to runtime stage
|
||||
ARG http_proxy
|
||||
ARG https_proxy
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
ARG IMAGE_URL=https://github.com/ggml-org/llama.cpp
|
||||
ARG IMAGE_SOURCE=https://github.com/ggml-org/llama.cpp
|
||||
LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.version=$APP_VERSION \
|
||||
org.opencontainers.image.revision=$APP_REVISION \
|
||||
org.opencontainers.image.title="llama.cpp" \
|
||||
org.opencontainers.image.description="LLM inference in C/C++" \
|
||||
org.opencontainers.image.url=$IMAGE_URL \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 libtbb12 curl wget ocl-icd-libopencl1 \
|
||||
|
||||
@@ -7,6 +7,10 @@ ARG AMDGPU_VERSION=7.2.1
|
||||
# Target the ROCm build image
|
||||
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
### Build image
|
||||
FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||
|
||||
@@ -49,6 +53,7 @@ RUN mkdir -p /app/lib \
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r conversion /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
@@ -57,6 +62,19 @@ RUN mkdir -p /app/full \
|
||||
## Base image
|
||||
FROM ${BASE_ROCM_DEV_CONTAINER} AS base
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
ARG IMAGE_URL=https://github.com/ggml-org/llama.cpp
|
||||
ARG IMAGE_SOURCE=https://github.com/ggml-org/llama.cpp
|
||||
LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.version=$APP_VERSION \
|
||||
org.opencontainers.image.revision=$APP_REVISION \
|
||||
org.opencontainers.image.title="llama.cpp" \
|
||||
org.opencontainers.image.description="LLM inference in C/C++" \
|
||||
org.opencontainers.image.url=$IMAGE_URL \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
ARG GCC_VERSION=15.2.0
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
### Build Llama.cpp stage
|
||||
FROM gcc:${GCC_VERSION} AS build
|
||||
@@ -34,6 +37,7 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
|
||||
COPY *.py /opt/llama.cpp/bin
|
||||
COPY .devops/tools.sh /opt/llama.cpp/bin
|
||||
COPY conversion /opt/llama.cpp/conversion
|
||||
|
||||
COPY gguf-py /opt/llama.cpp/gguf-py
|
||||
COPY requirements.txt /opt/llama.cpp/gguf-py
|
||||
@@ -44,14 +48,28 @@ COPY requirements /opt/llama.cpp/gguf-py/requirements
|
||||
FROM scratch AS collector
|
||||
|
||||
# Copy llama.cpp binaries and libraries
|
||||
COPY --from=build /opt/llama.cpp/bin /llama.cpp/bin
|
||||
COPY --from=build /opt/llama.cpp/lib /llama.cpp/lib
|
||||
COPY --from=build /opt/llama.cpp/gguf-py /llama.cpp/gguf-py
|
||||
COPY --from=build /opt/llama.cpp/bin /llama.cpp/bin
|
||||
COPY --from=build /opt/llama.cpp/lib /llama.cpp/lib
|
||||
COPY --from=build /opt/llama.cpp/gguf-py /llama.cpp/gguf-py
|
||||
COPY --from=build /opt/llama.cpp/conversion /llama.cpp/conversion
|
||||
|
||||
|
||||
### Base image
|
||||
FROM ubuntu:${UBUNTU_VERSION} AS base
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
ARG IMAGE_URL=https://github.com/ggml-org/llama.cpp
|
||||
ARG IMAGE_SOURCE=https://github.com/ggml-org/llama.cpp
|
||||
LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.version=$APP_VERSION \
|
||||
org.opencontainers.image.revision=$APP_REVISION \
|
||||
org.opencontainers.image.title="llama.cpp" \
|
||||
org.opencontainers.image.description="LLM inference in C/C++" \
|
||||
org.opencontainers.image.url=$IMAGE_URL \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt/lists,sharing=locked \
|
||||
apt update -y && \
|
||||
@@ -91,6 +109,7 @@ RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
|
||||
|
||||
COPY --from=collector /llama.cpp/bin /app
|
||||
COPY --from=collector /llama.cpp/gguf-py /app/gguf-py
|
||||
COPY --from=collector /llama.cpp/conversion /app/conversion
|
||||
|
||||
RUN pip install --no-cache-dir --break-system-packages \
|
||||
-r /app/gguf-py/requirements.txt
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
ARG UBUNTU_VERSION=26.04
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
@@ -23,6 +26,7 @@ RUN mkdir -p /app/lib && \
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r conversion /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
@@ -31,6 +35,19 @@ RUN mkdir -p /app/full \
|
||||
## Base image
|
||||
FROM ubuntu:$UBUNTU_VERSION AS base
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
ARG IMAGE_URL=https://github.com/ggml-org/llama.cpp
|
||||
ARG IMAGE_SOURCE=https://github.com/ggml-org/llama.cpp
|
||||
LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.version=$APP_VERSION \
|
||||
org.opencontainers.image.revision=$APP_REVISION \
|
||||
org.opencontainers.image.title="llama.cpp" \
|
||||
org.opencontainers.image.description="LLM inference in C/C++" \
|
||||
org.opencontainers.image.url=$IMAGE_URL \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl libvulkan1 mesa-vulkan-drivers \
|
||||
libglvnd0 libgl1 libglx0 libegl1 libgles2 \
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/011-bug-results.yml
vendored
4
.github/ISSUE_TEMPLATE/011-bug-results.yml
vendored
@@ -100,8 +100,8 @@ body:
|
||||
label: Relevant log output
|
||||
description: >
|
||||
Please copy and paste any relevant log output, including the command that you entered and any generated text.
|
||||
For very long logs (thousands of lines), preferably upload them as files instead.
|
||||
On Linux you can redirect console output into a file by appending ` > llama.log 2>&1` to your command.
|
||||
For very long logs (thousands of lines), please upload them as files instead; the `--log-file` CLI argument can be used for this purpose.
|
||||
On Linux you can alternatively redirect the console output of any command into a file by appending ` > llama.log 2>&1` to your command.
|
||||
value: |
|
||||
<details>
|
||||
<summary>Logs</summary>
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/019-bug-misc.yml
vendored
4
.github/ISSUE_TEMPLATE/019-bug-misc.yml
vendored
@@ -88,8 +88,8 @@ body:
|
||||
description: >
|
||||
If applicable, please copy and paste any relevant log output, including any generated text.
|
||||
If you are encountering problems specifically with the `llama_params_fit` module, always upload `--verbose` logs as well.
|
||||
For very long logs (thousands of lines), please upload them as files instead.
|
||||
On Linux you can redirect console output into a file by appending ` > llama.log 2>&1` to your command.
|
||||
For very long logs (thousands of lines), please upload them as files instead; the `--log-file` CLI argument can be used for this purpose.
|
||||
On Linux you can alternatively redirect the console output of any command into a file by appending ` > llama.log 2>&1` to your command.
|
||||
value: |
|
||||
<details>
|
||||
<summary>Logs</summary>
|
||||
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
android-ndk-snapdragon:
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: 'ghcr.io/snapdragon-toolchain/arm64-android:v0.3'
|
||||
image: 'ghcr.io/snapdragon-toolchain/arm64-android:v0.6'
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -61,7 +61,7 @@ jobs:
|
||||
linux-iot-snapdragon:
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: 'ghcr.io/snapdragon-toolchain/arm64-linux:v0.1'
|
||||
image: 'ghcr.io/snapdragon-toolchain/arm64-linux:v0.6'
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
83
.github/workflows/docker.yml
vendored
83
.github/workflows/docker.yml
vendored
@@ -11,6 +11,11 @@ name: Publish Docker image
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
inputs:
|
||||
skip_s390x:
|
||||
description: "Skip the s390x build target (useful for fast test runs that do not need the IBM Z runner)"
|
||||
type: boolean
|
||||
default: false
|
||||
schedule:
|
||||
# Rebuild daily rather than on every push because it is expensive
|
||||
- cron: '12 4 * * *'
|
||||
@@ -64,6 +69,8 @@ jobs:
|
||||
- name: Generate build and merge matrices
|
||||
id: matrices
|
||||
shell: bash
|
||||
env:
|
||||
SKIP_S390X: ${{ inputs.skip_s390x || 'false' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
@@ -86,6 +93,11 @@ jobs:
|
||||
]
|
||||
JSON
|
||||
|
||||
if [ "${SKIP_S390X}" = "true" ]; then
|
||||
jq 'map(select(.platforms != "linux/s390x"))' build-matrix.json > build-matrix.json.tmp
|
||||
mv build-matrix.json.tmp build-matrix.json
|
||||
fi
|
||||
|
||||
BUILD_MATRIX="$(jq -c . build-matrix.json)"
|
||||
MERGE_MATRIX="$(jq -c '
|
||||
reduce .[] as $entry ({}; .[$entry.tag] |= (
|
||||
@@ -132,6 +144,7 @@ jobs:
|
||||
config: ${{ fromJSON(needs.prepare_matrices.outputs.build_matrix) }}
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
@@ -187,6 +200,10 @@ jobs:
|
||||
env:
|
||||
GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}'
|
||||
|
||||
- name: Get build date
|
||||
id: build_date
|
||||
run: echo "date=$(date -u +"%Y-%m-%dT%H:%M:%SZ")" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
if: ${{ matrix.config.free_disk_space == true }}
|
||||
uses: ggml-org/free-disk-space@v1.3.1
|
||||
@@ -211,13 +228,26 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.config.platforms }}
|
||||
outputs: type=image,name=${{ steps.meta.outputs.image_repo }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ steps.meta.outputs.image_repo }},push-by-digest=true,name-canonical=true,push=true,oci-mediatypes=true
|
||||
file: ${{ matrix.config.dockerfile }}
|
||||
target: full
|
||||
provenance: false
|
||||
build-args: |
|
||||
BUILD_DATE=${{ steps.build_date.outputs.date }}
|
||||
APP_VERSION=${{ needs.create_tag.outputs.source_tag }}
|
||||
APP_REVISION=${{ steps.checkout.outputs.commit }}
|
||||
IMAGE_URL=${{ github.server_url }}/${{ github.repository }}
|
||||
IMAGE_SOURCE=${{ github.server_url }}/${{ github.repository }}
|
||||
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
|
||||
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
|
||||
annotations: |
|
||||
manifest:org.opencontainers.image.created=${{ steps.build_date.outputs.date }}
|
||||
manifest:org.opencontainers.image.version=${{ needs.create_tag.outputs.source_tag }}
|
||||
manifest:org.opencontainers.image.revision=${{ steps.checkout.outputs.commit }}
|
||||
manifest:org.opencontainers.image.title=llama.cpp
|
||||
manifest:org.opencontainers.image.description=LLM inference in C/C++
|
||||
manifest:org.opencontainers.image.url=${{ github.server_url }}/${{ github.repository }}
|
||||
manifest:org.opencontainers.image.source=${{ github.server_url }}/${{ github.repository }}
|
||||
# using github experimental cache
|
||||
#cache-from: type=gha
|
||||
#cache-to: type=gha,mode=max
|
||||
@@ -235,13 +265,26 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.config.platforms }}
|
||||
outputs: type=image,name=${{ steps.meta.outputs.image_repo }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ steps.meta.outputs.image_repo }},push-by-digest=true,name-canonical=true,push=true,oci-mediatypes=true
|
||||
file: ${{ matrix.config.dockerfile }}
|
||||
target: light
|
||||
provenance: false
|
||||
build-args: |
|
||||
BUILD_DATE=${{ steps.build_date.outputs.date }}
|
||||
APP_VERSION=${{ needs.create_tag.outputs.source_tag }}
|
||||
APP_REVISION=${{ steps.checkout.outputs.commit }}
|
||||
IMAGE_URL=${{ github.server_url }}/${{ github.repository }}
|
||||
IMAGE_SOURCE=${{ github.server_url }}/${{ github.repository }}
|
||||
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
|
||||
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
|
||||
annotations: |
|
||||
manifest:org.opencontainers.image.created=${{ steps.build_date.outputs.date }}
|
||||
manifest:org.opencontainers.image.version=${{ needs.create_tag.outputs.source_tag }}
|
||||
manifest:org.opencontainers.image.revision=${{ steps.checkout.outputs.commit }}
|
||||
manifest:org.opencontainers.image.title=llama.cpp
|
||||
manifest:org.opencontainers.image.description=LLM inference in C/C++
|
||||
manifest:org.opencontainers.image.url=${{ github.server_url }}/${{ github.repository }}
|
||||
manifest:org.opencontainers.image.source=${{ github.server_url }}/${{ github.repository }}
|
||||
# using github experimental cache
|
||||
#cache-from: type=gha
|
||||
#cache-to: type=gha,mode=max
|
||||
@@ -259,13 +302,26 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.config.platforms }}
|
||||
outputs: type=image,name=${{ steps.meta.outputs.image_repo }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ steps.meta.outputs.image_repo }},push-by-digest=true,name-canonical=true,push=true,oci-mediatypes=true
|
||||
file: ${{ matrix.config.dockerfile }}
|
||||
target: server
|
||||
provenance: false
|
||||
build-args: |
|
||||
BUILD_DATE=${{ steps.build_date.outputs.date }}
|
||||
APP_VERSION=${{ needs.create_tag.outputs.source_tag }}
|
||||
APP_REVISION=${{ steps.checkout.outputs.commit }}
|
||||
IMAGE_URL=${{ github.server_url }}/${{ github.repository }}
|
||||
IMAGE_SOURCE=${{ github.server_url }}/${{ github.repository }}
|
||||
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
|
||||
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
|
||||
annotations: |
|
||||
manifest:org.opencontainers.image.created=${{ steps.build_date.outputs.date }}
|
||||
manifest:org.opencontainers.image.version=${{ needs.create_tag.outputs.source_tag }}
|
||||
manifest:org.opencontainers.image.revision=${{ steps.checkout.outputs.commit }}
|
||||
manifest:org.opencontainers.image.title=llama.cpp
|
||||
manifest:org.opencontainers.image.description=LLM inference in C/C++
|
||||
manifest:org.opencontainers.image.url=${{ github.server_url }}/${{ github.repository }}
|
||||
manifest:org.opencontainers.image.source=${{ github.server_url }}/${{ github.repository }}
|
||||
# using github experimental cache
|
||||
#cache-from: type=gha
|
||||
#cache-to: type=gha,mode=max
|
||||
@@ -330,10 +386,15 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Get build date
|
||||
id: build_date
|
||||
run: echo "date=$(date -u +"%Y-%m-%dT%H:%M:%SZ")" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Download digest metadata
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
|
||||
with:
|
||||
@@ -361,6 +422,8 @@ jobs:
|
||||
IMAGE_REPO="ghcr.io/${REPO_OWNER}/${REPO_NAME}"
|
||||
PREFIX="${IMAGE_REPO}:"
|
||||
SRC_TAG="${{ needs.create_tag.outputs.source_tag }}"
|
||||
BUILD_DATE="${{ steps.build_date.outputs.date }}"
|
||||
COMMIT_SHA="${{ steps.checkout.outputs.commit }}"
|
||||
TAGS="${{ matrix.config.tag }}"
|
||||
ARCHES="${{ matrix.config.arches }}"
|
||||
DIGEST_GLOB="/tmp/digests/*.tsv"
|
||||
@@ -412,11 +475,21 @@ jobs:
|
||||
refs+=("${IMAGE_REPO}@${digest}")
|
||||
done
|
||||
|
||||
local annotations=(
|
||||
--annotation "index:org.opencontainers.image.created=${BUILD_DATE}"
|
||||
--annotation "index:org.opencontainers.image.version=${SRC_TAG}"
|
||||
--annotation "index:org.opencontainers.image.revision=${COMMIT_SHA}"
|
||||
--annotation "index:org.opencontainers.image.title=llama.cpp"
|
||||
--annotation "index:org.opencontainers.image.description=LLM inference in C/C++"
|
||||
--annotation "index:org.opencontainers.image.url=${{ github.server_url }}/${{ github.repository }}"
|
||||
--annotation "index:org.opencontainers.image.source=${{ github.server_url }}/${{ github.repository }}"
|
||||
)
|
||||
|
||||
echo "Creating ${merged_tag} from ${refs[*]}"
|
||||
docker buildx imagetools create --tag "${merged_tag}" "${refs[@]}"
|
||||
docker buildx imagetools create "${annotations[@]}" --tag "${merged_tag}" "${refs[@]}"
|
||||
|
||||
echo "Creating ${merged_versioned_tag} from ${refs[*]}"
|
||||
docker buildx imagetools create --tag "${merged_versioned_tag}" "${refs[@]}"
|
||||
docker buildx imagetools create "${annotations[@]}" --tag "${merged_versioned_tag}" "${refs[@]}"
|
||||
}
|
||||
|
||||
for tag in $TAGS; do
|
||||
|
||||
65
.github/workflows/server-self-hosted.yml
vendored
65
.github/workflows/server-self-hosted.yml
vendored
@@ -130,3 +130,68 @@ jobs:
|
||||
# pip install -r requirements.txt
|
||||
# export ${{ matrix.extra_args }}
|
||||
# pytest -v -x -m "not slow"
|
||||
|
||||
server-kleidiai:
|
||||
runs-on: ah-ubuntu_22_04-c8g_8x
|
||||
|
||||
name: server-kleidiai (${{ matrix.wf_name }})
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- build_type: Release
|
||||
extra_build_flags: "-DGGML_CPU_KLEIDIAI=ON"
|
||||
extra_args: ""
|
||||
wf_name: "CPUx1, kleidiai"
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
sudo apt-get update
|
||||
sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a \
|
||||
apt-get install -y \
|
||||
build-essential \
|
||||
libssl-dev \
|
||||
python3-venv \
|
||||
gpg \
|
||||
wget \
|
||||
time \
|
||||
git-lfs
|
||||
|
||||
git lfs install
|
||||
|
||||
# install the latest cmake
|
||||
sudo install -d /usr/share/keyrings
|
||||
wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc \
|
||||
| gpg --dearmor \
|
||||
| sudo tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null
|
||||
echo 'deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ jammy main' \
|
||||
| sudo tee /etc/apt/sources.list.d/kitware.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y cmake
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -DGGML_SCHED_NO_REALLOC=ON ${{ matrix.extra_build_flags }}
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
||||
|
||||
- name: Tests
|
||||
id: server_integration_tests
|
||||
if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }}
|
||||
run: |
|
||||
cd tools/server/tests
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
export ${{ matrix.extra_args }}
|
||||
pytest -v -x -m "not slow"
|
||||
|
||||
4
.github/workflows/ui-ci.yml
vendored
4
.github/workflows/ui-ci.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
ui-checks:
|
||||
name: UI Checks
|
||||
needs: ui-build
|
||||
runs-on: ubuntu-24.04-arm
|
||||
runs-on: ubuntu-latest
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@@ -93,7 +93,7 @@ jobs:
|
||||
e2e-tests:
|
||||
name: E2E Tests
|
||||
needs: ui-build
|
||||
runs-on: ubuntu-24.04-arm
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
|
||||
@@ -22,6 +22,8 @@ Pull requests (PRs):
|
||||
Commits:
|
||||
- On every commit that you make, include a "Assisted-by: llama.cpp:local pi" tag
|
||||
- Do not explicitly set the git author in commits - rely on the default git config
|
||||
- Always use `--no-gpg-sign` when committing
|
||||
- Never `git push` without explicit confirmation from the user
|
||||
|
||||
Resources (read on demand):
|
||||
- [CONTRIBUTING.md](CONTRIBUTING.md)
|
||||
|
||||
@@ -104,29 +104,26 @@ option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
|
||||
option(LLAMA_BUILD_COMMON "llama: build common utils library" ${LLAMA_STANDALONE})
|
||||
|
||||
# extra artifacts
|
||||
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
|
||||
# Deprecated: use LLAMA_BUILD_UI instead (kept for backward compat)
|
||||
option(LLAMA_BUILD_WEBUI "llama: build the embedded Web UI for server (deprecated: use LLAMA_BUILD_UI)" ON)
|
||||
option(LLAMA_USE_PREBUILT_WEBUI "llama: use prebuilt WebUI from HF Bucket when available (deprecated: use LLAMA_USE_PREBUILT_UI)" ON)
|
||||
|
||||
# New option names
|
||||
option(LLAMA_BUILD_UI "llama: build the embedded Web UI for server" ON)
|
||||
option(LLAMA_USE_PREBUILT_UI "llama: use prebuilt UI from HF Bucket when available (requires LLAMA_BUILD_UI=ON)" ON)
|
||||
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_APP "llama: build the unified binary" OFF)
|
||||
option(LLAMA_BUILD_UI "llama: build the embedded Web UI for server" ON)
|
||||
option(LLAMA_USE_PREBUILT_UI "llama: use prebuilt UI from HF Bucket when available (requires LLAMA_BUILD_UI=ON)" ON)
|
||||
|
||||
# Backward compat: when old var is set but new one isn't, forward the value
|
||||
if(DEFINED LLAMA_BUILD_WEBUI AND NOT DEFINED LLAMA_BUILD_UI)
|
||||
if(DEFINED LLAMA_BUILD_WEBUI)
|
||||
set(LLAMA_BUILD_UI ${LLAMA_BUILD_WEBUI})
|
||||
message(DEPRECATION "LLAMA_BUILD_WEBUI is deprecated, use LLAMA_BUILD_UI instead")
|
||||
endif()
|
||||
if(DEFINED LLAMA_USE_PREBUILT_WEBUI AND NOT DEFINED LLAMA_USE_PREBUILT_UI)
|
||||
if(DEFINED LLAMA_USE_PREBUILT_WEBUI)
|
||||
set(LLAMA_USE_PREBUILT_UI ${LLAMA_USE_PREBUILT_WEBUI})
|
||||
message(DEPRECATION "LLAMA_USE_PREBUILT_WEBUI is deprecated, use LLAMA_USE_PREBUILT_UI instead")
|
||||
endif()
|
||||
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
|
||||
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
|
||||
|
||||
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
|
||||
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
|
||||
|
||||
# 3rd party libs
|
||||
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON)
|
||||
@@ -231,6 +228,10 @@ if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TOOLS)
|
||||
add_subdirectory(tools)
|
||||
endif()
|
||||
|
||||
if (LLAMA_BUILD_APP)
|
||||
add_subdirectory(app)
|
||||
endif()
|
||||
|
||||
# Automatically add all files from the 'licenses' directory
|
||||
file(GLOB EXTRA_LICENSES "${CMAKE_SOURCE_DIR}/licenses/LICENSE-*")
|
||||
|
||||
@@ -286,18 +287,6 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/llama-config.cmake
|
||||
${CMAKE_CURRENT_BINARY_DIR}/llama-version.cmake
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/llama)
|
||||
|
||||
install(
|
||||
FILES convert_hf_to_gguf.py
|
||||
PERMISSIONS
|
||||
OWNER_READ
|
||||
OWNER_WRITE
|
||||
OWNER_EXECUTE
|
||||
GROUP_READ
|
||||
GROUP_EXECUTE
|
||||
WORLD_READ
|
||||
WORLD_EXECUTE
|
||||
DESTINATION ${CMAKE_INSTALL_BINDIR})
|
||||
|
||||
configure_file(cmake/llama.pc.in
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/llama.pc"
|
||||
@ONLY)
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
/common/fit.* @JohannesGaessler
|
||||
/common/jinja/ @CISC
|
||||
/common/ngram-map.* @srogmann
|
||||
/conversion/ @CISC
|
||||
/convert_*.py @CISC
|
||||
/docs/backend/snapdragon/ @ggml-org/ggml-hexagon
|
||||
/examples/batched.swift/ @ggerganov
|
||||
@@ -48,7 +49,6 @@
|
||||
/examples/parallel/ @ggerganov
|
||||
/examples/passkey/ @ggerganov
|
||||
/examples/retrieval/ @ggerganov
|
||||
/examples/save-load-state/ @ggerganov
|
||||
/examples/speculative-simple/ @ggerganov
|
||||
/examples/speculative/ @ggerganov
|
||||
/ggml/cmake/ @ggerganov
|
||||
|
||||
@@ -280,7 +280,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
| [Metal](docs/build.md#metal-build) | Apple Silicon |
|
||||
| [BLAS](docs/build.md#blas-build) | All |
|
||||
| [BLIS](docs/backend/BLIS.md) | All |
|
||||
| [SYCL](docs/backend/SYCL.md) | Intel and Nvidia GPU |
|
||||
| [SYCL](docs/backend/SYCL.md) | Intel GPU |
|
||||
| [OpenVINO [In Progress]](docs/backend/OPENVINO.md) | Intel CPUs, GPUs, and NPUs |
|
||||
| [MUSA](docs/build.md#musa) | Moore Threads GPU |
|
||||
| [CUDA](docs/build.md#cuda) | Nvidia GPU |
|
||||
|
||||
20
app/CMakeLists.txt
Normal file
20
app/CMakeLists.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
set(TARGET llama-app)
|
||||
|
||||
add_executable(${TARGET} llama.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
llama-server-impl
|
||||
llama-cli-impl
|
||||
llama-completion-impl
|
||||
llama-bench-impl
|
||||
llama-batched-bench-impl
|
||||
llama-fit-params-impl
|
||||
llama-quantize-impl
|
||||
llama-perplexity-impl
|
||||
)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
95
app/llama.cpp
Normal file
95
app/llama.cpp
Normal file
@@ -0,0 +1,95 @@
|
||||
#include "build-info.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// visible
|
||||
int llama_server(int argc, char ** argv);
|
||||
int llama_cli(int argc, char ** argv);
|
||||
|
||||
// hidden
|
||||
int llama_completion(int argc, char ** argv);
|
||||
int llama_bench(int argc, char ** argv);
|
||||
int llama_batched_bench(int argc, char ** argv);
|
||||
int llama_fit_params(int argc, char ** argv);
|
||||
int llama_quantize(int argc, char ** argv);
|
||||
int llama_perplexity(int argc, char ** argv);
|
||||
|
||||
static int help(int argc, char ** argv);
|
||||
static int version(int argc, char ** argv);
|
||||
|
||||
struct command {
|
||||
const char * name;
|
||||
const char * desc;
|
||||
std::vector<std::string> aliases;
|
||||
bool hidden;
|
||||
int (*func)(int, char **);
|
||||
};
|
||||
|
||||
static const command cmds[] = {
|
||||
{"serve", "HTTP API server", {"server"}, false, llama_server },
|
||||
{"cli", "Command-line interactive interface", {"client"}, false, llama_cli },
|
||||
{"completion", "Text completion", {"complete"}, true, llama_completion },
|
||||
{"bench", "Benchmark prompt processing and text generation", {}, true, llama_bench },
|
||||
{"batched-bench", "Benchmark batched decoding performance", {}, true, llama_batched_bench},
|
||||
{"fit-params", "Compute parameters to fit a model in device memory", {}, true, llama_fit_params },
|
||||
{"quantize", "Quantize a model", {}, true, llama_quantize },
|
||||
{"perplexity", "Compute model perplexity and KL divergence", {}, true, llama_perplexity },
|
||||
{"version", "Show version", {}, true, version },
|
||||
{"help", "Show available commands", {}, true, help },
|
||||
};
|
||||
|
||||
static int version(int argc, char ** argv) {
|
||||
printf("%s\n", llama_build_info());
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int help(int argc, char ** argv) {
|
||||
const bool show_all = argc >= 2 && std::string(argv[1]) == "all";
|
||||
|
||||
printf("Usage: llama <command> [options]\n\nAvailable commands:\n");
|
||||
|
||||
for (const auto & cmd : cmds) {
|
||||
if (show_all || !cmd.hidden) {
|
||||
printf(" %-15s %s\n", cmd.name, cmd.desc);
|
||||
}
|
||||
}
|
||||
printf("\nRun 'llama <command> --help' for command-specific usage.\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static bool matches(const std::string & arg, const command & cmd) {
|
||||
if (arg == cmd.name) {
|
||||
return true;
|
||||
}
|
||||
for (const auto & alias : cmd.aliases) {
|
||||
if (arg == alias) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
const std::string arg = argc >= 2 ? argv[1] : "help";
|
||||
|
||||
for (const auto & cmd : cmds) {
|
||||
if (matches(arg, cmd)) {
|
||||
|
||||
// router spawns children through this same binary, it needs the
|
||||
// subcommand to relaunch as 'llama serve' and not bare options
|
||||
#ifdef _WIN32
|
||||
_putenv_s("LLAMA_APP_CMD", cmd.name);
|
||||
#else
|
||||
setenv("LLAMA_APP_CMD", cmd.name, 1);
|
||||
#endif
|
||||
return cmd.func(argc - 1, argv + 1);
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "error: unknown command '%s'\n", arg.c_str());
|
||||
return 1;
|
||||
}
|
||||
14
ci/run.sh
14
ci/run.sh
@@ -117,6 +117,12 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then
|
||||
# if on Mac, disable METAL
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=OFF -DGGML_BLAS=OFF"
|
||||
|
||||
MACOS_RUNNER_CUSTOM_VULKAN_CMAKE_LOCATION="/usr/local/lib/cmake/vulkan"
|
||||
MACOS_RUNNER_CUSTOM_SPIRV_HEADERS_LOCATION="${MACOS_RUNNER_CUSTOM_VULKAN_CMAKE_LOCATION}/SPIRV-Headers/SPIRV-HeadersConfig.cmake"
|
||||
if [[ -f "${MACOS_RUNNER_CUSTOM_SPIRV_HEADERS_LOCATION}" || -h "${MACOS_RUNNER_CUSTOM_SPIRV_HEADERS_LOCATION}" ]]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DSPIRV-Headers_DIR=${MACOS_RUNNER_CUSTOM_VULKAN_CMAKE_LOCATION}/SPIRV-Headers"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Build shared libs on Windows
|
||||
@@ -455,10 +461,10 @@ function gg_run_qwen3_0_6b {
|
||||
|
||||
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
|
||||
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/test-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/test-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/test-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/test-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
|
||||
function check_ppl {
|
||||
qnt="$1"
|
||||
|
||||
@@ -7,7 +7,7 @@ set(LLAMA_SHARED_LIB @BUILD_SHARED_LIBS@)
|
||||
|
||||
set_and_check(LLAMA_INCLUDE_DIR "@PACKAGE_LLAMA_INCLUDE_INSTALL_DIR@")
|
||||
set_and_check(LLAMA_LIB_DIR "@PACKAGE_LLAMA_LIB_INSTALL_DIR@")
|
||||
set_and_check(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@")
|
||||
set(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@")
|
||||
|
||||
find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
#include "hf-cache.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
@@ -537,7 +536,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
|
||||
}
|
||||
if (!seen_args.insert(arg).second) {
|
||||
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||
const bool skip = (arg == "--spec-type");
|
||||
|
||||
if (!skip) {
|
||||
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||
}
|
||||
}
|
||||
auto & tmp = arg_to_options[arg];
|
||||
auto opt = *tmp.first;
|
||||
@@ -586,12 +589,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
// parse the first time to get -hf option (used for remote preset)
|
||||
parse_cli_args();
|
||||
|
||||
// TODO: Remove later
|
||||
try {
|
||||
hf_cache::migrate_old_cache_to_hf_cache(params.hf_token, params.offline);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("HF cache migration failed: %s\n", e.what());
|
||||
}
|
||||
// export_graph_ops loads only metadata
|
||||
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
@@ -900,7 +897,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
|
||||
}
|
||||
if (!seen_args.insert(arg).second) {
|
||||
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||
const bool skip = (arg == "--spec-type");
|
||||
|
||||
if (!skip) {
|
||||
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||
}
|
||||
}
|
||||
auto opt = *arg_to_options[arg];
|
||||
std::string val;
|
||||
@@ -3363,7 +3364,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
" - 1: error\n"
|
||||
" - 2: warning\n"
|
||||
" - 3: info\n"
|
||||
" - 4: debug\n"
|
||||
" - 4: trace (more info)\n"
|
||||
" - 5: debug\n"
|
||||
"(default: %d)\n", params.verbosity),
|
||||
[](common_params & params, int value) {
|
||||
params.verbosity = value;
|
||||
@@ -3589,6 +3591,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.speculative.draft.p_min = std::stof(value);
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-draft-backend-sampling"},
|
||||
{"--no-spec-draft-backend-sampling"},
|
||||
string_format("offload draft sampling to the backend (default: %s)",
|
||||
params.speculative.draft.backend_sampling ? "enabled" : "disabled"),
|
||||
[](common_params & params, bool value) {
|
||||
params.speculative.draft.backend_sampling = value;
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_BACKEND_SAMPLING"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-draft-device", "-devd", "--device-draft"}, "<dev1,dev2,..>",
|
||||
"comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"
|
||||
@@ -4124,6 +4135,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.speculative.ngram_mod.n_match = 24;
|
||||
params.speculative.ngram_mod.n_min = 48;
|
||||
params.speculative.ngram_mod.n_max = 64;
|
||||
|
||||
// TODO: not sure if this is a good config - explore more settings and potentially enable it
|
||||
//params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
|
||||
//params.speculative.ngram_map_k4v.size_n = 8;
|
||||
//params.speculative.ngram_map_k4v.size_m = 24;
|
||||
//params.speculative.ngram_map_k4v.min_hits = 2;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
|
||||
|
||||
@@ -43,11 +43,33 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
const autoparser & autoparser) {
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = autoparser.preserved_tokens;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = autoparser.preserved_tokens;
|
||||
|
||||
auto parser = autoparser.build_parser(inputs);
|
||||
std::string parser_generation_prompt = data.generation_prompt;
|
||||
|
||||
if (inputs.continue_final_message != COMMON_CHAT_CONTINUATION_NONE && !inputs.continue_msg.empty()) {
|
||||
// Build up generation prompt manually
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
if (!autoparser.reasoning.start.empty()) {
|
||||
data.generation_prompt = data.generation_prompt.substr(0, data.generation_prompt.find(autoparser.reasoning.start));
|
||||
data.generation_prompt += autoparser.reasoning.start + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += autoparser.reasoning.end;
|
||||
}
|
||||
}
|
||||
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = autoparser.build_parser(inputs, parser_generation_prompt);
|
||||
data.parser = parser.save();
|
||||
|
||||
// Build grammar if tools are present
|
||||
@@ -87,7 +109,7 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
return data;
|
||||
}
|
||||
|
||||
common_peg_arena autoparser::build_parser(const generation_params & inputs) const {
|
||||
common_peg_arena autoparser::build_parser(const generation_params & inputs, const std::string & generation_prompt) const {
|
||||
if (!analysis_complete) {
|
||||
throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)");
|
||||
}
|
||||
@@ -121,7 +143,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons
|
||||
} else {
|
||||
parser = content.build_parser(ctx);
|
||||
}
|
||||
return pure_content ? p.prefix(inputs.generation_prompt, reasoning.start) + parser : p.prefix(inputs.generation_prompt, reasoning.start) << parser;
|
||||
return pure_content ? p.prefix(generation_prompt, reasoning.start) + parser : p.prefix(generation_prompt, reasoning.start) << parser;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -60,16 +60,21 @@ struct generation_params {
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||
bool stream = true;
|
||||
std::string grammar;
|
||||
bool add_generation_prompt = false;
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
std::string generation_prompt;
|
||||
bool add_generation_prompt = false;
|
||||
common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE;
|
||||
common_chat_msg continue_msg;
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
json extra_context;
|
||||
bool add_bos = false;
|
||||
bool add_eos = false;
|
||||
bool is_inference = true;
|
||||
bool add_inference = false;
|
||||
bool mark_input = true; // whether to mark input strings in the jinja context
|
||||
|
||||
bool has_continuation() const {
|
||||
return continue_final_message != COMMON_CHAT_CONTINUATION_NONE && !continue_msg.empty();
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
@@ -386,7 +391,7 @@ struct autoparser {
|
||||
void analyze_template(const common_chat_template & tmpl);
|
||||
|
||||
// Build the PEG parser for this template
|
||||
common_peg_arena build_parser(const generation_params & inputs) const;
|
||||
common_peg_arena build_parser(const generation_params & inputs, const std::string & generation_prompt) const;
|
||||
|
||||
private:
|
||||
// Collect tokens from entire analysis to preserve
|
||||
|
||||
@@ -785,7 +785,7 @@ common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const s
|
||||
if (delimiter.empty()) {
|
||||
return literal(s);
|
||||
}
|
||||
return literal(s.substr(0, s.rfind(delimiter)));
|
||||
return literal(s.substr(0, s.find(delimiter)));
|
||||
}
|
||||
|
||||
common_peg_parser common_chat_peg_builder::optspace(const std::string & tag) {
|
||||
|
||||
255
common/chat.cpp
255
common/chat.cpp
@@ -70,6 +70,26 @@ static bool has_content_or_tool_calls(const common_chat_msg & msg) {
|
||||
return !msg.content.empty() || !msg.tool_calls.empty();
|
||||
}
|
||||
|
||||
std::string common_chat_msg::render_content(const std::string & delimiter) const {
|
||||
if (!content.empty() && !content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
}
|
||||
if (!content.empty()) {
|
||||
return content;
|
||||
}
|
||||
|
||||
std::string text;
|
||||
for (const auto & part : content_parts) {
|
||||
if (part.type == "text") {
|
||||
if (!text.empty()) {
|
||||
text += delimiter;
|
||||
}
|
||||
text += part.text;
|
||||
}
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
||||
if (!content.empty() && !content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
@@ -451,6 +471,22 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
|
||||
return result;
|
||||
}
|
||||
|
||||
common_chat_continuation common_chat_continuation_parse(const nlohmann::ordered_json & value) {
|
||||
if (value.is_boolean() && value.get<bool>()) {
|
||||
return COMMON_CHAT_CONTINUATION_AUTO;
|
||||
}
|
||||
if (value.is_string()) {
|
||||
auto value_str = value.get<std::string>();
|
||||
if (value_str == "reasoning_content") {
|
||||
return COMMON_CHAT_CONTINUATION_REASONING;
|
||||
}
|
||||
if (value_str == "content") {
|
||||
return COMMON_CHAT_CONTINUATION_CONTENT;
|
||||
}
|
||||
}
|
||||
return COMMON_CHAT_CONTINUATION_NONE;
|
||||
}
|
||||
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
@@ -811,6 +847,36 @@ std::string common_chat_template_direct_apply(
|
||||
return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
static std::string common_chat_template_generation_prompt_impl(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt) {
|
||||
|
||||
auto adjusted_messages = messages_override ? *messages_override : inputs.messages;
|
||||
|
||||
autoparser::generation_params params = inputs;
|
||||
params.add_generation_prompt = false;
|
||||
params.continue_final_message = COMMON_CHAT_CONTINUATION_NONE;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params, adjusted_messages, tools_override, additional_context);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params, adjusted_messages, tools_override, additional_context);
|
||||
|
||||
size_t prefix_len = 0;
|
||||
size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size());
|
||||
while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) {
|
||||
prefix_len++;
|
||||
}
|
||||
return gen_prompt.substr(prefix_len);
|
||||
}
|
||||
|
||||
std::string common_chat_template_generation_prompt(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
return common_chat_template_generation_prompt_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
@@ -863,6 +929,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
data.thinking_start_tag = "[THINK]";
|
||||
data.thinking_end_tag = "[/THINK]";
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
"[THINK]",
|
||||
@@ -871,8 +938,19 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
"[ARGS]",
|
||||
};
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = "[THINK]" + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += "[/THINK]" + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, "[THINK]");
|
||||
auto generation_prompt = p.eps();
|
||||
auto reasoning =
|
||||
extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps();
|
||||
|
||||
@@ -963,6 +1041,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
}
|
||||
|
||||
data.prompt = prompt;
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
|
||||
@@ -972,6 +1051,18 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
"<|channel|>", "<|constrain|>", "<|message|>", "<|start|>", "<|end|>",
|
||||
};
|
||||
|
||||
// Adjust prompt for continuation
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = "<|start|>assistant<|channel|>analysis<|message|>" + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += "<|end|><|start|>assistant<|channel|>final<|message|>" + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
@@ -1080,12 +1171,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
|
||||
if (inputs.add_generation_prompt && string_ends_with(data.prompt, "<turn|>\n")) {
|
||||
// This may happen if the model generates content + tool_call, the
|
||||
// template does not add the model's next turn and confuses the model
|
||||
// from emitting its proper reasoning token sequence.
|
||||
data.prompt += "<|turn>model\n";
|
||||
data.generation_prompt = "<|turn>model\n";
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
@@ -1101,13 +1194,25 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
"<|turn>",
|
||||
};
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = string_ends_with(data.prompt, "<turn|>\n") ? "<|turn>model\n" : "";
|
||||
data.generation_prompt += "<|channel>thought\n" + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += "<channel|>" + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto start = p.rule("start", p.prefix(inputs.generation_prompt, "<|channel>"));
|
||||
auto start = p.rule("start", p.optional(p.literal("<|turn>model\n")));
|
||||
|
||||
if (extract_reasoning) {
|
||||
p.rule("thought", p.literal("<|channel>thought") + p.space() + p.reasoning(p.until("<channel|>")) + p.literal("<channel|>"));
|
||||
@@ -1224,15 +1329,22 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
">>>all",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
data.generation_prompt = "<|start_header_id|>assistant<|end_header_id|>\n\n>>>all\n" + msg.render_content();
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
// Functionary v3.2 format:
|
||||
// - Normal content: >>>all\n{content}
|
||||
@@ -1244,7 +1356,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
// When no tools, content goes until end
|
||||
auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>"));
|
||||
auto content_until_end = p.literal("all\n") + p.content(p.rest());
|
||||
auto generation_prompt = p.literal(inputs.generation_prompt);
|
||||
auto generation_prompt = p.literal("<|start_header_id|>assistant<|end_header_id|>\n\n>>>");
|
||||
|
||||
// If no tools or tool_choice is NONE, just parse content
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
@@ -1318,9 +1430,10 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_calls_section_end|>",
|
||||
@@ -1343,10 +1456,22 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
const std::string GEN_PROMPT = "<|im_assistant|>assistant<|im_middle|>";
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += THINK_END + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
// Kimi K2 Thinking format:
|
||||
// - Reasoning: <think>{reasoning}</think>
|
||||
@@ -1366,7 +1491,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
auto reasoning = extract_reasoning ? p.optional(THINK_START + p.reasoning(
|
||||
p.until_one_of({ THINK_END, "<|tool_calls_section_begin|>", "<|tool_call_begin|>" })) +
|
||||
p.optional(p.literal(THINK_END))) : p.eps();
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto generation_prompt = p.literal(GEN_PROMPT);
|
||||
|
||||
|
||||
// Content only parser (no tools)
|
||||
@@ -1442,6 +1567,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
@@ -1461,12 +1587,24 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
const std::string TOOL_CALL_END = "<|tool_call_end|>";
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
const std::string GEN_PROMPT = "<|im_start|>assistant\n";
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += THINK_END + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto generation_prompt = p.literal(GEN_PROMPT);
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
@@ -1521,6 +1659,7 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
@@ -1536,12 +1675,24 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ
|
||||
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
const std::string GEN_PROMPT = "<|im_start|>assistant\n";
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += THINK_END + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto generation_prompt = p.literal(GEN_PROMPT);
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
@@ -1592,6 +1743,7 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = false;
|
||||
data.preserved_tokens = {
|
||||
@@ -1599,6 +1751,12 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
"<|role_sep|>\n",
|
||||
};
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
data.generation_prompt = "assistant<|role_sep|>\n" + msg.render_content();
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
|
||||
@@ -1634,7 +1792,7 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
ret = p.content(p.rest());
|
||||
}
|
||||
|
||||
return p.literal(inputs.generation_prompt) + ret;
|
||||
return p.literal("assistant<|role_sep|>\n") + ret;
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
@@ -1662,12 +1820,13 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
data.preserved_tokens = {
|
||||
data.preserved_tokens = {
|
||||
"|DSML|",
|
||||
"<think>",
|
||||
"</think>",
|
||||
@@ -1687,9 +1846,21 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha
|
||||
const std::string INVOKE_END = "</" + DSML + "invoke>";
|
||||
const std::string PARAM_START = "<" + DSML + "parameter";
|
||||
const std::string PARAM_END = "</" + DSML + "parameter>";
|
||||
const std::string GEN_PROMPT = "<|Assistant|>";
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += THINK_END + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto generation_prompt = p.literal(GEN_PROMPT);
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
@@ -2116,21 +2287,6 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::string common_chat_templates_generation_prompt(const common_chat_template & tmpl, const autoparser::generation_params & inputs) {
|
||||
autoparser::generation_params params = inputs;
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
|
||||
size_t prefix_len = 0;
|
||||
size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size());
|
||||
while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) {
|
||||
prefix_len++;
|
||||
}
|
||||
return gen_prompt.substr(prefix_len);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs) {
|
||||
autoparser::generation_params params;
|
||||
@@ -2149,6 +2305,27 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
params.continue_final_message = inputs.continue_final_message;
|
||||
if (params.continue_final_message != COMMON_CHAT_CONTINUATION_NONE) {
|
||||
params.add_generation_prompt = false;
|
||||
|
||||
if (!inputs.messages.empty()) {
|
||||
// Render messages[:-1] and store continuation message separately
|
||||
params.continue_msg = inputs.messages.back();
|
||||
params.messages.erase(params.messages.size() - 1);
|
||||
}
|
||||
|
||||
if (params.continue_final_message == COMMON_CHAT_CONTINUATION_AUTO && !inputs.messages.empty()) {
|
||||
// Resolve based on message content
|
||||
params.continue_final_message = COMMON_CHAT_CONTINUATION_CONTENT;
|
||||
if (!params.continue_msg.reasoning_content.empty() &&
|
||||
params.continue_msg.content.empty() &&
|
||||
params.continue_msg.content_parts.empty()) {
|
||||
params.continue_final_message = COMMON_CHAT_CONTINUATION_REASONING;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
workaround::map_developer_role_to_system(params.messages);
|
||||
@@ -2169,8 +2346,6 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
workaround::func_args_not_string(params.messages);
|
||||
}
|
||||
|
||||
params.generation_prompt = common_chat_templates_generation_prompt(tmpl, params);
|
||||
|
||||
params.extra_context = common_chat_extra_context();
|
||||
for (auto el : inputs.chat_template_kwargs) {
|
||||
params.extra_context[el.first] = json::parse(el.second);
|
||||
@@ -2200,17 +2375,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
auto params_copy = params;
|
||||
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.generation_prompt = params.generation_prompt;
|
||||
auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) {
|
||||
return p.prefix(params.generation_prompt) << p.content(p.rest());
|
||||
auto parser = build_chat_peg_parser([&data](common_chat_peg_builder &p) {
|
||||
return p.literal(data.generation_prompt) << p.content(p.rest());
|
||||
});
|
||||
data.parser = parser.save();
|
||||
return data;
|
||||
}
|
||||
|
||||
if (auto result = common_chat_try_specialized_template(tmpl, src, params)) {
|
||||
result->generation_prompt = params.generation_prompt;
|
||||
return *result;
|
||||
}
|
||||
|
||||
@@ -2224,7 +2398,6 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start);
|
||||
auto_params.thinking_end_tag = trim_whitespace(autoparser.reasoning.end);
|
||||
}
|
||||
auto_params.generation_prompt = params.generation_prompt;
|
||||
common_peg_arena arena;
|
||||
arena.load(auto_params.parser);
|
||||
LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str());
|
||||
|
||||
@@ -89,6 +89,8 @@ struct common_chat_msg {
|
||||
|
||||
nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const;
|
||||
|
||||
std::string render_content(const std::string & delimiter = "\n\n") const;
|
||||
|
||||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() &&
|
||||
tool_name.empty() && tool_call_id.empty();
|
||||
@@ -164,12 +166,22 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
|
||||
// Continuation method provided via `continue_final_message`
|
||||
enum common_chat_continuation {
|
||||
COMMON_CHAT_CONTINUATION_NONE,
|
||||
COMMON_CHAT_CONTINUATION_AUTO,
|
||||
COMMON_CHAT_CONTINUATION_REASONING,
|
||||
COMMON_CHAT_CONTINUATION_CONTENT,
|
||||
};
|
||||
|
||||
struct common_chat_templates_inputs {
|
||||
std::vector<common_chat_msg> messages;
|
||||
std::string grammar;
|
||||
std::string json_schema;
|
||||
bool add_generation_prompt = true;
|
||||
bool use_jinja = true;
|
||||
bool add_generation_prompt = true;
|
||||
common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE;
|
||||
bool use_jinja = true;
|
||||
// Parameters below only supported when use_jinja is true
|
||||
std::vector<common_chat_tool> tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
@@ -207,6 +219,7 @@ struct common_chat_parser_params {
|
||||
bool reasoning_in_content = false;
|
||||
std::string generation_prompt;
|
||||
bool parse_tool_calls = true;
|
||||
bool echo = false; // Include assistant prefilled msg in output
|
||||
bool debug = false; // Enable debug output for PEG parser
|
||||
common_peg_arena parser = {};
|
||||
common_chat_parser_params() = default;
|
||||
@@ -267,6 +280,8 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::or
|
||||
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
|
||||
|
||||
common_chat_continuation common_chat_continuation_parse(const nlohmann::ordered_json & value);
|
||||
|
||||
// DEPRECATED: only used in tests
|
||||
nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
|
||||
@@ -279,6 +294,10 @@ std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs);
|
||||
|
||||
std::string common_chat_template_generation_prompt(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs);
|
||||
|
||||
std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
|
||||
@@ -1160,7 +1160,7 @@ struct common_init_result::impl {
|
||||
std::vector<llama_sampler_seq_config> samplers_seq_config;
|
||||
};
|
||||
|
||||
common_init_result::common_init_result(common_params & params) :
|
||||
common_init_result::common_init_result(common_params & params, bool model_only) :
|
||||
pimpl(new impl{}) {
|
||||
auto mparams = common_model_params_to_llama(params);
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
@@ -1173,7 +1173,7 @@ common_init_result::common_init_result(common_params & params) :
|
||||
params.tensor_buft_overrides.data(),
|
||||
params.fit_params_target.data(),
|
||||
params.fit_params_min_ctx,
|
||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
params.verbosity >= LOG_LEVEL_DEBUG ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
}
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||
@@ -1183,6 +1183,10 @@ common_init_result::common_init_result(common_params & params) :
|
||||
|
||||
pimpl->model.reset(model);
|
||||
|
||||
if (model_only) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// load and optionally apply lora adapters
|
||||
@@ -1252,29 +1256,6 @@ common_init_result::common_init_result(common_params & params) :
|
||||
cparams.n_samplers = pimpl->samplers_seq_config.size();
|
||||
}
|
||||
|
||||
// [TAG_RS_STATE_ROLLBACK_SUPPORT]
|
||||
// TODO: ngram speculative methods require checkpointing in addition to partial RS rollback
|
||||
// currently this is not supported. so we disable the partial rollback
|
||||
if (cparams.n_rs_seq > 0 && (llama_model_is_recurrent(model) || llama_model_is_hybrid(model))) {
|
||||
auto & types = params.speculative.types;
|
||||
|
||||
for (int i = 0; i < (int) types.size(); i++) {
|
||||
if (types[i] == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
continue;
|
||||
}
|
||||
if (types[i] == COMMON_SPECULATIVE_TYPE_DRAFT_MTP) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cparams.n_rs_seq = 0;
|
||||
|
||||
LOG_WRN("%s: recurrent state rollback is not compatible with '%s' - disabling rollback support\n", __func__,
|
||||
common_speculative_type_to_str(types[i]).c_str());
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
@@ -1309,8 +1290,8 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||
return pimpl->lora;
|
||||
}
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
common_init_result_ptr res(new common_init_result(params));
|
||||
common_init_result_ptr common_init_from_params(common_params & params, bool model_only) {
|
||||
common_init_result_ptr res(new common_init_result(params, model_only));
|
||||
|
||||
llama_model * model = res->model();
|
||||
if (model == NULL) {
|
||||
@@ -1318,6 +1299,10 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
return res;
|
||||
}
|
||||
|
||||
if (model_only) {
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_context * lctx = res->context();
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
@@ -1381,7 +1366,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
LOG_INF("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
|
||||
llama_set_warmup(lctx, true);
|
||||
|
||||
|
||||
@@ -299,11 +299,13 @@ struct common_params_model {
|
||||
|
||||
// draft-model-based speculative decoding parameters
|
||||
struct common_params_speculative_draft {
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
int32_t n_max = 3; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy) // TODO: change default to 0.0f
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.0f; // minimum speculative decoding probability (greedy)
|
||||
|
||||
bool backend_sampling = true; // offload draft sampling to the backend (default: on)
|
||||
|
||||
common_params_model mparams;
|
||||
|
||||
@@ -617,8 +619,6 @@ struct common_params {
|
||||
// UI configs
|
||||
#ifdef LLAMA_UI_DEFAULT_ENABLED
|
||||
bool ui = LLAMA_UI_DEFAULT_ENABLED != 0;
|
||||
#elif defined(LLAMA_WEBUI_DEFAULT_ENABLED)
|
||||
bool ui = LLAMA_WEBUI_DEFAULT_ENABLED != 0;
|
||||
#else
|
||||
bool ui = true; // default to enabled when not set
|
||||
#endif
|
||||
@@ -859,7 +859,7 @@ struct common_sampler;
|
||||
|
||||
// note: defines the model, context, samplers, ets. lifetimes
|
||||
struct common_init_result {
|
||||
common_init_result(common_params & params);
|
||||
common_init_result(common_params & params, bool model_only = false);
|
||||
~common_init_result();
|
||||
|
||||
llama_model * model();
|
||||
@@ -877,7 +877,7 @@ private:
|
||||
|
||||
using common_init_result_ptr = std::unique_ptr<common_init_result>;
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params);
|
||||
common_init_result_ptr common_init_from_params(common_params & params, bool model_only = false);
|
||||
|
||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <atomic>
|
||||
#include <regex> // migration only
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <stdexcept>
|
||||
@@ -336,15 +335,9 @@ hf_files get_repo_files(const std::string & repo_id,
|
||||
if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) {
|
||||
file.oid = item["lfs"]["oid"].get<std::string>();
|
||||
}
|
||||
if (item["lfs"].contains("size") && item["lfs"]["size"].is_number()) {
|
||||
file.size = item["lfs"]["size"].get<size_t>();
|
||||
}
|
||||
} else if (item.contains("oid") && item["oid"].is_string()) {
|
||||
file.oid = item["oid"].get<std::string>();
|
||||
}
|
||||
if (file.size == 0 && item.contains("size") && item["size"].is_number()) {
|
||||
file.size = item["size"].get<size_t>();
|
||||
}
|
||||
|
||||
if (!file.oid.empty() && !is_valid_oid(file.oid)) {
|
||||
LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str());
|
||||
@@ -502,271 +495,4 @@ std::string finalize_file(const hf_file & file) {
|
||||
return file.final_path;
|
||||
}
|
||||
|
||||
// delete everything after this line, one day
|
||||
|
||||
// copied from download.cpp without the tag part
|
||||
struct gguf_split_info {
|
||||
std::string prefix; // tag included
|
||||
int index;
|
||||
int count;
|
||||
};
|
||||
|
||||
static gguf_split_info get_gguf_split_info(const std::string & path) {
|
||||
static const std::regex re_split("^(.+)-([0-9]{5})-of-([0-9]{5})$", std::regex::icase);
|
||||
std::smatch m;
|
||||
|
||||
std::string prefix = path;
|
||||
if (!string_remove_suffix(prefix, ".gguf")) {
|
||||
return {};
|
||||
}
|
||||
|
||||
int index = 1;
|
||||
int count = 1;
|
||||
|
||||
if (std::regex_match(prefix, m, re_split)) {
|
||||
index = std::stoi(m[2].str());
|
||||
count = std::stoi(m[3].str());
|
||||
prefix = m[1].str();
|
||||
}
|
||||
|
||||
return {std::move(prefix), index, count};
|
||||
}
|
||||
|
||||
static std::pair<std::string, std::string> parse_manifest_name(std::string & filename) {
|
||||
static const std::regex re(R"(^manifest=([^=]+)=([^=]+)=.*\.json$)");
|
||||
std::smatch match;
|
||||
if (std::regex_match(filename, match, re)) {
|
||||
return {match[1].str(), match[2].str()};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::string make_old_cache_filename(const std::string & owner,
|
||||
const std::string & repo,
|
||||
const std::string & filename) {
|
||||
auto result = owner + "_" + repo + "_" + filename;
|
||||
string_replace_all(result, "/", "_");
|
||||
return result;
|
||||
}
|
||||
|
||||
struct migrate_file {
|
||||
std::string path;
|
||||
std::string sha256;
|
||||
size_t size;
|
||||
fs::path old_path;
|
||||
fs::path etag_path;
|
||||
const hf_file * file;
|
||||
};
|
||||
|
||||
using migrate_files = std::vector<migrate_file>;
|
||||
|
||||
static bool collect_file(const fs::path & old_cache,
|
||||
const std::string & owner,
|
||||
const std::string & repo,
|
||||
const std::string & path,
|
||||
const std::string & sha256,
|
||||
const hf_files & files,
|
||||
migrate_files & to_migrate) {
|
||||
|
||||
const hf_file * file = nullptr;
|
||||
|
||||
for (const auto & f : files) {
|
||||
if (f.path == path) {
|
||||
file = &f;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::string old_filename = make_old_cache_filename(owner, repo, path);
|
||||
fs::path old_path = old_cache / old_filename;
|
||||
fs::path etag_path = old_path.string() + ".etag";
|
||||
|
||||
if (!fs::exists(old_path)) {
|
||||
if (file && fs::exists(file->final_path)) {
|
||||
return true;
|
||||
}
|
||||
LOG_WRN("%s: %s not found in old cache or HF cache\n", __func__, old_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!file) {
|
||||
LOG_WRN("%s: %s not found in current repo\n", __func__, old_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!sha256.empty() && !file->oid.empty() && sha256 != file->oid) {
|
||||
LOG_WRN("%s: %s is not up to date (sha256 mismatch)\n", __func__, old_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (file->size > 0) {
|
||||
size_t size = fs::file_size(old_path);
|
||||
if (size != file->size) {
|
||||
LOG_WRN("%s: %s has wrong size %zu (expected %zu)\n", __func__, old_filename.c_str(), size, file->size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
to_migrate.push_back({path, sha256, file->size, old_path, etag_path, file});
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool collect_files(const fs::path & old_cache,
|
||||
const std::string & owner,
|
||||
const std::string & repo,
|
||||
const nl::json & node,
|
||||
const hf_files & files,
|
||||
migrate_files & to_migrate) {
|
||||
|
||||
if (!node.contains("rfilename") ||
|
||||
!node.contains("lfs") ||
|
||||
!node["lfs"].contains("sha256")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string path = node["rfilename"];
|
||||
std::string sha256 = node["lfs"]["sha256"];
|
||||
|
||||
auto split = get_gguf_split_info(path);
|
||||
|
||||
if (split.count <= 1) {
|
||||
return collect_file(old_cache, owner, repo, path, sha256, files, to_migrate);
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> splits;
|
||||
|
||||
for (const auto & f : files) {
|
||||
auto split_f = get_gguf_split_info(f.path);
|
||||
if (split_f.count == split.count && split_f.prefix == split.prefix) {
|
||||
// sadly the manifest only provides the sha256 of the first file (index == 1)
|
||||
// the rest will be verified using the size...
|
||||
std::string f_sha256 = (split_f.index == 1) ? sha256 : "";
|
||||
splits.emplace_back(f.path, f_sha256);
|
||||
}
|
||||
}
|
||||
|
||||
if ((int)splits.size() != split.count) {
|
||||
LOG_WRN("%s: expected %d split files but found %d in repo\n", __func__, split.count, (int)splits.size());
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto & [f_path, f_sha256] : splits) {
|
||||
if (!collect_file(old_cache, owner, repo, f_path, f_sha256, files, to_migrate)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool migrate_file(const migrate_file & file) {
|
||||
std::error_code ec;
|
||||
|
||||
fs::path new_path(file.file->local_path);
|
||||
fs::create_directories(new_path.parent_path(), ec);
|
||||
|
||||
if (!fs::exists(new_path, ec)) {
|
||||
fs::rename(file.old_path, new_path, ec);
|
||||
if (ec) {
|
||||
fs::copy_file(file.old_path, new_path, ec);
|
||||
if (ec) {
|
||||
LOG_ERR("%s: failed to move/copy %s: %s\n", __func__, file.old_path.string().c_str(), ec.message().c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
fs::remove(file.old_path, ec);
|
||||
}
|
||||
fs::remove(file.etag_path, ec);
|
||||
|
||||
std::string filename = finalize_file(*file.file);
|
||||
LOG_INF("%s: migrated %s -> %s\n", __func__, file.old_path.filename().string().c_str(), filename.c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) {
|
||||
fs::path old_cache = fs_get_cache_directory();
|
||||
if (!fs::exists(old_cache)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (offline) {
|
||||
LOG_WRN("%s: skipping migration in offline mode (will run when online)\n", __func__);
|
||||
return; // -hf is not going to work
|
||||
}
|
||||
|
||||
bool warned = false;
|
||||
|
||||
for (const auto & entry : fs::directory_iterator(old_cache)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
auto filename = entry.path().filename().string();
|
||||
auto [owner, repo] = parse_manifest_name(filename);
|
||||
|
||||
if (owner.empty() || repo.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!warned) {
|
||||
warned = true;
|
||||
LOG_WRN("================================================================================\n"
|
||||
"WARNING: Migrating cache to HuggingFace cache directory\n"
|
||||
" Old cache: %s\n"
|
||||
" New cache: %s\n"
|
||||
"This one-time migration moves models previously downloaded with -hf\n"
|
||||
"from the legacy llama.cpp cache to the standard HuggingFace cache.\n"
|
||||
"Models downloaded with --model-url are not affected.\n"
|
||||
"================================================================================\n",
|
||||
old_cache.string().c_str(), get_cache_directory().string().c_str());
|
||||
}
|
||||
|
||||
auto repo_id = owner + "/" + repo;
|
||||
auto files = get_repo_files(repo_id, token);
|
||||
|
||||
if (files.empty()) {
|
||||
LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
migrate_files to_migrate;
|
||||
bool ok = true;
|
||||
|
||||
try {
|
||||
std::ifstream manifest(entry.path());
|
||||
auto json = nl::json::parse(manifest);
|
||||
for (const char * key : {"ggufFile", "mmprojFile"}) {
|
||||
if (json.contains(key)) {
|
||||
if (!collect_files(old_cache, owner, repo, json[key], files, to_migrate)) {
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("%s: failed to parse manifest %s: %s\n", __func__, filename.c_str(), e.what());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
LOG_WRN("%s: migration skipped: one or more files failed validation\n", __func__);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const auto & file : to_migrate) {
|
||||
if (!migrate_file(file)) {
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
LOG_WRN("%s: migration failed: could not migrate all files\n", __func__);
|
||||
continue;
|
||||
}
|
||||
|
||||
LOG_INF("%s: migration complete, deleting manifest: %s\n", __func__, entry.path().string().c_str());
|
||||
fs::remove(entry.path());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace hf_cache
|
||||
|
||||
@@ -14,7 +14,6 @@ struct hf_file {
|
||||
std::string final_path;
|
||||
std::string oid;
|
||||
std::string repo_id;
|
||||
size_t size = 0; // only for the migration
|
||||
};
|
||||
|
||||
using hf_files = std::vector<hf_file>;
|
||||
@@ -30,7 +29,4 @@ hf_files get_cached_files(const std::string & repo_id = {});
|
||||
// Create snapshot path (link or move/copy) and return it
|
||||
std::string finalize_file(const hf_file & file);
|
||||
|
||||
// TODO: Remove later
|
||||
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false);
|
||||
|
||||
} // namespace hf_cache
|
||||
|
||||
@@ -500,7 +500,7 @@ void common_ngram_map_draft(common_ngram_map & map,
|
||||
draft.push_back(inp[match_pos + n + i]);
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
|
||||
LOG_DBG("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
|
||||
key_offset, slot_max,
|
||||
curr_key.key_num, draft.size());
|
||||
|
||||
|
||||
@@ -32,6 +32,18 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
|
||||
{"ngram-cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}
|
||||
};
|
||||
|
||||
static std::string common_speculative_get_devices_str(const std::vector<ggml_backend_dev_t> & devices) {
|
||||
std::string result;
|
||||
for (size_t i = 0; i < devices.size(); i++) {
|
||||
if (devices[i] == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (!result.empty()) result += ", ";
|
||||
result += ggml_backend_dev_name(devices[i]);
|
||||
}
|
||||
return result.empty() ? "default" : result;
|
||||
}
|
||||
|
||||
struct common_speculative_config {
|
||||
common_speculative_type type;
|
||||
common_params_speculative params;
|
||||
@@ -144,10 +156,13 @@ struct common_speculative_impl {
|
||||
|
||||
virtual void draft(common_speculative_draft_params_vec & dparams) = 0;
|
||||
|
||||
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
|
||||
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0;
|
||||
|
||||
// true if this implementation requires the target context to extract embeddings
|
||||
// true if this implementation requires the target context to extract post-norm embeddings
|
||||
virtual bool need_embd() const = 0;
|
||||
|
||||
// true if this implementation requires the target context to extract pre-norm embeddings
|
||||
virtual bool need_embd_pre_norm() const { return false; }
|
||||
};
|
||||
|
||||
struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
@@ -164,6 +179,16 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-simple'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
|
||||
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
|
||||
this->params.n_gpu_layers,
|
||||
ggml_type_name(this->params.cache_type_k),
|
||||
ggml_type_name(this->params.cache_type_v),
|
||||
ctx_tgt ? "yes" : "no",
|
||||
ctx_dft ? "yes" : "no",
|
||||
common_speculative_get_devices_str(this->params.devices).c_str());
|
||||
|
||||
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
||||
|
||||
// TODO: optimize or pass from outside?
|
||||
@@ -340,7 +365,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
@@ -352,8 +377,12 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
//common_params_speculative_eagle3 params;
|
||||
|
||||
common_speculative_impl_draft_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq) {}
|
||||
common_speculative_impl_draft_eagle3(const common_params_speculative & params, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min);
|
||||
}
|
||||
|
||||
void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override {
|
||||
// noop
|
||||
@@ -368,7 +397,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
// TODO: implement
|
||||
}
|
||||
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
@@ -377,13 +406,16 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
||||
struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)
|
||||
|
||||
llama_batch batch;
|
||||
|
||||
std::vector<common_sampler_ptr> smpls;
|
||||
|
||||
// backend sampler chain per seq, attached to ctx_dft
|
||||
std::vector<llama_sampler *> backend_chains;
|
||||
|
||||
int32_t n_embd = 0;
|
||||
|
||||
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
|
||||
@@ -404,7 +436,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
||||
// pre-advancement before process() mirrored the verify batch.
|
||||
std::vector<uint16_t> last_n_drafted;
|
||||
|
||||
common_speculative_state_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
|
||||
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
|
||||
, params(params.draft)
|
||||
{
|
||||
@@ -414,6 +446,16 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
||||
|
||||
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
|
||||
this->params.n_gpu_layers,
|
||||
ggml_type_name(this->params.cache_type_k),
|
||||
ggml_type_name(this->params.cache_type_v),
|
||||
ctx_tgt ? "yes" : "no",
|
||||
ctx_dft ? "yes" : "no",
|
||||
common_speculative_get_devices_str(this->params.devices).c_str());
|
||||
|
||||
const int32_t n_b = (int32_t) llama_n_batch(ctx_dft);
|
||||
batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
|
||||
// llama_batch_init allocates only one of token/embd; MTP needs both.
|
||||
@@ -424,13 +466,29 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
||||
for (auto & s : smpls) {
|
||||
common_params_sampling sparams;
|
||||
sparams.no_perf = false;
|
||||
sparams.top_k = 1; // TODO: re-enable top_k == 10 and utilize `p_min` spec param
|
||||
sparams.top_k = 10;
|
||||
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
|
||||
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
|
||||
}
|
||||
|
||||
llama_set_embeddings_pre_norm(ctx_tgt, true);
|
||||
llama_set_embeddings_pre_norm(ctx_dft, true);
|
||||
// offload draft sampling to the backend
|
||||
backend_chains.assign(n_seq, nullptr);
|
||||
if (this->params.backend_sampling) {
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
|
||||
|
||||
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
|
||||
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
|
||||
llama_sampler_free(chain);
|
||||
chain = nullptr;
|
||||
}
|
||||
backend_chains[seq_id] = chain;
|
||||
}
|
||||
}
|
||||
|
||||
llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
|
||||
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
@@ -443,7 +501,19 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
||||
last_n_drafted.assign(n_seq, 0);
|
||||
}
|
||||
|
||||
~common_speculative_state_draft_mtp() override {
|
||||
~common_speculative_impl_draft_mtp() override {
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) backend_chains.size(); ++seq_id) {
|
||||
if (backend_chains[seq_id] == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (ctx_dft) {
|
||||
llama_set_sampler(ctx_dft, seq_id, nullptr);
|
||||
}
|
||||
llama_sampler_free(backend_chains[seq_id]);
|
||||
}
|
||||
backend_chains.clear();
|
||||
|
||||
if (batch.token != nullptr) {
|
||||
free(batch.token);
|
||||
batch.token = nullptr;
|
||||
@@ -459,7 +529,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
if (pos_max < N - 1) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d — "
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
|
||||
"process() hook may not have run on every prefill ubatch "
|
||||
"(need_embd / logits=1 on every prompt position?). "
|
||||
"Drafts may degrade.\n",
|
||||
@@ -630,6 +700,14 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
||||
// add drafted token for each sequence
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
// only collect very high-confidence draft tokens
|
||||
if (cur_p->data[0].p < params.p_min) {
|
||||
drafting[seq_id] = false;
|
||||
n_drafting--;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
common_sampler_accept(smpl, id, true);
|
||||
|
||||
auto & dp = dparams.at(seq_id);
|
||||
@@ -675,7 +753,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
||||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool /*is_other*/) override {
|
||||
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
|
||||
return;
|
||||
}
|
||||
@@ -691,6 +769,10 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
||||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool need_embd_pre_norm() const override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
@@ -707,7 +789,12 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
|
||||
common_ngram_simple_config config)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq)
|
||||
, params(params.ngram_simple)
|
||||
, config(config) {}
|
||||
, config(config)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-simple'\n", __func__);
|
||||
LOG_INF("%s: - size_n=%d, size_m=%d, min_hits=%d\n", __func__,
|
||||
this->params.size_n, this->params.size_m, this->params.min_hits);
|
||||
}
|
||||
|
||||
void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override {
|
||||
// noop
|
||||
@@ -731,7 +818,7 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
|
||||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
@@ -741,20 +828,21 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
|
||||
};
|
||||
|
||||
struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
|
||||
common_params_speculative_ngram_map params;
|
||||
|
||||
// n_seq configs
|
||||
std::vector<common_ngram_map> config;
|
||||
|
||||
common_speculative_impl_ngram_map_k(
|
||||
const common_params_speculative & params,
|
||||
const common_ngram_map & config,
|
||||
uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq)
|
||||
, params(params.ngram_map_k) {
|
||||
{
|
||||
for (uint32_t i = 0; i < n_seq; i++) {
|
||||
this->config.push_back(config);
|
||||
}
|
||||
|
||||
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(this->type).c_str());
|
||||
LOG_INF("%s: - size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n", __func__,
|
||||
config.size_key, config.size_value, config.key_only, config.min_hits);
|
||||
}
|
||||
|
||||
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
|
||||
@@ -781,9 +869,13 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
|
||||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) override {
|
||||
GGML_ASSERT((seq_id < (llama_seq_id) config.size()));
|
||||
|
||||
if (is_other) {
|
||||
return;
|
||||
}
|
||||
|
||||
common_ngram_map_accept(config[seq_id], n_accepted);
|
||||
}
|
||||
|
||||
@@ -805,7 +897,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
// the last position in the prompt that was added to the ngram container
|
||||
size_t i_last = 0;
|
||||
|
||||
// length of the last drafted n‑gram (number of tokens returned by draft)
|
||||
// length of the last drafted n-gram (number of tokens returned by draft)
|
||||
size_t n_draft_last = 0;
|
||||
|
||||
// consecutive accept rounds with low acceptance fraction (< 0.5)
|
||||
@@ -823,8 +915,11 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
, verbose(std::getenv("LLAMA_TRACE") != nullptr) {
|
||||
static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
|
||||
|
||||
LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__,
|
||||
this->params.n_match, mod.size(), (float)(mod.size_bytes())/1024/1024);
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-mod'\n", __func__);
|
||||
LOG_INF("%s: - n_match=%d, n_max=%d, n_min=%d\n", __func__,
|
||||
this->params.n_match, this->params.n_max, this->params.n_min);
|
||||
LOG_INF("%s: - mod size=%zu (%.3f MB)\n", __func__,
|
||||
mod.size(), (float)(mod.size_bytes())/1024/1024);
|
||||
|
||||
if (this->params.n_match < 16) {
|
||||
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
|
||||
@@ -914,7 +1009,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
}
|
||||
result.resize(result.size() - n);
|
||||
|
||||
// store length of drafted n‑gram for later acceptance analysis
|
||||
// store length of drafted n-gram for later acceptance analysis
|
||||
sinfo.n_draft_last = result.size();
|
||||
}
|
||||
|
||||
@@ -936,17 +1031,21 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) override {
|
||||
if (is_other) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto & sinfo = sinfos[seq_id];
|
||||
|
||||
// compute acceptance fraction if we have a recorded draft length
|
||||
if (sinfo.n_draft_last > 0) {
|
||||
const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last;
|
||||
if (f_acc < 0.5) {
|
||||
if (f_acc < 0.25) {
|
||||
sinfo.n_low++;
|
||||
if (sinfo.n_low >= 3) {
|
||||
if (sinfo.n_low >= 5) {
|
||||
if (verbose) {
|
||||
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, sinfo.n_low);
|
||||
LOG_WRN("%s: low acceptance streak (%d) - resetting ngram_mod\n", __func__, sinfo.n_low);
|
||||
}
|
||||
|
||||
mod.reset();
|
||||
@@ -996,6 +1095,12 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
, save_dynamic(save_dynamic)
|
||||
, save_static(save_static)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-cache'\n", __func__);
|
||||
LOG_INF("%s: - n_draft=%d, cache_static=%s, cache_dynamic=%s\n", __func__,
|
||||
n_draft,
|
||||
path_static.empty() ? "none" : path_static.c_str(),
|
||||
path_dynamic.empty() ? "none" : path_dynamic.c_str());
|
||||
|
||||
sinfos.resize(n_seq);
|
||||
|
||||
if (!path_static.empty()) {
|
||||
@@ -1092,7 +1197,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
@@ -1278,7 +1383,6 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
|
||||
|
||||
for (const common_speculative_config & config : configs) {
|
||||
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(config.type).c_str());
|
||||
switch (config.type) {
|
||||
case COMMON_SPECULATIVE_TYPE_NONE:
|
||||
break;
|
||||
@@ -1291,7 +1395,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_draft_mtp>(config.params, n_seq));
|
||||
impls.push_back(std::make_unique<common_speculative_impl_draft_mtp>(config.params, n_seq));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
|
||||
@@ -1312,11 +1416,16 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
impls.push_back(std::move(state));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: {
|
||||
impls.push_back(
|
||||
std::make_unique<common_speculative_impl_ngram_map_k>(
|
||||
get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
|
||||
impls.push_back(
|
||||
std::make_unique<common_speculative_impl_ngram_map_k>(
|
||||
config.params, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq));
|
||||
get_common_ngram_map(config.type, config.params.ngram_map_k4v), n_seq));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
|
||||
@@ -1408,6 +1517,20 @@ bool common_speculative_need_embd(common_speculative * spec) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool common_speculative_need_embd_pre_norm(common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
if (impl->need_embd_pre_norm()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void common_speculative_draft(common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
@@ -1494,11 +1617,6 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
|
||||
|
||||
GGML_ASSERT(impl);
|
||||
|
||||
// TODO: currently only the implementation that generated the draft is used to accept it
|
||||
// however, some implementations (such as MTP) need to also "see" the accepted tokens
|
||||
// extend `common_speculative_impl::accept()` with an extra argument `bool is_other` to
|
||||
// inform the implementation if the accepted tokens are from another implementation and
|
||||
// pass the accepted tokens to all remaining implementations using `is_other == true`
|
||||
{
|
||||
common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
|
||||
if (n_accepted > 0) {
|
||||
@@ -1506,9 +1624,16 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
|
||||
impl->n_acc_tokens += n_accepted;
|
||||
}
|
||||
|
||||
impl->accept(seq_id, n_accepted);
|
||||
impl->accept(seq_id, n_accepted, false);
|
||||
impl->n_call_accept++;
|
||||
}
|
||||
|
||||
// accept with the rest of the implementations, using is_other == true
|
||||
for (auto & impl_other : spec->impls) {
|
||||
if (impl_other.get() != impl) {
|
||||
impl_other->accept(seq_id, n_accepted, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void common_speculative_print_stats(const common_speculative * spec) {
|
||||
@@ -1528,7 +1653,7 @@ void common_speculative_print_stats(const common_speculative * spec) {
|
||||
str_perf = "";
|
||||
}
|
||||
|
||||
LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
|
||||
LOG_INF("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s\n",
|
||||
common_speculative_type_to_str(impl->type).c_str(),
|
||||
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
|
||||
impl->n_gen_drafts,
|
||||
|
||||
@@ -53,9 +53,12 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
|
||||
// process the batch and update the internal state of the speculative context
|
||||
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
|
||||
|
||||
// true if any implementation requires target embeddings to be extracted
|
||||
// true if any implementation requires target post-norm embeddings to be extracted
|
||||
bool common_speculative_need_embd(common_speculative * spec);
|
||||
|
||||
// true if any implementation requires target pre-norm embeddings to be extracted
|
||||
bool common_speculative_need_embd_pre_norm(common_speculative * spec);
|
||||
|
||||
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
|
||||
void common_speculative_draft(common_speculative * spec);
|
||||
|
||||
|
||||
@@ -1610,6 +1610,42 @@ class TextModel(ModelBase):
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_hybriddna(self):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) # ty: ignore[unresolved-attribute]
|
||||
assert max(tokenizer.vocab.values()) < vocab_size # ty: ignore[unresolved-attribute]
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} # ty: ignore[unresolved-attribute]
|
||||
added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute]
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder # ty: ignore[unresolved-attribute]
|
||||
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
else:
|
||||
token: str = reverse_vocab[i]
|
||||
if token in added_vocab:
|
||||
if added_tokens_decoder[i].special or self.does_token_look_special(token):
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
tokens.append(token)
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
self.gguf_writer.add_tokenizer_model("hybriddna")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_qwen(self):
|
||||
from .qwen import QwenModel
|
||||
|
||||
|
||||
@@ -189,7 +189,8 @@ class HunYuanModel(TextModel):
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
# HunyuanOCR has pad_token_id=-1 in config.json; exclude pad from SpecialVocab
|
||||
# Some HunYuanVL variants (e.g. OCR-style configs) have pad_token_id=-1;
|
||||
# guard SpecialVocab so it doesn't try to emit an invalid pad id.
|
||||
token_types = None
|
||||
if (self.hparams.get("pad_token_id") or 0) < 0:
|
||||
token_types = ('bos', 'eos', 'unk', 'sep', 'cls', 'mask')
|
||||
@@ -250,7 +251,8 @@ class HunYuanModel(TextModel):
|
||||
self._fix_special_tokens()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
# HunyuanOCR has num_experts=1 which is not MoE, prevent parent from writing it
|
||||
# Some HunYuanVL variants set num_experts=1 (not real MoE);
|
||||
# prevent the parent class from emitting expert_count metadata in that case.
|
||||
saved_num_experts = self.hparams.pop("num_experts", None)
|
||||
super().set_gguf_parameters()
|
||||
if saved_num_experts is not None and saved_num_experts > 1:
|
||||
@@ -288,51 +290,21 @@ class HunYuanModel(TextModel):
|
||||
|
||||
@ModelBase.register("HunYuanVLForConditionalGeneration")
|
||||
class HunyuanVLVisionModel(MmprojModel):
|
||||
# Handles both HunyuanOCR and HunyuanVL, which share the HF architecture name
|
||||
# "HunYuanVLForConditionalGeneration" and the `vit.perceive.*` vision layout.
|
||||
# Each variant maps to a different projector type in clip.cpp so image
|
||||
# preprocessing follows the correct code path.
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
# HunyuanOCR / HunyuanVL uses max_image_size instead of image_size
|
||||
# HunyuanVL uses max_image_size instead of image_size
|
||||
if "image_size" not in self.hparams_vision:
|
||||
self.hparams_vision["image_size"] = self.hparams_vision.get("max_image_size", 2048)
|
||||
|
||||
@staticmethod
|
||||
def is_ocr_variant(hparams: dict) -> bool:
|
||||
"""Return True for HunyuanOCR, False for HunyuanVL.
|
||||
|
||||
The projector's output dim must equal the text model's hidden_size by
|
||||
construction (that's what "projector" means). HunyuanOCR pairs a 1B text
|
||||
backbone (hidden=1024); HunyuanVL pairs a 4B one (hidden=3072). So the
|
||||
ViT -> LLM projection dim is a hard architectural signature, not a
|
||||
magic number.
|
||||
"""
|
||||
vision_out = int((hparams.get("vision_config") or {}).get("out_hidden_size", 0))
|
||||
return vision_out == 1024
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
assert self.hparams_vision is not None
|
||||
vcfg = self.hparams_vision
|
||||
|
||||
if self.is_ocr_variant(self.global_config):
|
||||
# --- HunyuanOCR ---
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(vcfg.get("rms_norm_eps", 1e-5))
|
||||
self.gguf_writer.add_vision_spatial_merge_size(vcfg.get("spatial_merge_size", 2))
|
||||
self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"])
|
||||
self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"])
|
||||
return
|
||||
|
||||
# --- HunyuanVL ---
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANVL)
|
||||
self.gguf_writer.add_vision_use_gelu(str(vcfg["hidden_act"]).lower() == "gelu")
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(float(vcfg["rms_norm_eps"]))
|
||||
self.gguf_writer.add_vision_spatial_merge_size(int(vcfg["spatial_merge_size"]))
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(vcfg.get("rms_norm_eps", 1e-5))
|
||||
self.gguf_writer.add_vision_spatial_merge_size(vcfg.get("spatial_merge_size", 2))
|
||||
self.gguf_writer.add_vision_min_pixels(int(self.preprocessor_config["min_pixels"]))
|
||||
self.gguf_writer.add_vision_max_pixels(int(self.preprocessor_config["max_pixels"]))
|
||||
|
||||
@@ -353,7 +325,7 @@ class HunyuanVLVisionModel(MmprojModel):
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
# force conv weights to F32 or F16 to avoid BF16 IM2COL issues on Metal
|
||||
# Both HunyuanOCR and HunyuanVL emit the ViT -> LLM projection as mm.0/mm.2.
|
||||
# HunyuanVL emit the ViT -> LLM projection as mm.0/mm.2.
|
||||
if ("mm.0." in new_name or "mm.2." in new_name) and new_name.endswith(".weight"):
|
||||
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
@@ -361,40 +333,18 @@ class HunyuanVLVisionModel(MmprojModel):
|
||||
|
||||
@ModelBase.register("HunYuanVLForConditionalGeneration")
|
||||
class HunyuanVLTextModel(HunYuanModel):
|
||||
# The "HunYuanVLForConditionalGeneration" HF architecture covers both HunyuanOCR
|
||||
# and HunyuanVL. HunyuanOCR reuses the HunYuan-Dense text backbone (standard RoPE),
|
||||
# while HunyuanVL introduces a new LLM arch with XD-RoPE. Detect the variant from
|
||||
# the config and pick the matching GGUF architecture.
|
||||
model_arch = gguf.MODEL_ARCH.HUNYUAN_VL
|
||||
|
||||
@staticmethod
|
||||
def _is_ocr_config(hparams: dict) -> bool:
|
||||
# OCR pairs a 1B text backbone (hidden=1024) with a ViT projector that
|
||||
# outputs 1024-d; HunyuanVL uses 3072-d. Keep in sync with
|
||||
# HunyuanVLVisionModel.is_ocr_variant.
|
||||
return int((hparams.get("vision_config") or {}).get("out_hidden_size", 0)) == 1024
|
||||
|
||||
def __init__(self, dir_model: Path, *args, **kwargs):
|
||||
raw_hparams = kwargs.get("hparams") or ModelBase.load_hparams(dir_model, is_mistral_format=False)
|
||||
if self._is_ocr_config(raw_hparams):
|
||||
self.model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
|
||||
else:
|
||||
self.model_arch = gguf.MODEL_ARCH.HUNYUAN_VL
|
||||
super().__init__(dir_model, *args, **kwargs)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Only emit XD-RoPE metadata for the HunyuanVL backbone; HunyuanOCR uses
|
||||
# the HunYuan-Dense arch which already handles standard rope in super().
|
||||
if self.model_arch != gguf.MODEL_ARCH.HUNYUAN_VL:
|
||||
return
|
||||
|
||||
# XD-RoPE metadata for the HunyuanVL;
|
||||
if self.rope_parameters.get("rope_type") != "xdrope":
|
||||
return
|
||||
|
||||
# defaults for HunyuanVL. The C++ side later computes:
|
||||
# freq_base = rope_theta * alpha ** (head_dim / (head_dim - 2))
|
||||
self.gguf_writer.add_rope_freq_base(float(self.rope_parameters["rope_theta"]))
|
||||
self.gguf_writer.add_rope_scaling_alpha(float(self.rope_parameters["alpha"]))
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
|
||||
@@ -51,6 +51,15 @@ class LlamaModel(TextModel):
|
||||
if path_tekken_json.is_file() and not path_tokenizer_json.is_file():
|
||||
self._set_vocab_mistral()
|
||||
|
||||
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
|
||||
if tokenizer_config_file.is_file():
|
||||
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
|
||||
tokenizer_config_json = json.load(f)
|
||||
if (add_prefix_space := tokenizer_config_json.get("add_prefix_space")) is not None:
|
||||
self.gguf_writer.add_add_space_prefix(add_prefix_space)
|
||||
if tokenizer_config_json.get("tokenizer_class") == "HybridDNATokenizer":
|
||||
return self._set_vocab_hybriddna()
|
||||
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
@@ -72,13 +81,6 @@ class LlamaModel(TextModel):
|
||||
special_vocab._set_special_token("eot", 32010)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
|
||||
if tokenizer_config_file.is_file():
|
||||
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
|
||||
tokenizer_config_json = json.load(f)
|
||||
if "add_prefix_space" in tokenizer_config_json:
|
||||
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
|
||||
|
||||
# Apply to granite small models only
|
||||
if self.hparams.get("vocab_size", 32000) == 49152:
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
|
||||
@@ -600,6 +600,7 @@ class _Qwen35MtpMixin:
|
||||
if name.find("layers.") != -1:
|
||||
assert bid is not None
|
||||
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}")
|
||||
bid = bid + n_layer
|
||||
else:
|
||||
remapper = {
|
||||
"mtp.fc": "model.layers.{bid}.eh_proj",
|
||||
|
||||
@@ -115,15 +115,15 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mmproj", action="store_true",
|
||||
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
|
||||
help="Export multimodal projector (mmproj) for vision models. This will only work on some vision models. An 'mmproj-' prefix will be added to the output file name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mtp", action="store_true",
|
||||
help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.",
|
||||
help="Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. An 'mtp-' prefix will be added to the output file name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-mtp", action="store_true",
|
||||
help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.",
|
||||
help="Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, but even though the bundled default is more space-efficient overall, this allows differing quantization which may be more performant.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mistral-format", action="store_true",
|
||||
|
||||
@@ -445,6 +445,11 @@ if __name__ == '__main__':
|
||||
if self.lazy:
|
||||
tensor = LazyTorchTensor.from_eager(tensor)
|
||||
base_name = get_base_tensor_name(name)
|
||||
# filter base name, ignore tensor transformations for now
|
||||
data_gen = lambda g=tensor: g # noqa: E731
|
||||
if (titem := self.filter_tensors((base_name, data_gen))) is None:
|
||||
continue
|
||||
base_name, _ = titem
|
||||
# note: mergekit-extract-lora also adds token embeddings to the adapter
|
||||
is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name
|
||||
is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
- [News](#news)
|
||||
- [OS](#os)
|
||||
- [Hardware](#hardware)
|
||||
- [Performance Reference](#performance-reference)
|
||||
- [Docker](#docker)
|
||||
- [Linux](#linux)
|
||||
- [Windows](#windows)
|
||||
@@ -51,9 +52,8 @@ The packages for FP32 and FP16 would have different accuracy and performance on
|
||||
|
||||
## News
|
||||
|
||||
- 2026.04
|
||||
|
||||
- Optimize mul_mat by reorder feature for data type: Q4_K, Q5_K, Q_K, Q8_0.
|
||||
- 2026.04-05
|
||||
- Optimize mul_mat by reorder feature for data type: Q4_K, Q5_K, Q6_K, Q8_0.
|
||||
- Fused MoE.
|
||||
- Upgrate CI and built package for oneAPI 2025.3.3, support Ubuntu 24.04 built package.
|
||||
|
||||
@@ -150,6 +150,13 @@ On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the
|
||||
|
||||
NA
|
||||
|
||||
## Performance Reference
|
||||
|
||||
|
||||
To get the supported LLMs, GPUs, and performance reference, please check [Performance of llama.cpp on Intel GPU with SYCL backend](https://github.com/ggml-org/llama.cpp/discussions/23313).
|
||||
|
||||
You could update your test result in it directly.
|
||||
|
||||
## Docker
|
||||
|
||||
The docker build option is currently limited to *Intel GPU* targets.
|
||||
|
||||
@@ -10,8 +10,8 @@
|
||||
"ANDROID_ABI": "arm64-v8a",
|
||||
"ANDROID_PLATFORM": "android-31",
|
||||
"CMAKE_TOOLCHAIN_FILE": "$env{ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake",
|
||||
"CMAKE_C_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_CXX_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_C_FLAGS": "-march=armv8.7a+fp16+dotprod+i8mm -fvectorize -ffp-model=fast -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_CXX_FLAGS": "-march=armv8.7a+fp16+dotprod+i8mm -fvectorize -ffp-model=fast -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
@@ -59,8 +59,8 @@
|
||||
"toolset": { "value": "host=x86_64", "strategy": "external" },
|
||||
"cacheVariables": {
|
||||
"CMAKE_TOOLCHAIN_FILE": "cmake/arm64-linux-clang.cmake",
|
||||
"CMAKE_C_FLAGS": "-march=armv8 -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_CXX_FLAGS": "-march=armv8 -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_C_FLAGS": "-march=armv8.2a+fp16+dotprod -fvectorize -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_CXX_FLAGS": "-march=armv8.2a+fp16+dotprod -fvectorize -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
|
||||
@@ -10,7 +10,7 @@ This image includes Android NDK, OpenCL SDK, Hexagon SDK, CMake, etc.
|
||||
This method works on Linux, macOS, and Windows. macOS and Windows users should install Docker Desktop.
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ docker run -it -u $(id -u):$(id -g) --volume $(pwd):/workspace --platform linux/amd64 ghcr.io/snapdragon-toolchain/arm64-android:v0.3
|
||||
~/src/llama.cpp$ docker run -it -u $(id -u):$(id -g) --volume $(pwd):/workspace --platform linux/amd64 ghcr.io/snapdragon-toolchain/arm64-android:v0.6
|
||||
[d]/> cd /workspace
|
||||
```
|
||||
|
||||
|
||||
@@ -108,11 +108,12 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
|
||||
### General Speculative Parameters
|
||||
|
||||
```
|
||||
--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
type of speculative decoding to use when no draft model is provided
|
||||
--spec-type [none|draft-simple|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
comma-separated list of types of speculative decoding to use
|
||||
(default: none)
|
||||
(env: LLAMA_ARG_SPEC_TYPE)
|
||||
--spec-default use default speculative decoding
|
||||
--spec-default use default speculative decoding config
|
||||
(enables ngram-mod)
|
||||
```
|
||||
|
||||
### Draft Model Parameters
|
||||
@@ -123,8 +124,9 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_MODEL)
|
||||
--spec-draft-hf, -hfd, -hfrd, --hf-repo-draft <user>/<model>[:quant]
|
||||
HuggingFace repository for the draft model
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_HF_REPO)
|
||||
--spec-draft-n-max N
|
||||
number of tokens to draft for speculative decoding (default: 16)
|
||||
number of tokens to draft for speculative decoding (default: 3)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_N_MAX)
|
||||
--spec-draft-n-min N
|
||||
minimum number of draft tokens to use for speculative decoding (default: 0)
|
||||
@@ -133,18 +135,64 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
|
||||
speculative decoding split probability (default: 0.10)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT)
|
||||
--spec-draft-p-min, --draft-p-min P
|
||||
minimum speculative decoding probability (greedy) (default: 0.75)
|
||||
minimum speculative decoding probability (greedy) (default: 0.00)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_P_MIN)
|
||||
--spec-draft-ctx-size, -cd, --ctx-size-draft N
|
||||
size of the prompt context for the draft model (default: 0, 0 = loaded from model)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE)
|
||||
--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N
|
||||
max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
|
||||
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT)
|
||||
--spec-draft-device, -devd, --device-draft <dev1,dev2,..>
|
||||
comma-separated list of devices to use for offloading the draft model
|
||||
--spec-draft-replace, --spec-replace TARGET DRAFT
|
||||
translate the string in TARGET into DRAFT if the draft model and main model are not compatible
|
||||
(use --list-devices to see available devices)
|
||||
```
|
||||
|
||||
### Draft Model CPU Scheduling Parameters
|
||||
|
||||
```
|
||||
--spec-draft-threads, -td, --threads-draft N
|
||||
number of CPU threads to use during generation
|
||||
--spec-draft-threads-batch, -tbd, --threads-batch-draft N
|
||||
number of threads to use during batch and prompt processing (default: same as --threads-draft)
|
||||
--spec-draft-cpu-mask, -Cd, --cpu-mask-draft M
|
||||
Draft model CPU affinity mask. Complements cpu-range-draft
|
||||
--spec-draft-cpu-range, -Crd, --cpu-range-draft lo-hi
|
||||
Ranges of CPUs for affinity. Complements --cpu-mask-draft
|
||||
--spec-draft-cpu-strict, --cpu-strict-draft <0|1>
|
||||
Use strict CPU placement for draft model (default: same as --cpu-strict)
|
||||
--spec-draft-prio, --prio-draft N
|
||||
set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime
|
||||
--spec-draft-poll, --poll-draft <0|1>
|
||||
Use polling to wait for draft model work (default: same as --poll)
|
||||
--spec-draft-cpu-mask-batch, -Cbd, --cpu-mask-batch-draft M
|
||||
Draft model CPU affinity mask for batch. Complements cpu-range-batch-draft
|
||||
--spec-draft-cpu-range-batch, -Crbd, --cpu-range-batch-draft lo-hi
|
||||
Ranges of CPUs for affinity for batch. Complements --cpu-mask-batch-draft
|
||||
--spec-draft-cpu-strict-batch, --cpu-strict-batch-draft <0|1>
|
||||
Use strict CPU placement for draft model batch (default: --cpu-strict-draft)
|
||||
--spec-draft-prio-batch, --prio-batch-draft N
|
||||
set draft process/thread priority for batch : 0-normal, 1-medium, 2-high, 3-realtime
|
||||
--spec-draft-poll-batch, --poll-batch-draft <0|1>
|
||||
Use polling to wait for draft model work for batch (default: --poll-draft)
|
||||
```
|
||||
|
||||
### Draft Model KV Cache and Tensor Override Parameters
|
||||
|
||||
```
|
||||
--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE
|
||||
KV cache data type for K for the draft model
|
||||
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K)
|
||||
--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE
|
||||
KV cache data type for V for the draft model
|
||||
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V)
|
||||
--spec-draft-override-tensor, -otd, --override-tensor-draft <tensor name pattern>=<buffer type>,...
|
||||
override tensor buffer type for draft model
|
||||
--spec-draft-cpu-moe, -cmoed, --cpu-moe-draft
|
||||
keep all Mixture of Experts (MoE) weights in the CPU for the draft model
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_CPU_MOE)
|
||||
--spec-draft-n-cpu-moe, --spec-draft-ncmoe, -ncmoed, --n-cpu-moe-draft N
|
||||
keep the MoE weights of the first N layers in the CPU for the draft model
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_N_CPU_MOE)
|
||||
```
|
||||
|
||||
### n-gram Mod Parameters
|
||||
@@ -193,11 +241,13 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
|
||||
|
||||
### `--spec-type TYPE`
|
||||
|
||||
Specifies a type of speculative decoding without draft model.
|
||||
Specifies a comma-separated list of speculative decoding types to use.
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `none` | No speculative decoding (default) |
|
||||
| `draft-simple` | Use a simple draft model for speculation |
|
||||
| `draft-mtp` | Use Multi Token Prediction (MTP) heads from the main model |
|
||||
| `ngram-cache` | Use n-gram cache lookup |
|
||||
| `ngram-simple` | Use simple n-gram pattern matching |
|
||||
| `ngram-map-k` | Use n-gram pattern matching with n-gram-keys |
|
||||
@@ -209,6 +259,11 @@ Specifies a type of speculative decoding without draft model.
|
||||
./llama-server [...] --spec-type ngram-simple
|
||||
```
|
||||
|
||||
**Example:** Multiple speculative implementations.
|
||||
```bash
|
||||
./llama-server [...] --spec-type ngram-mod,ngram-map-k4v
|
||||
```
|
||||
|
||||
### `--spec-ngram-*-size-n N`
|
||||
|
||||
Sets the size N of the lookup n-gram for n-gram map based speculative decoding.
|
||||
|
||||
@@ -27,7 +27,6 @@ else()
|
||||
add_subdirectory(parallel)
|
||||
add_subdirectory(passkey)
|
||||
add_subdirectory(retrieval)
|
||||
add_subdirectory(save-load-state)
|
||||
add_subdirectory(simple)
|
||||
add_subdirectory(simple-chat)
|
||||
add_subdirectory(speculative)
|
||||
|
||||
@@ -149,6 +149,8 @@ class TaskState:
|
||||
t_gen_ms: Optional[float] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
server_name: Optional[str] = None
|
||||
chunk_idx: int = 0
|
||||
problem_idx: int = 0
|
||||
|
||||
|
||||
class EvalState:
|
||||
@@ -233,7 +235,9 @@ class EvalState:
|
||||
tps_gen: Optional[float] = None,
|
||||
t_gen_ms: Optional[float] = None,
|
||||
reasoning_content: Optional[str] = None,
|
||||
server_name: Optional[str] = None
|
||||
server_name: Optional[str] = None,
|
||||
chunk_idx: int = 0,
|
||||
problem_idx: int = 0,
|
||||
):
|
||||
with self._lock:
|
||||
if "cases" not in self.task_states:
|
||||
@@ -252,7 +256,9 @@ class EvalState:
|
||||
"tps_gen": tps_gen,
|
||||
"t_gen_ms": t_gen_ms,
|
||||
"reasoning_content": reasoning_content,
|
||||
"server_name": server_name
|
||||
"server_name": server_name,
|
||||
"chunk_idx": chunk_idx,
|
||||
"problem_idx": problem_idx,
|
||||
}
|
||||
|
||||
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
|
||||
@@ -289,6 +295,9 @@ class EvalState:
|
||||
all_cases = {}
|
||||
for i, task_id in tasks_to_save:
|
||||
question_text, prompt, expected = self.get_case(i)
|
||||
# Extract chunk_idx from task_id for pending cases
|
||||
_parts = task_id.rsplit("_", 2)
|
||||
_chunk_idx = int(_parts[-2]) if len(_parts) >= 3 else 0
|
||||
if task_id in self.task_states.get("cases", {}):
|
||||
all_cases[task_id] = self.task_states["cases"][task_id]
|
||||
else:
|
||||
@@ -306,7 +315,9 @@ class EvalState:
|
||||
"tps_gen": None,
|
||||
"t_gen_ms": None,
|
||||
"reasoning_content": None,
|
||||
"server_name": None
|
||||
"server_name": None,
|
||||
"chunk_idx": _chunk_idx,
|
||||
"problem_idx": i,
|
||||
}
|
||||
|
||||
ci_lower, ci_upper = self.accuracy_ci()
|
||||
@@ -382,11 +393,12 @@ class EvalState:
|
||||
grader_log_str = self._escape_html(json.dumps(grader_log, indent=2))
|
||||
escaped_server = self._escape_html(server_name)
|
||||
|
||||
answer_class = status_class if status == "ok" else ""
|
||||
rows.append(f"""<tr class="task-row" onclick="toggleDetails('{task_id}')">
|
||||
<td>{task_id}</td>
|
||||
<td class="{status_class}">{status_text}</td>
|
||||
<td>{self._escape_html(expected)}</td>
|
||||
<td>{self._escape_html(answer)}</td>
|
||||
<td class="{answer_class}">{self._escape_html(answer)}</td>
|
||||
<td>{tokens_str}</td>
|
||||
<td>{tps_str}</td>
|
||||
<td>{t_gen_str}</td>
|
||||
@@ -405,6 +417,53 @@ class EvalState:
|
||||
|
||||
rows_html = "\n".join(rows)
|
||||
|
||||
# ---- per-problem summary table ----
|
||||
problem_groups: Dict[int, List[Dict[str, Any]]] = {}
|
||||
for _tid, _case in cases.items():
|
||||
if _case.get("status") != "ok":
|
||||
continue
|
||||
_pidx = _case.get("problem_idx")
|
||||
if _pidx is None:
|
||||
_p_parts = _tid.rsplit("_", 2)
|
||||
_pidx = int(_p_parts[-1]) if len(_p_parts) >= 3 else 0
|
||||
problem_groups.setdefault(_pidx, []).append(_case)
|
||||
|
||||
summary_rows_html = ""
|
||||
if problem_groups:
|
||||
def _stat(v, fmt=".1f", avg_fmt=None):
|
||||
if not v:
|
||||
return ("–", "–", "–")
|
||||
af = fmt if avg_fmt is None else avg_fmt
|
||||
return (f"{min(v):{fmt}}", f"{sum(v)/len(v):{af}}", f"{max(v):{fmt}}")
|
||||
|
||||
summary_data = []
|
||||
for pidx, g in problem_groups.items():
|
||||
runs = len(g)
|
||||
n_ok = sum(1 for c in g if c.get("correct", False))
|
||||
toks = [c["tokens"] for c in g if c.get("tokens") is not None]
|
||||
tps = [c["tps_gen"] for c in g if c.get("tps_gen") is not None]
|
||||
tg = [c["t_gen_ms"] / 1000 for c in g if c.get("t_gen_ms") is not None]
|
||||
summary_data.append((
|
||||
pidx, runs, n_ok,
|
||||
_stat(toks, "d", ".0f"),
|
||||
_stat(tps),
|
||||
_stat(tg),
|
||||
))
|
||||
|
||||
summary_data.sort(key=lambda r: r[0]) # sort by problem index ascending
|
||||
|
||||
summary_rows_html = "\n".join(
|
||||
f"""<tr class="summary-row">
|
||||
<td>{p:03d}</td>
|
||||
<td>{r}</td>
|
||||
<td>{n}/{r}</td>
|
||||
<td>{tk[0]}</td><td>{tk[1]}</td><td>{tk[2]}</td>
|
||||
<td>{tp[0]}</td><td>{tp[1]}</td><td>{tp[2]}</td>
|
||||
<td>{tg[0]}</td><td>{tg[1]}</td><td>{tg[2]}</td>
|
||||
</tr>"""
|
||||
for p, r, n, tk, tp, tg in summary_data
|
||||
)
|
||||
|
||||
html_content = f"""<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
@@ -412,10 +471,10 @@ class EvalState:
|
||||
<title>{self.dataset_type.upper()} Eval</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; margin: 0; padding: 16px; background: #fff; color: #222; }}
|
||||
.bar {{ padding: 8px 0; font-size: 14px; color: #555; }}
|
||||
.bar span {{ margin-right: 20px; }}
|
||||
.bar b {{ color: #222; }}
|
||||
table {{ width: 100%; border-collapse: collapse; font-size: 13px; }}
|
||||
.bar {{ padding: 8px 0; font-size: 13px; color: #555; font-family: 'SF Mono', 'Menlo', 'Consolas', monospace; display: grid; grid-template-columns: auto 1fr auto 1fr; gap: 2px 12px; align-items: baseline; }}
|
||||
.bar .label {{ color: #888; }}
|
||||
.bar .value {{ color: #222; }}
|
||||
table {{ width: 100%; border-collapse: collapse; font-size: 13px; font-family: 'SF Mono', 'Menlo', 'Consolas', monospace; }}
|
||||
th {{ text-align: left; padding: 6px 8px; border-bottom: 2px solid #ccc; font-weight: 600; }}
|
||||
td {{ padding: 4px 8px; border-bottom: 1px solid #eee; vertical-align: top; }}
|
||||
.task-row {{ cursor: pointer; }}
|
||||
@@ -429,37 +488,88 @@ class EvalState:
|
||||
.details-content {{ padding: 8px 16px; background: #f6f8fa; font-size: 12px; }}
|
||||
.details-content b {{ color: #555; }}
|
||||
.details-content pre {{ background: #fff; border: 1px solid #e1e4e8; padding: 8px; overflow-x: auto; white-space: pre-wrap; word-wrap: break-word; margin: 4px 0 8px; }}
|
||||
.summary-table {{ margin-bottom: 16px; font-size: 13px; width: 100%; }}
|
||||
.summary-row {{ background: #fafbfc; }}
|
||||
.summary-row:hover {{ background: #f5f5f5; }}
|
||||
.summary-table th {{ text-align: right; font-weight: 600; }}
|
||||
.summary-table th:first-child {{ text-align: left; }}
|
||||
.summary-table th[colspan] {{ text-align: center; }}
|
||||
.summary-table td {{ text-align: right; }}
|
||||
.summary-table td:first-child {{ text-align: left; }}
|
||||
.tabs {{ display: flex; border-bottom: 2px solid #ddd; margin: 12px 0 0; }}
|
||||
.tab-btn {{ padding: 6px 16px; border: none; background: none; font-size: 13px; cursor: pointer; color: #555; border-bottom: 2px solid transparent; margin-bottom: -2px; font-weight: 500; }}
|
||||
.tab-btn:hover {{ color: #222; }}
|
||||
.tab-btn.active {{ color: #222; border-bottom-color: #222; font-weight: 600; }}
|
||||
.tab-content {{ display: none; }}
|
||||
.tab-content.active {{ display: block; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="bar">
|
||||
<span><b>{self.dataset_type.upper()}</b></span>
|
||||
<span>Model: {self.model_name or 'N/A'}</span>
|
||||
<span>Accuracy: <b>{accuracy:.1f}%</b> [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]</span>
|
||||
<span>Correct: <span class="correct">{n_correct}</span> / {len(completed)}</span>
|
||||
<span>Pending: {n_pending}</span>
|
||||
<span>Time: {self.total_time:.1f}s</span>
|
||||
<span>Sampling: {sampling_str}</span>
|
||||
<div class="label">Dataset</div><div class="value"><b>{self.dataset_type.upper()}</b></div>
|
||||
<div class="label">Model</div><div class="value"><b>{self.model_name or 'N/A'}</b></div>
|
||||
<div class="label">Accuracy</div><div class="value"><b>{accuracy:.1f}%</b> [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]</div>
|
||||
<div class="label">Correct</div><div class="value"><span class="correct">{n_correct}</span> / {len(completed)}</div>
|
||||
<div class="label">Pending</div><div class="value">{n_pending}</div>
|
||||
<div class="label">Time</div><div class="value">{self.total_time:.1f}s</div>
|
||||
<div class="label">Sampling</div><div class="value">{sampling_str}</div>
|
||||
</div>
|
||||
<div class="tabs">
|
||||
<button class="tab-btn active" data-tab="detailed" onclick="switchTab(this)">Detailed</button>
|
||||
<button class="tab-btn" data-tab="summary" onclick="switchTab(this)">Summary</button>
|
||||
</div>
|
||||
<div id="tab-detailed" class="tab-content active">
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th></th>
|
||||
<th>Gold</th>
|
||||
<th>Answer</th>
|
||||
<th>Tokens</th>
|
||||
<th>T/s</th>
|
||||
<th>Gen s</th>
|
||||
<th>Server</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{rows_html}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<div id="tab-summary" class="tab-content">
|
||||
<table class="summary-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Problem</th>
|
||||
<th>Runs</th>
|
||||
<th>Correct</th>
|
||||
<th colspan="3">Tokens</th>
|
||||
<th colspan="3">T/s</th>
|
||||
<th colspan="3">Gen s</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th></th>
|
||||
<th></th>
|
||||
<th></th>
|
||||
<th>min</th><th>avg</th><th>max</th>
|
||||
<th>min</th><th>avg</th><th>max</th>
|
||||
<th>min</th><th>avg</th><th>max</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{summary_rows_html}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th></th>
|
||||
<th>Gold</th>
|
||||
<th>Answer</th>
|
||||
<th>Tokens</th>
|
||||
<th>T/s</th>
|
||||
<th>Gen s</th>
|
||||
<th>Server</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{rows_html}
|
||||
</tbody>
|
||||
</table>
|
||||
<script>
|
||||
function toggleDetails(id) {{ document.getElementById('details-'+id).classList.toggle('open'); }}
|
||||
function switchTab(btn) {{
|
||||
document.querySelectorAll('.tab-btn').forEach(b => b.classList.remove('active'));
|
||||
document.querySelectorAll('.tab-content').forEach(c => c.classList.remove('active'));
|
||||
btn.classList.add('active');
|
||||
document.getElementById('tab-'+btn.dataset.tab).classList.add('active');
|
||||
}}
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
@@ -1062,12 +1172,19 @@ class Processor:
|
||||
) -> TaskState:
|
||||
question_text, prompt, expected = eval_state.get_case(i)
|
||||
|
||||
# Extract chunk_idx from task_id: "{dataset_type}_{chunk_idx:03d}_{index:03d}"
|
||||
_parts = task_id.rsplit("_", 2)
|
||||
chunk_idx = int(_parts[-2]) if len(_parts) >= 3 else 0
|
||||
problem_idx = i
|
||||
|
||||
task_state = TaskState(
|
||||
task_id=task_id,
|
||||
prompt=prompt,
|
||||
expected=expected,
|
||||
question_text=question_text,
|
||||
server_name=server_config.name
|
||||
server_name=server_config.name,
|
||||
chunk_idx=chunk_idx,
|
||||
problem_idx=problem_idx,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1085,7 +1202,8 @@ class Processor:
|
||||
eval_state.add_result(
|
||||
task_id, prompt, expected, result, None,
|
||||
{"finish_reason": finish_reason}, False, task_state.status,
|
||||
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
|
||||
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name,
|
||||
chunk_idx, problem_idx,
|
||||
)
|
||||
eval_state.dump()
|
||||
return task_state
|
||||
@@ -1108,7 +1226,8 @@ class Processor:
|
||||
eval_state.add_result(
|
||||
task_id, prompt, expected, result, answer,
|
||||
grader_log, is_correct, "ok",
|
||||
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
|
||||
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name,
|
||||
chunk_idx, problem_idx,
|
||||
)
|
||||
|
||||
eval_state.dump()
|
||||
|
||||
@@ -65,34 +65,70 @@ def normalize_number(s: str) -> Optional[int]:
|
||||
return int(match.group(0))
|
||||
|
||||
class AimeDataset:
|
||||
def __init__(self, split: str = "train"):
|
||||
def __init__(self, split: str = "train", dataset_type: str = "aime"):
|
||||
self.split = split
|
||||
self.dataset_type = dataset_type
|
||||
self.questions: List[Dict] = []
|
||||
self._load_dataset()
|
||||
|
||||
def _load_dataset(self):
|
||||
print(f"Loading AIME dataset (split: {self.split})...")
|
||||
def _get_question_text(self, question: Dict) -> str:
|
||||
"""Get question text, handling different dataset field names."""
|
||||
return question.get("problem", question.get("question", ""))
|
||||
|
||||
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
|
||||
if cache_path.exists():
|
||||
print(f"Using cached dataset from {cache_path}")
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
|
||||
def _load_dataset(self):
|
||||
if self.dataset_type == "aime":
|
||||
print(f"Loading AIME dataset (split: {self.split})...")
|
||||
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
|
||||
if cache_path.exists():
|
||||
print(f"Using cached dataset from {cache_path}")
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
|
||||
else:
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
|
||||
elif self.dataset_type == "aime2025":
|
||||
print(f"Loading AIME2025 dataset...")
|
||||
ds_list = []
|
||||
for config_name in ["AIME2025-I", "AIME2025-II"]:
|
||||
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "opencompass___AIME2025" / "default" / "0.0.0"
|
||||
if cache_path.exists():
|
||||
print(f"Using cached dataset from {cache_path}")
|
||||
ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test", cache_dir=str(cache_path))
|
||||
else:
|
||||
ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test")
|
||||
ds_list.extend(ds)
|
||||
ds = ds_list
|
||||
else:
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
|
||||
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
|
||||
|
||||
self.questions = list(ds)
|
||||
print(f"AIME dataset loaded: {len(self.questions)} questions")
|
||||
print(f"{self.dataset_type} dataset loaded: {len(self.questions)} questions")
|
||||
|
||||
def find_question(self, request_text: str) -> Optional[Dict]:
|
||||
# Strip common template prefixes to get the actual question text
|
||||
# Templates include things like "Solve the following math problem step by step..."
|
||||
# The actual question usually follows a blank line or after the template instruction
|
||||
cleaned = request_text
|
||||
# Split on double newline and take the part that looks like the problem
|
||||
parts = cleaned.split('\n\n')
|
||||
if len(parts) > 1:
|
||||
# Find the part that's longest (likely the actual problem text)
|
||||
problem_parts = [p for p in parts if len(p.strip()) > 100]
|
||||
if problem_parts:
|
||||
cleaned = max(problem_parts, key=lambda x: len(x))
|
||||
|
||||
best_match = None
|
||||
best_distance = -1
|
||||
best_index = -1
|
||||
|
||||
for i, question in enumerate(self.questions):
|
||||
question_text = question["problem"]
|
||||
request_lower = request_text.lower()
|
||||
question_text = self._get_question_text(question)
|
||||
request_lower = cleaned.lower()
|
||||
question_lower = question_text.lower()
|
||||
|
||||
# Check if question text is contained in the cleaned request
|
||||
if question_lower in request_lower or request_lower in question_lower:
|
||||
debug_log(f"DEBUG: Found substring match at index {i}")
|
||||
return question
|
||||
|
||||
# Exact match
|
||||
if question_lower == request_lower:
|
||||
debug_log(f"DEBUG: Found exact match at index {i}")
|
||||
@@ -118,7 +154,7 @@ class AimeDataset:
|
||||
debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}")
|
||||
return best_match
|
||||
|
||||
debug_log(f"DEBUG: No matching question found for: {request_text[:100]}...")
|
||||
debug_log(f"DEBUG: No matching question found for cleaned: {cleaned[:100]}...")
|
||||
return None
|
||||
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
@@ -134,15 +170,16 @@ class Simulator:
|
||||
port: int = 8033,
|
||||
host: str = "localhost",
|
||||
success_rate: float = 0.8,
|
||||
dataset_split: str = "train"
|
||||
dataset_split: str = "train",
|
||||
dataset_type: str = "aime"
|
||||
):
|
||||
self.port = port
|
||||
self.host = host
|
||||
self.success_rate = success_rate
|
||||
self.dataset = AimeDataset(dataset_split)
|
||||
self.dataset = AimeDataset(dataset_split, dataset_type)
|
||||
self.eval_state = EvalState(
|
||||
id="aime-2025",
|
||||
tasks=["aime"],
|
||||
id=dataset_type,
|
||||
tasks=[dataset_type],
|
||||
task_states={},
|
||||
sampling_config={"temperature": 0, "max_tokens": 2048}
|
||||
)
|
||||
@@ -159,6 +196,10 @@ class Simulator:
|
||||
else:
|
||||
response_text = self._generate_wrong_answer(question)
|
||||
|
||||
comp_tokens = random.randint(10000, 60000)
|
||||
tps_gen = random.uniform(90.0, 110.0)
|
||||
t_gen_ms = comp_tokens / tps_gen * 1000
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{int(time.time())}",
|
||||
"object": "chat.completion",
|
||||
@@ -176,8 +217,12 @@ class Simulator:
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150
|
||||
"completion_tokens": comp_tokens,
|
||||
"total_tokens": 100 + comp_tokens
|
||||
},
|
||||
"timings": {
|
||||
"predicted_ms": t_gen_ms,
|
||||
"predicted_per_second": tps_gen
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,6 +263,12 @@ class Simulator:
|
||||
return response
|
||||
|
||||
class RequestHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path == "/v1/models":
|
||||
self._send_json({"data": [{"id": "llama", "object": "model"}]}, 200)
|
||||
return
|
||||
self._send_json({"error": "Not found"}, 404)
|
||||
|
||||
def do_POST(self):
|
||||
if self.path != "/v1/chat/completions":
|
||||
self._send_json({"error": "Not found"}, 404)
|
||||
@@ -280,6 +331,13 @@ def main():
|
||||
default=0.8,
|
||||
help="Success rate 0-1 (default: 0.8)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="aime",
|
||||
choices=["aime", "aime2025"],
|
||||
help="Dataset type (default: aime)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-split",
|
||||
type=str,
|
||||
@@ -294,7 +352,8 @@ def main():
|
||||
port=args.port,
|
||||
host=args.host,
|
||||
success_rate=args.success_rate,
|
||||
dataset_split=args.dataset_split
|
||||
dataset_split=args.dataset_split,
|
||||
dataset_type=args.dataset
|
||||
)
|
||||
|
||||
server = HTTPServer((args.host, args.port), RequestHandler)
|
||||
@@ -304,7 +363,7 @@ def main():
|
||||
print("\n=== llama-server-simulator ===")
|
||||
print(f"Server running on http://{args.host}:{args.port}")
|
||||
print(f"Success rate: {args.success_rate}")
|
||||
print(f"AIME dataset loaded: {len(simulator.dataset.questions)} questions")
|
||||
print(f"{args.dataset} dataset loaded: {len(simulator.dataset.questions)} questions")
|
||||
print("\nPress Ctrl+C to stop\n")
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
set(TARGET llama-save-load-state)
|
||||
add_executable(${TARGET} save-load-state.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
@@ -1,320 +0,0 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
params.prompt = "The quick brown fox";
|
||||
params.sampling.seed = 1234;
|
||||
|
||||
const std::string_view state_file = "dump_state.bin";
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.n_parallel == 1) {
|
||||
// the example uses 2 sequences, so when n_parallel == 1, we need to enable unified kv cache
|
||||
printf("%s: n_parallel == 1, enabling unified kv cache\n", __func__);
|
||||
params.kv_unified = true;
|
||||
}
|
||||
|
||||
if (params.n_predict < 0) {
|
||||
params.n_predict = 16;
|
||||
}
|
||||
|
||||
auto n_past = 0;
|
||||
|
||||
std::string result0;
|
||||
std::string result1;
|
||||
std::string result2;
|
||||
std::string result3;
|
||||
|
||||
// init
|
||||
|
||||
ggml_backend_load_all();
|
||||
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
if (model == nullptr || ctx == nullptr) {
|
||||
fprintf(stderr, "%s : failed to init\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed));
|
||||
|
||||
// tokenize prompt
|
||||
auto tokens = common_tokenize(ctx, params.prompt, true);
|
||||
|
||||
const bool save_state = true;
|
||||
if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// first run
|
||||
printf("\nfirst run: %s", params.prompt.c_str());
|
||||
|
||||
llama_batch batch = llama_batch_init(1, 0, 1);
|
||||
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto next_token = llama_sampler_sample(smpl, ctx, -1);
|
||||
auto next_token_str = common_token_to_piece(ctx, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result0 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {0}, true);
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
}
|
||||
|
||||
printf("\n\n");
|
||||
|
||||
// make new context
|
||||
llama_context * ctx2 = llama_init_from_model(model, common_context_params_to_llama(params));
|
||||
|
||||
llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sampling.seed));
|
||||
|
||||
printf("\nsecond run: %s", params.prompt.c_str());
|
||||
|
||||
// load state from file
|
||||
std::vector<llama_token> unused_sts(tokens.size()); // unused session tokens.
|
||||
size_t n_token_count_out = 0;
|
||||
|
||||
if (!llama_state_load_file(ctx2, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
||||
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
||||
|
||||
// restore state (last tokens)
|
||||
n_past = n_token_count_out;
|
||||
if (!common_replay_last_token(ctx2, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// second run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto next_token = llama_sampler_sample(smpl2, ctx2, -1);
|
||||
auto next_token_str = common_token_to_piece(ctx2, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result1 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {0}, true);
|
||||
|
||||
if (llama_decode(ctx2, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
}
|
||||
|
||||
printf("\n\n");
|
||||
|
||||
if (result0 != result1) {
|
||||
fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// make new context
|
||||
auto params_ctx3 = common_context_params_to_llama(params);
|
||||
params_ctx3.n_seq_max = 2;
|
||||
llama_context * ctx3 = llama_init_from_model(model, params_ctx3);
|
||||
|
||||
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sampling.seed));
|
||||
|
||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
n_token_count_out = 0;
|
||||
|
||||
if (!llama_state_load_file(ctx3, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
||||
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
||||
|
||||
// restore state (last tokens)
|
||||
n_past = n_token_count_out;
|
||||
if (!common_replay_last_token(ctx3, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// save seq 0 and load into seq 1
|
||||
{
|
||||
// save kv of seq 0
|
||||
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
|
||||
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
|
||||
if (ncopy != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
||||
return 1;
|
||||
}
|
||||
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
|
||||
|
||||
// erase whole kv
|
||||
llama_memory_clear(llama_get_memory(ctx3), true);
|
||||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||
|
||||
// restore kv into seq 1
|
||||
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
|
||||
if (nset != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
||||
return 1;
|
||||
}
|
||||
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
|
||||
}
|
||||
|
||||
// third run with seq 1 instead of 0
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto next_token = llama_sampler_sample(smpl3, ctx3, -1);
|
||||
auto next_token_str = common_token_to_piece(ctx3, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result2 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {1}, true);
|
||||
|
||||
if (llama_decode(ctx3, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
}
|
||||
|
||||
// test on-device state save/load
|
||||
auto params_ctx4 = common_context_params_to_llama(params);
|
||||
params_ctx4.n_seq_max = 2;
|
||||
llama_context * ctx4 = llama_init_from_model(model, params_ctx4);
|
||||
|
||||
llama_sampler * smpl4 = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl4, llama_sampler_init_dist(params.sampling.seed));
|
||||
|
||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
n_token_count_out = 0;
|
||||
|
||||
if (!llama_state_load_file(ctx4, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
||||
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
||||
|
||||
// restore state (last tokens)
|
||||
n_past = n_token_count_out;
|
||||
if (!common_replay_last_token(ctx4, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// save seq 0 and load into seq 1
|
||||
{
|
||||
// save kv of seq 0
|
||||
std::vector<uint8_t> seq_store(llama_state_seq_get_size_ext(ctx4, 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE));
|
||||
const size_t ncopy = llama_state_seq_get_data_ext(ctx4, seq_store.data(), seq_store.size(), 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (ncopy != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
||||
return 1;
|
||||
}
|
||||
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
|
||||
|
||||
// erase whole kv
|
||||
llama_memory_clear(llama_get_memory(ctx4), true);
|
||||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||
|
||||
// restore kv into seq 0
|
||||
const size_t nset = llama_state_seq_set_data_ext(ctx4, seq_store.data(), seq_store.size(), 1, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (nset != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
||||
return 1;
|
||||
}
|
||||
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
|
||||
}
|
||||
|
||||
// forth run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto next_token = llama_sampler_sample(smpl4, ctx4, -1);
|
||||
auto next_token_str = common_token_to_piece(ctx4, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result3 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {1}, true);
|
||||
|
||||
if (llama_decode(ctx4, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
llama_sampler_free(smpl2);
|
||||
llama_sampler_free(smpl3);
|
||||
llama_sampler_free(smpl4);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
// this one is managed by common_init_result
|
||||
//llama_free(ctx);
|
||||
|
||||
llama_free(ctx2);
|
||||
llama_free(ctx3);
|
||||
llama_free(ctx4);
|
||||
|
||||
if (result0 != result2) {
|
||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (result0 != result3) {
|
||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n%s : success\n", __func__);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -111,7 +111,6 @@ if [ $GGML_SYCL_DEVICE -ne -1 ]; then
|
||||
echo "Use $GGML_SYCL_DEVICE as main GPU"
|
||||
#use signle GPU only
|
||||
GPUS_SETTING="-mg $GGML_SYCL_DEVICE -sm ${SPLIT_MODE}"
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:${GGML_SYCL_DEVICE}"
|
||||
echo "ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR}"
|
||||
else
|
||||
echo "Use all Intel GPUs, including iGPU & dGPU"
|
||||
|
||||
@@ -119,7 +119,6 @@ if [ $GGML_SYCL_DEVICE -ne -1 ]; then
|
||||
echo "Use $GGML_SYCL_DEVICE as main GPU"
|
||||
#use signle GPU only
|
||||
GPUS_SETTING="-mg $GGML_SYCL_DEVICE -sm ${SPLIT_MODE}"
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:${GGML_SYCL_DEVICE}"
|
||||
echo "ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR}"
|
||||
else
|
||||
echo "Use all Intel GPUs, including iGPU & dGPU"
|
||||
|
||||
@@ -164,7 +164,6 @@ if not "%GGML_SYCL_DEVICE%"=="-1" (
|
||||
echo Use %GGML_SYCL_DEVICE% as main GPU
|
||||
REM Use single GPU only.
|
||||
set "GPUS_SETTING=-mg %GGML_SYCL_DEVICE% -sm %SPLIT_MODE%"
|
||||
set "ONEAPI_DEVICE_SELECTOR=level_zero:%GGML_SYCL_DEVICE%"
|
||||
echo ONEAPI_DEVICE_SELECTOR=%ONEAPI_DEVICE_SELECTOR%
|
||||
) else (
|
||||
echo Use all Intel GPUs, including iGPU ^& dGPU
|
||||
|
||||
@@ -186,7 +186,6 @@ if not "%GGML_SYCL_DEVICE%"=="-1" (
|
||||
echo Use %GGML_SYCL_DEVICE% as main GPU
|
||||
REM Use single GPU only.
|
||||
set "GPUS_SETTING=-mg %GGML_SYCL_DEVICE% -sm %SPLIT_MODE%"
|
||||
set "ONEAPI_DEVICE_SELECTOR=level_zero:%GGML_SYCL_DEVICE%"
|
||||
echo ONEAPI_DEVICE_SELECTOR=%ONEAPI_DEVICE_SELECTOR%
|
||||
) else (
|
||||
echo Use all Intel GPUs, including iGPU ^& dGPU
|
||||
|
||||
@@ -379,7 +379,7 @@ void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data,
|
||||
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||
GGML_ASSERT(buf != NULL && "tensor buffer not set");
|
||||
|
||||
if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) {
|
||||
if (n_copies <= 1 || buf->iface.get_tensor_2d == NULL) {
|
||||
for (size_t i = 0; i < n_copies; i++) {
|
||||
ggml_backend_tensor_get(tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size);
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
|
||||
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
|
||||
# 86 == RTX 3000, needs CUDA v11.1
|
||||
# 89 == RTX 4000, needs CUDA v11.8
|
||||
# 90 == Hopper H100/200, needs CUDA v11.8
|
||||
# 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores
|
||||
#
|
||||
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
|
||||
@@ -33,7 +34,7 @@ if (CUDAToolkit_FOUND)
|
||||
list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real)
|
||||
|
||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
|
||||
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
|
||||
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real 90-virtual)
|
||||
endif()
|
||||
|
||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
template<typename T, size_t>
|
||||
using type_for_index = T;
|
||||
|
||||
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
||||
return b;
|
||||
GGML_UNUSED(a);
|
||||
@@ -52,6 +55,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
const int s12,
|
||||
const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
ggml_cuda_pdl_lc();
|
||||
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
|
||||
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
|
||||
@@ -72,6 +76,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
|
||||
const uint32_t i10 = fastmodulo(i0, ne10);
|
||||
|
||||
@@ -141,6 +146,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
||||
|
||||
const int i10 = fastmodulo(i0, ne10);
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
@@ -282,35 +288,24 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
|
||||
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
|
||||
|
||||
if constexpr (sizeof...(I) > 0) {
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
|
||||
{
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)block_num, block_size, 0, stream);
|
||||
ggml_cuda_kernel_launch(k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t, type_for_index<const src1_t *, I>...>, launch_params,
|
||||
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
|
||||
ne12, ne13,
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
|
||||
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13);
|
||||
}
|
||||
} else {
|
||||
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
|
||||
if constexpr (sizeof...(I) > 0) {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
|
||||
{
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
||||
ggml_cuda_kernel_launch(k_bin_bcast<bin_op, src0_t, src1_t, dst_t, type_for_index<const src1_t *, I>...>, launch_params,
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00 ,s01, s02, s03,
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13);
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -333,6 +328,7 @@ static __global__ void k_repeat_back(
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
|
||||
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
|
||||
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "ggml-cuda.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
@@ -27,6 +28,7 @@
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
@@ -50,6 +52,7 @@
|
||||
#define GGML_CUDA_CC_TURING 750
|
||||
#define GGML_CUDA_CC_AMPERE 800
|
||||
#define GGML_CUDA_CC_ADA_LOVELACE 890
|
||||
#define GGML_CUDA_CC_HOPPER 900
|
||||
// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
|
||||
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
|
||||
#define GGML_CUDA_CC_BLACKWELL 1200
|
||||
@@ -107,6 +110,24 @@
|
||||
# define GGML_CUDA_USE_CUB
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||
|
||||
// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8 and excludes HIP/MUSA.
|
||||
// __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code.
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080
|
||||
# define GGML_CUDA_USE_PDL
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080
|
||||
|
||||
static __device__ __forceinline__ void ggml_cuda_pdl_sync() {
|
||||
#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
|
||||
cudaGridDependencySynchronize();
|
||||
#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void ggml_cuda_pdl_lc() {
|
||||
#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
|
||||
}
|
||||
|
||||
#ifdef __CUDA_ARCH_LIST__
|
||||
constexpr bool ggml_cuda_has_arch_impl(int) {
|
||||
return false;
|
||||
@@ -165,6 +186,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
|
||||
|
||||
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
|
||||
|
||||
|
||||
#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
|
||||
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
||||
return cublasGetStatusString(err);
|
||||
@@ -1487,3 +1509,67 @@ struct ggml_cuda_mm_fusion_args_device {
|
||||
const void * gate_bias = nullptr;
|
||||
ggml_glu_op glu_op;
|
||||
};
|
||||
|
||||
struct ggml_cuda_kernel_launch_params {
|
||||
dim3 block_nums;
|
||||
dim3 block_dims;
|
||||
size_t shmem;
|
||||
cudaStream_t stream;
|
||||
|
||||
// size_t shmem
|
||||
ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const size_t shmem_, const cudaStream_t stream_)
|
||||
: block_nums(block_nums_), block_dims(block_dims_), shmem(shmem_), stream(stream_) {}
|
||||
|
||||
// Some call sites pass ints instead of the required size_t. This 2nd constructor casts int->size_t to avoid these -Wnarrowing warnings.
|
||||
ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const int shmem_, const cudaStream_t stream_)
|
||||
: block_nums(block_nums_), block_dims(block_dims_), shmem((size_t)shmem_), stream(stream_) {}
|
||||
};
|
||||
|
||||
#if defined(GGML_CUDA_USE_PDL)
|
||||
struct ggml_cuda_pdl_config {
|
||||
cudaLaunchAttribute attr;
|
||||
cudaLaunchConfig_t cfg;
|
||||
|
||||
ggml_cuda_pdl_config(const ggml_cuda_kernel_launch_params & params) {
|
||||
attr.id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attr.val.programmaticStreamSerializationAllowed = 1;
|
||||
|
||||
cfg = {};
|
||||
cfg.gridDim = params.block_nums;
|
||||
cfg.blockDim = params.block_dims;
|
||||
cfg.dynamicSmemBytes = params.shmem;
|
||||
cfg.stream = params.stream;
|
||||
cfg.attrs = &attr;
|
||||
cfg.numAttrs = 1;
|
||||
}
|
||||
|
||||
// Delete due to &attr
|
||||
ggml_cuda_pdl_config(const ggml_cuda_pdl_config&) = delete;
|
||||
ggml_cuda_pdl_config& operator=(const ggml_cuda_pdl_config&) = delete;
|
||||
ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete;
|
||||
|
||||
};
|
||||
#endif //defined(GGML_CUDA_USE_PDL)
|
||||
|
||||
|
||||
template<typename Kernel, typename... Args>
|
||||
static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) {
|
||||
#if defined(GGML_CUDA_USE_PDL)
|
||||
|
||||
static const bool env_pdl_enabled = []() {
|
||||
const char * env = getenv("GGML_CUDA_PDL");
|
||||
return env == nullptr || std::atoi(env) != 0;
|
||||
}();
|
||||
|
||||
if (env_pdl_enabled && ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_HOPPER) {
|
||||
auto pdl_cfg = ggml_cuda_pdl_config(launch_params);
|
||||
|
||||
CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward<Args>(args)... ));
|
||||
return;
|
||||
}
|
||||
#endif //defined(GGML_CUDA_USE_PDL)
|
||||
|
||||
kernel<<<launch_params.block_nums, launch_params.block_dims, launch_params.shmem, launch_params.stream>>>(std::forward<Args>(args)... );
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont
|
||||
|
||||
const int64_t n = ne0 * ne1 * ne2;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) {
|
||||
if constexpr (dim == 0) {
|
||||
const int64_t row = i / ne0;
|
||||
@@ -64,8 +65,8 @@ static void concat_f32_cuda(const float * x,
|
||||
const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
|
||||
|
||||
if (dim == 0) {
|
||||
concat_f32_cont<0>
|
||||
<<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
|
||||
return;
|
||||
}
|
||||
if (dim == 1) {
|
||||
|
||||
@@ -16,6 +16,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
|
||||
const int64_t nb12, const int64_t nb13) {
|
||||
ggml_cuda_pdl_lc();
|
||||
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
@@ -36,6 +37,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne
|
||||
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
@@ -59,6 +61,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
|
||||
__shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
|
||||
int cur_tile_buf = 0;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
|
||||
|
||||
@@ -142,6 +145,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||
const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
@@ -168,6 +172,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
@@ -182,6 +187,7 @@ static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const
|
||||
const src_t * x = (const src_t *) cx;
|
||||
dst_t * dst = (dst_t *) cdst;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
|
||||
}
|
||||
|
||||
@@ -192,8 +198,8 @@ cudaStream_t stream) {
|
||||
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar_contiguous<src_t, dst_t>, launch_params, cx, cdst, ne);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t, bool transposed = false>
|
||||
@@ -223,13 +229,15 @@ static void ggml_cpy_scalar_cuda(
|
||||
GGML_ASSERT(grid_z < USHRT_MAX);
|
||||
dim3 dimGrid(grid_x, grid_y, grid_z);
|
||||
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
||||
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
||||
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar_transpose<dst_t>, launch_params,
|
||||
cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
} else {
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar<cpy_1_scalar<src_t, dst_t>>, launch_params,
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -636,6 +636,7 @@ static __global__ void flash_attn_mask_to_KV_max(
|
||||
if (tid < WARP_SIZE) {
|
||||
buf_iw[tid] = 1;
|
||||
}
|
||||
ggml_cuda_pdl_sync();
|
||||
__syncthreads();
|
||||
|
||||
int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
|
||||
@@ -687,6 +688,7 @@ static __global__ void flash_attn_stream_k_fixup_uniform(
|
||||
const uint3 fd_iter_j_z,
|
||||
const uint3 fd_iter_j) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
ggml_cuda_pdl_lc();
|
||||
|
||||
const int tile_idx = blockIdx.x; // One block per output tile.
|
||||
const int j = blockIdx.y;
|
||||
@@ -718,6 +720,7 @@ static __global__ void flash_attn_stream_k_fixup_uniform(
|
||||
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
// Load the partial result that needs a fixup
|
||||
float dst_val = *dst;
|
||||
float max_val;
|
||||
@@ -809,6 +812,7 @@ static __global__ void flash_attn_stream_k_fixup_general(
|
||||
float dst_val = 0.0f;
|
||||
float max_val = 0.0f;
|
||||
float rowsum = 0.0f;
|
||||
ggml_cuda_pdl_sync();
|
||||
{
|
||||
dst_val = *dst;
|
||||
|
||||
@@ -867,6 +871,7 @@ static __global__ void flash_attn_combine_results(
|
||||
const float2 * __restrict__ VKQ_meta,
|
||||
float * __restrict__ dst,
|
||||
const int parallel_blocks) {
|
||||
ggml_cuda_pdl_lc();
|
||||
// Dimension 0: threadIdx.x
|
||||
// Dimension 1: blockIdx.x
|
||||
// Dimension 2: blockIdx.y
|
||||
@@ -890,6 +895,7 @@ static __global__ void flash_attn_combine_results(
|
||||
__builtin_assume(tid < D);
|
||||
|
||||
extern __shared__ float2 meta[];
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
||||
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
|
||||
}
|
||||
@@ -1146,7 +1152,9 @@ void launch_fattn(
|
||||
const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
|
||||
|
||||
GGML_ASSERT(block_dim.x % warp_size == 0);
|
||||
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
||||
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream);
|
||||
ggml_cuda_kernel_launch(fattn_kernel, launch_params,
|
||||
(const char *) Q->data,
|
||||
K_data,
|
||||
V_data,
|
||||
@@ -1176,9 +1184,9 @@ void launch_fattn(
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2};
|
||||
|
||||
flash_attn_stream_k_fixup_uniform<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr,
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream);
|
||||
ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_uniform<DV, ncols1, ncols2>, launch_params,
|
||||
(float *) KQV->data, dst_tmp_meta.ptr,
|
||||
Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk,
|
||||
gqa_ratio, bpt, fd0, fd1, fd2);
|
||||
} else if (ntiles_dst % blocks_num.x != 0) {
|
||||
@@ -1193,9 +1201,9 @@ void launch_fattn(
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
||||
|
||||
flash_attn_stream_k_fixup_general<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr,
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream);
|
||||
ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_general<DV, ncols1, ncols2>, launch_params,
|
||||
(float *) KQV->data, dst_tmp_meta.ptr,
|
||||
Q->ne[1], Q->ne[2], gqa_ratio, total_work,
|
||||
fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k);
|
||||
}
|
||||
@@ -1204,9 +1212,9 @@ void launch_fattn(
|
||||
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
|
||||
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
||||
|
||||
flash_attn_combine_results<DV>
|
||||
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream);
|
||||
ggml_cuda_kernel_launch(flash_attn_combine_results<DV>, launch_params,
|
||||
dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
@@ -1724,6 +1724,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
ggml_cuda_pdl_sync(); // TODO optimize placement
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE))
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
|
||||
@@ -894,6 +894,8 @@ static __global__ void flash_attn_tile(
|
||||
}
|
||||
float KQ_sum[cpw] = {0.0f};
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
|
||||
// Load Q data, convert to FP16 if fast:
|
||||
#pragma unroll
|
||||
for (int jc0 = 0; jc0 < cpw; ++jc0) {
|
||||
|
||||
@@ -40,6 +40,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
ggml_cuda_pdl_lc();
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
@@ -136,6 +137,8 @@ static __global__ void flash_attn_ext_vec(
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
if constexpr (Q_q8_1) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
|
||||
@@ -86,6 +86,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
||||
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "gated_delta_net.cuh"
|
||||
#include "ggml-cuda/common.cuh"
|
||||
|
||||
template <int S_v, bool KDA, bool keep_rs_t>
|
||||
__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
|
||||
@@ -53,6 +54,7 @@ gated_delta_net_cuda(const float * q,
|
||||
float s_shard[rows_per_lane];
|
||||
// state is stored transposed: M[col][i] = S[i][col], row col is contiguous
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
@@ -189,28 +191,29 @@ static void launch_gated_delta_net(
|
||||
|
||||
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream);
|
||||
switch (S_v) {
|
||||
case 16:
|
||||
gated_delta_net_cuda<16, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
|
||||
ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
break;
|
||||
case 32:
|
||||
gated_delta_net_cuda<32, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
|
||||
ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
break;
|
||||
case 64: {
|
||||
gated_delta_net_cuda<64, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
|
||||
ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
break;
|
||||
}
|
||||
case 128: {
|
||||
gated_delta_net_cuda<128, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
|
||||
ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
|
||||
@@ -11,6 +11,7 @@ static __global__ void k_get_rows(
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) {
|
||||
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
@@ -48,6 +49,8 @@ static __global__ void k_get_rows_float(
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
ggml_cuda_pdl_lc();
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) {
|
||||
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
@@ -83,6 +86,7 @@ static __global__ void k_get_rows_back_float(
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int64_t i = 0; i < nrows_grad; ++i) {
|
||||
if (rows[i] != dst_row) {
|
||||
continue;
|
||||
@@ -156,7 +160,8 @@ static void get_rows_cuda_float(
|
||||
GGML_ASSERT(ne11 <= std::numeric_limits<uint32_t>::max() / ne12);
|
||||
const uint3 ne12_fdv = init_fastdiv_values(ne12);
|
||||
|
||||
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{block_nums, block_dims, 0, stream};
|
||||
ggml_cuda_kernel_launch(k_get_rows_float<src0_t, dst_t>, launch_params,
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10,*/ ne11, ne12_fdv, /*ne13,*/
|
||||
|
||||
@@ -67,9 +67,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
// See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132
|
||||
if ((nrows / nsm) < 2) {
|
||||
const dim3 block_dims(512, 1, 1);
|
||||
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
||||
ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/true>, launch_params, src0_d, dst_d, ncols);
|
||||
} else {
|
||||
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
|
||||
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
||||
ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/true>, launch_params, src0_d, dst_d, ncols);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ static __global__ void mul_mat_vec_f(
|
||||
int channel_y;
|
||||
int sample_dst;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
if constexpr (is_multi_token_id) {
|
||||
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
|
||||
token_idx = blockIdx.z;
|
||||
@@ -298,6 +299,7 @@ static __global__ void mul_mat_vec_f(
|
||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||
}
|
||||
|
||||
ggml_cuda_pdl_lc();
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||
@@ -382,11 +384,13 @@ static void mul_mat_vec_f_switch_fusion(
|
||||
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) {
|
||||
|
||||
const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, nbytes_shared, stream};
|
||||
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
if constexpr (ncols_dst == 1) {
|
||||
if (has_fusion) {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
ggml_cuda_kernel_launch(mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id>, launch_params,
|
||||
x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
|
||||
return;
|
||||
@@ -395,8 +399,8 @@ static void mul_mat_vec_f_switch_fusion(
|
||||
|
||||
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
||||
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
ggml_cuda_kernel_launch(mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id>, launch_params,
|
||||
x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
|
||||
|
||||
|
||||
@@ -359,7 +359,9 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_K:
|
||||
return 8;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return 2;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return 8;
|
||||
default:
|
||||
@@ -422,6 +424,7 @@ static __global__ void mul_mat_vec_q(
|
||||
uint32_t channel_y;
|
||||
uint32_t sample_dst;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
|
||||
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
|
||||
sample_dst = blockIdx.z;
|
||||
@@ -681,8 +684,9 @@ static void mul_mat_vec_q_switch_fusion(
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
if constexpr (c_ncols_dst == 1) {
|
||||
if (has_fusion) {
|
||||
mul_mat_vec_q<type, c_ncols_dst, true, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream);
|
||||
ggml_cuda_kernel_launch(mul_mat_vec_q<type, c_ncols_dst, true, small_k>, launch_params,
|
||||
vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
|
||||
return;
|
||||
@@ -691,8 +695,9 @@ static void mul_mat_vec_q_switch_fusion(
|
||||
|
||||
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
||||
|
||||
mul_mat_vec_q<type, c_ncols_dst, false, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream);
|
||||
ggml_cuda_kernel_launch(mul_mat_vec_q<type, c_ncols_dst, false, small_k>, launch_params,
|
||||
vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ static __global__ void norm_f32(
|
||||
|
||||
float2 mean_var = make_float2(0.0f, 0.0f);
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
mean_var.x += xi;
|
||||
@@ -46,6 +47,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int j = start; j < end; j += block_size) {
|
||||
tmp += x[j];
|
||||
}
|
||||
@@ -95,6 +97,7 @@ static __global__ void rms_norm_f32(const float * x,
|
||||
const uint3 add_nrows_packed = make_uint3(0, 0, 0),
|
||||
const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
|
||||
const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
|
||||
ggml_cuda_pdl_lc();
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
|
||||
@@ -124,6 +127,7 @@ static __global__ void rms_norm_f32(const float * x,
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
tmp += xi * xi;
|
||||
@@ -163,6 +167,7 @@ static __global__ void rms_norm_back_f32(
|
||||
float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
|
||||
float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xfi = xf[col];
|
||||
sum_xx += xfi * xfi;
|
||||
@@ -253,6 +258,7 @@ static __global__ void l2_norm_f32(
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
tmp += xi * xi;
|
||||
@@ -261,6 +267,7 @@ static __global__ void l2_norm_f32(
|
||||
// sum up partial sums
|
||||
extern __shared__ float s_sum[];
|
||||
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
|
||||
ggml_cuda_pdl_lc();
|
||||
|
||||
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
||||
@@ -300,10 +307,19 @@ static void rms_norm_f32_cuda(
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(256, 1, 1);
|
||||
rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
const ggml_cuda_kernel_launch_params launch_params = {blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
||||
ggml_cuda_kernel_launch(rms_norm_f32<256, false>, launch_params,
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps,
|
||||
// underlying cudaLaunchKernelEx does not support default params
|
||||
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0),
|
||||
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0));
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
||||
ggml_cuda_kernel_launch(rms_norm_f32<1024, false>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps,
|
||||
// underlying cudaLaunchKernelEx does not support default params
|
||||
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0),
|
||||
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,14 +362,20 @@ static void rms_norm_mul_f32_cuda(const float * x,
|
||||
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(256, 1, 1);
|
||||
rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
||||
ggml_cuda_kernel_launch(rms_norm_f32<256, true>, launch_params,
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed,
|
||||
// underlying cudaLaunchKernelEx does not support default params
|
||||
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0));
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
||||
ggml_cuda_kernel_launch(rms_norm_f32<1024, true>, launch_params,
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed,
|
||||
// underlying cudaLaunchKernelEx does not support default params
|
||||
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0));
|
||||
}
|
||||
} else {
|
||||
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
|
||||
@@ -367,14 +389,16 @@ static void rms_norm_mul_f32_cuda(const float * x,
|
||||
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(256, 1, 1);
|
||||
rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims,block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
||||
ggml_cuda_kernel_launch(rms_norm_f32<256, true, true>, launch_params,
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
||||
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
||||
add_nchannels_packed, add_nsamples_packed);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
||||
ggml_cuda_kernel_launch(rms_norm_f32<1024, true, true>, launch_params,
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
||||
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
||||
@@ -399,10 +423,12 @@ static void l2_norm_f32_cuda(
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, 0, stream};
|
||||
ggml_cuda_kernel_launch(l2_norm_f32<WARP_SIZE>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
||||
ggml_cuda_kernel_launch(l2_norm_f32<1024>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ static __global__ void quantize_q8_1(
|
||||
const float * __restrict__ x, void * __restrict__ vy,
|
||||
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
|
||||
ggml_cuda_pdl_lc();
|
||||
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
@@ -28,6 +29,7 @@ static __global__ void quantize_q8_1(
|
||||
const int64_t ib = i_cont / QK8_1; // block index
|
||||
const int64_t iqs = i_cont % QK8_1; // quant index
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f;
|
||||
float amax = fabsf(xi);
|
||||
float sum = xi;
|
||||
@@ -196,6 +198,7 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
|
||||
const int64_t i2 = blockIdx.z % ne2;
|
||||
const int64_t i3 = blockIdx.z / ne2;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
const int64_t i01 = ids ? ids[i1] : i1;
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i03 = i3;
|
||||
@@ -288,6 +291,7 @@ static __global__ void quantize_mmq_q8_1(
|
||||
const int64_t i3 = blockIdx.z / ne2;
|
||||
|
||||
const int64_t i00 = i0;
|
||||
ggml_cuda_pdl_sync();
|
||||
const int64_t i01 = ids ? ids[i1] : i1;
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i03 = i3;
|
||||
@@ -378,7 +382,8 @@ void quantize_row_q8_1_cuda(
|
||||
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
|
||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
|
||||
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, block_size, 0, stream);
|
||||
ggml_cuda_kernel_launch(quantize_q8_1, launch_params, x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
|
||||
GGML_UNUSED(type_src0);
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
|
||||
const int num_unroll = 8;
|
||||
float temp[num_unroll];
|
||||
float sum_temp[num_unroll] = { 0.0f };
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int i = col; i < ncols;) {
|
||||
for (int j = 0; j < num_unroll; ++j) {
|
||||
if (i < ncols) {
|
||||
|
||||
@@ -134,6 +134,7 @@ static __global__ void rope_neox(const T * x,
|
||||
const float * freq_factors,
|
||||
const int64_t * row_indices,
|
||||
const int set_rows_stride) {
|
||||
ggml_cuda_pdl_lc();
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne00) {
|
||||
@@ -148,6 +149,7 @@ static __global__ void rope_neox(const T * x,
|
||||
|
||||
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
ggml_cuda_pdl_sync();
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
||||
@@ -216,6 +218,7 @@ static __global__ void rope_multi(const T * x,
|
||||
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
||||
@@ -300,6 +303,7 @@ static __global__ void rope_vision(const T * x,
|
||||
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
const int sect_dims = sections.v[0] + sections.v[1];
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
@@ -399,13 +403,14 @@ static void rope_neox_cuda(const T * x,
|
||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, 0, stream};
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
ggml_cuda_kernel_launch(rope_neox<forward, false, T, D>, launch_params,
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
} else {
|
||||
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
ggml_cuda_kernel_launch(rope_neox<forward, true, T, D>, launch_params,
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
}
|
||||
@@ -443,11 +448,13 @@ static void rope_multi_cuda(const T * x,
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
||||
ggml_cuda_kernel_launch(rope_multi<forward, false, T>, launch_params,
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
||||
} else {
|
||||
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
||||
ggml_cuda_kernel_launch(rope_multi<forward, true, T>, launch_params,
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
||||
}
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
#define MAX_GRIDDIM_X 0x7FFFFFFF
|
||||
|
||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
|
||||
ggml_cuda_pdl_lc();
|
||||
int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
|
||||
int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int64_t i = tid; i < nelements; i += stride) {
|
||||
dst[i] = scale * x[i] + bias;
|
||||
}
|
||||
@@ -13,7 +15,8 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
|
||||
|
||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
|
||||
const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
||||
scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(scale_f32, launch_params, x, dst, scale, bias, nelements);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
@@ -53,6 +53,7 @@ static __global__ void k_set_rows_quant(const float * __restrict__ src0,
|
||||
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
|
||||
const int64_t i10 = i01;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
|
||||
const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
|
||||
@@ -157,7 +158,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0,
|
||||
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
|
||||
const int64_t i10 = i01;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
ggml_cuda_pdl_lc();
|
||||
|
||||
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
|
||||
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
|
||||
@@ -203,9 +206,11 @@ static void set_rows_cuda(
|
||||
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
|
||||
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
|
||||
|
||||
k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
|
||||
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
|
||||
ne11_fd, ne12_fd);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_size, block_size, 0, stream);
|
||||
ggml_cuda_kernel_launch(k_set_rows<src_t, idx_t, dst_t>, launch_params,
|
||||
src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
|
||||
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
|
||||
ne11_fd, ne12_fd);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
#include "softcap.cuh"
|
||||
|
||||
static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) {
|
||||
ggml_cuda_pdl_lc();
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
dst[i] = tanhf(scale * x[i]) * softcap;
|
||||
}
|
||||
|
||||
static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE;
|
||||
softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, scale, softcap, k);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(softcap_f32, launch_params, x, dst, scale, softcap, k);
|
||||
}
|
||||
|
||||
// fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "common.cuh"
|
||||
#include "ssm-conv.cuh"
|
||||
#include "unary.cuh"
|
||||
|
||||
@@ -7,6 +8,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
|
||||
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
|
||||
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
|
||||
const int64_t n_t) {
|
||||
ggml_cuda_pdl_lc();
|
||||
GGML_UNUSED(src0_nb0);
|
||||
const int tid = threadIdx.x;
|
||||
const int bidx = blockIdx.x;
|
||||
@@ -23,6 +25,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
|
||||
float x[d_conv] = { 0.0f };
|
||||
float w[d_conv] = { 0.0f };
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
#pragma unroll
|
||||
for (size_t j = 0; j < d_conv; j++) {
|
||||
w[j] = w_block[tid * stride_w + j];
|
||||
@@ -128,8 +131,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa
|
||||
constexpr int kNC = decltype(NC)::value;
|
||||
if (n_t <= 32) {
|
||||
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
|
||||
ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream);
|
||||
ggml_cuda_kernel_launch(ssm_conv_f32<apply_silu, threads, kNC>, launch_params, src0, src1, bias, src0_nb0, src0_nb1,
|
||||
src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
const int64_t split_n_t = 32;
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
@@ -140,11 +144,12 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa
|
||||
};
|
||||
|
||||
switch (nc) {
|
||||
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
|
||||
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
|
||||
case 5: launch_kernel(std::integral_constant<int, 5>{}); break;
|
||||
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
|
||||
default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now.");
|
||||
case 3: launch_kernel(std::integral_constant<int, 3 >{}); break;
|
||||
case 4: launch_kernel(std::integral_constant<int, 4 >{}); break;
|
||||
case 5: launch_kernel(std::integral_constant<int, 5 >{}); break;
|
||||
case 9: launch_kernel(std::integral_constant<int, 9 >{}); break;
|
||||
case 15: launch_kernel(std::integral_constant<int, 15>{}); break;
|
||||
default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9, 15 right now.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ __global__ void __launch_bounds__(splitD, 1)
|
||||
const int64_t s_off, const int64_t d_inner, const int64_t L_param)
|
||||
{
|
||||
const size_t L = L_template == 0 ? L_param : L_template;
|
||||
ggml_cuda_pdl_sync();
|
||||
const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
|
||||
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
|
||||
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
|
||||
@@ -135,6 +136,7 @@ __global__ void __launch_bounds__(d_state, 1)
|
||||
|
||||
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
// TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase
|
||||
const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
|
||||
@@ -206,7 +208,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
constexpr int num_warps = threads/WARP_SIZE;
|
||||
|
||||
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
|
||||
ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>(
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream);
|
||||
ggml_cuda_kernel_launch(ssm_scan_f32_group<128/WARP_SIZE, 128>, launch_params,
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
||||
@@ -215,7 +218,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
constexpr int num_warps = threads/WARP_SIZE;
|
||||
|
||||
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
|
||||
ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>(
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream);
|
||||
ggml_cuda_kernel_launch(ssm_scan_f32_group<256/WARP_SIZE, 256>, launch_params,
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
||||
@@ -231,58 +235,59 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
|
||||
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
|
||||
if (d_state == 16) {
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream);
|
||||
switch (n_tok)
|
||||
{
|
||||
case 1:
|
||||
ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
|
||||
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 1>, launch_params,
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 2:
|
||||
ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
|
||||
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 2>, launch_params,
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 3:
|
||||
ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
|
||||
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 3>, launch_params,
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 4:
|
||||
ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
|
||||
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 4>, launch_params,
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 5:
|
||||
ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
|
||||
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 5>, launch_params,
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 6:
|
||||
ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
|
||||
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 6>, launch_params,
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 7:
|
||||
ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
|
||||
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 7>, launch_params,
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 8:
|
||||
ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
|
||||
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 8>, launch_params,
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
default:
|
||||
ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
|
||||
ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 0>, launch_params,
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
|
||||
@@ -7,10 +7,12 @@ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
if ((nrows / nsm) < 2) {
|
||||
const dim3 block_dims(512, 1, 1);
|
||||
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
||||
ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/false>, launch_params, x, dst, ncols);
|
||||
} else {
|
||||
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
|
||||
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
||||
ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/false>, launch_params, x, dst, ncols);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,10 +36,12 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
if ((nrows / nsm) < 2) {
|
||||
// Increase num threads to 512 for small nrows to better hide the latency
|
||||
const dim3 block_dims(512, 1, 1);
|
||||
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
||||
ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/false>, launch_params, src0_d, dst_d, ncols);
|
||||
} else {
|
||||
// Enough active SMs to hide latency, use smaller blocks to allow better scheduling
|
||||
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
|
||||
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
||||
ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/false>, launch_params, src0_d, dst_d, ncols);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
# include <cub/cub.cuh>
|
||||
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
|
||||
# define CUB_TOP_K_AVAILABLE
|
||||
# include <cuda/iterator>
|
||||
using namespace cub;
|
||||
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
@@ -105,6 +105,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
wt[i] = -INFINITY;
|
||||
}
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const int expert = i + threadIdx.x;
|
||||
@@ -161,6 +162,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
output_weights[i] = 0.f;
|
||||
}
|
||||
|
||||
ggml_cuda_pdl_lc();
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
int max_expert = threadIdx.x;
|
||||
@@ -271,51 +273,52 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
|
||||
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
||||
cudaStream_t stream = ctx.stream();
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream);
|
||||
|
||||
switch (n_expert) {
|
||||
case 1:
|
||||
topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<1, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
case 2:
|
||||
topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<2, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
case 4:
|
||||
topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<4, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
case 8:
|
||||
topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<8, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
case 16:
|
||||
topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<16, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
case 32:
|
||||
topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<32, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
case 64:
|
||||
topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<64, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
case 128:
|
||||
topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<128, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
case 256:
|
||||
topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<256, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
case 512:
|
||||
topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<512, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
case 576:
|
||||
topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<576, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "fatal error");
|
||||
|
||||
@@ -116,19 +116,22 @@ static __device__ __forceinline__ float op_trunc(float x) {
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
|
||||
ggml_cuda_pdl_lc();
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
dst[i] = (T)op((float)x[i]);
|
||||
}
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
|
||||
unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(unary_op_kernel<op, T>, launch_params, x, dst, k);
|
||||
}
|
||||
|
||||
template <float (*op)(float)>
|
||||
@@ -258,6 +261,7 @@ void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) {
|
||||
ggml_cuda_pdl_lc();
|
||||
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
@@ -268,13 +272,15 @@ static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst,
|
||||
const int64_t j0 = (i / n) * o0 + (i % n);
|
||||
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
dst[i] = (T)(op((float)x[j0]) * (float)g[j1]);
|
||||
}
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) {
|
||||
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
|
||||
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(unary_gated_op_kernel<op, T>, launch_params, x, g, dst, k, n, o0, o1);
|
||||
}
|
||||
|
||||
template <float (*op)(float)>
|
||||
|
||||
@@ -2661,7 +2661,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
|
||||
|
||||
int mode = op_params[2];
|
||||
|
||||
if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
|
||||
if (mode == GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
if (mode & 1) {
|
||||
@@ -2735,15 +2735,28 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session *
|
||||
if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contiguous tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
if (src0->nb[0] != sizeof(float) || src1->nb[0] != sizeof(float) || dst->nb[0] != sizeof(float)) {
|
||||
return false;
|
||||
}
|
||||
if (src0->nb[1] != src0->ne[0] * sizeof(float) || src1->nb[1] != src1->ne[0] * sizeof(float)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_pad(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_UNUSED(sess);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
@@ -2816,6 +2829,21 @@ static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session *
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) { return false; }
|
||||
if (dst->type != GGML_TYPE_F32) { return false; }
|
||||
if (!ggml_are_same_shape(src0, dst)) { return false; }
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { return false; }
|
||||
|
||||
return true;
|
||||
|
||||
GGML_UNUSED(sess);
|
||||
}
|
||||
|
||||
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
|
||||
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
|
||||
return sess->c_name();
|
||||
@@ -2843,6 +2871,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
|
||||
case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS;
|
||||
case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS;
|
||||
case GGML_OP_ARGSORT: return HTP_OP_ARGSORT;
|
||||
case GGML_OP_NORM: return HTP_OP_NORM;
|
||||
case GGML_OP_L2_NORM: return HTP_OP_L2_NORM;
|
||||
case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM;
|
||||
case GGML_OP_SCALE: return HTP_OP_SCALE;
|
||||
@@ -2857,6 +2886,9 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
|
||||
case GGML_OP_FILL: return HTP_OP_FILL;
|
||||
case GGML_OP_DIAG: return HTP_OP_DIAG;
|
||||
case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI;
|
||||
case GGML_OP_TRI: return HTP_OP_TRI;
|
||||
case GGML_OP_PAD: return HTP_OP_PAD;
|
||||
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(t)) {
|
||||
case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU;
|
||||
@@ -3308,10 +3340,8 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_add_id(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
@@ -3416,6 +3446,14 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_solve_tri(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_TRI:
|
||||
supp = ggml_hexagon_supported_tri(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_PAD:
|
||||
supp = ggml_hexagon_supported_pad(sess, op);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ add_library(${HTP_LIB} SHARED
|
||||
diag-ops.c
|
||||
solve-tri-ops.c
|
||||
gated-delta-net-ops.c
|
||||
pad-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
|
||||
@@ -201,11 +201,10 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32
|
||||
|
||||
// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using
|
||||
// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls.
|
||||
// Output: out[0..3] each hold 32 FP16 values in the first 64 bytes.
|
||||
static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes.
|
||||
static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
const uint8_t *packed_128, bool upper_nibbles,
|
||||
const __fp16 *scales_4, const HVX_Vector vlut_cvt,
|
||||
HVX_Vector out[4]) {
|
||||
const __fp16 *scales_4, const HVX_Vector vlut_cvt) {
|
||||
// Load all 128 packed bytes (4 contiguous 32-byte groups)
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
@@ -221,8 +220,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16]
|
||||
|
||||
// Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b
|
||||
volatile HVX_Vector vscale = hvx_vmemu(scales_4);
|
||||
|
||||
HVX_Vector vscale = hvx_vmemu(scales_4);
|
||||
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale);
|
||||
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4));
|
||||
|
||||
@@ -230,8 +228,9 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
// Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter
|
||||
out[0] = v_lo; // group0 already in [0:63]
|
||||
out[1] = v_hi; // group2 already in [0:63]
|
||||
HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */
|
||||
v_hi /* group2 already in [0:63] */ };
|
||||
return r;
|
||||
}
|
||||
|
||||
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
|
||||
@@ -292,12 +291,11 @@ static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed
|
||||
}
|
||||
|
||||
// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes).
|
||||
static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128,
|
||||
static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128,
|
||||
bool upper_nibbles,
|
||||
int sub_blk_base,
|
||||
const HVX_Vector vlut_cvt,
|
||||
mxfp4_scales_t scales,
|
||||
HVX_Vector out[4]) {
|
||||
mxfp4_scales_t scales) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||||
@@ -318,10 +316,8 @@ static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_12
|
||||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
out[0] = v_lo;
|
||||
out[1] = Q6_V_vror_VR(v_lo, 64);
|
||||
out[2] = v_hi;
|
||||
out[3] = Q6_V_vror_VR(v_hi, 64);
|
||||
HVX_Vector_x4 r = { v_lo, Q6_V_vror_VR(v_lo, 64), v_hi, Q6_V_vror_VR(v_hi, 64) };
|
||||
return r;
|
||||
}
|
||||
|
||||
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
|
||||
@@ -372,18 +368,18 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
|
||||
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
HVX_Vector v0[2];
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]);
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
|
||||
|
||||
r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
|
||||
@@ -415,21 +411,21 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||||
|
||||
HVX_Vector v0[4], v1[4];
|
||||
dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0);
|
||||
HVX_Vector_x4 dv0, dv1;
|
||||
dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8);
|
||||
if (row1 < n_cols) {
|
||||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||||
dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1);
|
||||
dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8);
|
||||
} else {
|
||||
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
|
||||
dv1.v[0] = dv1.v[1] = dv1.v[2] = dv1.v[3] = Q6_V_vzero();
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) {
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[g]);
|
||||
}
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
for (int g = 0; g < 4; g++) {
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[g]);
|
||||
}
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
@@ -612,11 +608,13 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict
|
||||
const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
for (int k = 0; k < n_dot_tiles; ++k) {
|
||||
Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047);
|
||||
row_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
for (int k = 0, k_block; k < n_dot_tiles; k += k_block) {
|
||||
k_block = hex_smin(n_dot_tiles - k, 32);
|
||||
const uint32_t range = 2048u * (uint32_t)k_block - 1;
|
||||
Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range);
|
||||
row_tiles += k_block * HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += k_block * HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
__fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS;
|
||||
@@ -832,10 +830,6 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *
|
||||
worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#define FALLBACK_TO_STANDARD 1
|
||||
|
||||
// C += AB
|
||||
static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b,
|
||||
const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile,
|
||||
@@ -861,314 +855,80 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047);
|
||||
}
|
||||
|
||||
for (int k = 0; k < n_dot_tiles; ++k) {
|
||||
Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047);
|
||||
row_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
for (int k = 0, k_block; k < n_dot_tiles; k += k_block) {
|
||||
k_block = hex_smin(n_dot_tiles - k, 32);
|
||||
const uint32_t range = 2048u * (uint32_t)k_block - 1;
|
||||
Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range);
|
||||
row_tiles += k_block * HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += k_block * HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
Q6_mxmem_AR_after_hf(accum_tile, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx,
|
||||
float *restrict out, const float *restrict x, const uint8_t *restrict w,
|
||||
int m, int k, int n, int weight_type) {
|
||||
// assume k % 32 == 0 && n % 32 == 0
|
||||
const size_t row_stride = get_x4x2_row_stride(weight_type, k);
|
||||
if (row_stride == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
|
||||
const size_t K_BLOCK_SIZE = 1024;
|
||||
|
||||
// Fallback: if k doesn't need K-blocking, out-stationary has no advantage
|
||||
const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE;
|
||||
if (k_iters_check <= 1) {
|
||||
FARF(HIGH, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k);
|
||||
return FALLBACK_TO_STANDARD;
|
||||
}
|
||||
|
||||
// Dynamic M,N search via hmx_compute_chunks
|
||||
const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE);
|
||||
const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32)
|
||||
+ K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles)
|
||||
const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant)
|
||||
+ K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles)
|
||||
const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary)
|
||||
|
||||
// Alignment margin: hex_align_up can add up to 2047 bytes per buffer;
|
||||
// scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin
|
||||
const size_t align_margin = 4 * HMX_FP16_TILE_SIZE;
|
||||
const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment
|
||||
|
||||
size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used;
|
||||
// Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost.
|
||||
// From profiling: wt_dequant per element ≈ 1.5× activation load per element.
|
||||
// m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive).
|
||||
// n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper).
|
||||
const size_t m_block_cost = (size_t) n * 3;
|
||||
const size_t n_block_cost = (size_t) m * 2;
|
||||
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn,
|
||||
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
|
||||
m_block_cost, n_block_cost, &M_BLOCK_SIZE,
|
||||
&N_BLOCK_SIZE, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Compute precise buffer sizes from searched M,N and fixed K
|
||||
const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE);
|
||||
const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE);
|
||||
|
||||
const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256;
|
||||
if (total_vtcm > vtcm_budget) {
|
||||
FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm,
|
||||
vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE);
|
||||
return -1;
|
||||
}
|
||||
|
||||
uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base;
|
||||
__fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size);
|
||||
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size);
|
||||
__fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size);
|
||||
uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz);
|
||||
uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz);
|
||||
__fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE);
|
||||
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
|
||||
assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget);
|
||||
|
||||
FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type,
|
||||
M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
|
||||
|
||||
// initialize eye tile (32x32 identity matrix)
|
||||
{
|
||||
HVX_Vector v;
|
||||
v = Q6_V_vzero();
|
||||
v = Q6_Vw_vinsert_VwR(v, 0x3c000000);
|
||||
v = Q6_V_vror_VR(v, VLEN - 4);
|
||||
v = Q6_Vw_vinsert_VwR(v, 0x00003c00);
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
((HVX_Vector *) vtcm_eye_tile)[i] = v;
|
||||
v = Q6_V_vror_VR(v, VLEN - 8);
|
||||
}
|
||||
}
|
||||
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
|
||||
|
||||
TIMER_DEFINE(fetch);
|
||||
TIMER_DEFINE(act_load);
|
||||
TIMER_DEFINE(wt_dequant);
|
||||
TIMER_DEFINE(core);
|
||||
|
||||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||||
|
||||
for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) {
|
||||
size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE);
|
||||
for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) {
|
||||
size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE);
|
||||
|
||||
const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS);
|
||||
const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS);
|
||||
|
||||
for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) {
|
||||
const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE);
|
||||
|
||||
TIMER_START(fetch);
|
||||
// fetch activation block into VTCM
|
||||
{
|
||||
const float *activation_block = x + mr * k + kk;
|
||||
|
||||
dma_queue_push(ctx->dma[0],
|
||||
dma_make_ptr(vtcm_scratch1, activation_block),
|
||||
k_blk_sz * sizeof(float),
|
||||
k * sizeof(float),
|
||||
k_blk_sz * sizeof(float),
|
||||
m_blk_sz);
|
||||
}
|
||||
|
||||
// fetch weight block into VTCM (x4x2 sub-block: quants + scales)
|
||||
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
|
||||
{
|
||||
const int blk_start = kk / QK_Q4_0x4x2;
|
||||
const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
|
||||
const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2);
|
||||
const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE;
|
||||
uint8_t *dst = vtcm_scratch0;
|
||||
const uint8_t *src = w + nc * row_stride;
|
||||
const size_t n_rows = n_blk_sz;
|
||||
const size_t src_stride = row_stride;
|
||||
const size_t dst_stride = sub_row_stride;
|
||||
const size_t quant_off = (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2));
|
||||
const size_t quant_width = (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2));
|
||||
const size_t scale_off = full_qrow + blk_start * scale_blk_size;
|
||||
const size_t scale_width = nb_sub * scale_blk_size;
|
||||
|
||||
// 2D DMA: quants sub-range
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(dst, src + quant_off), dst_stride, src_stride, quant_width, n_rows);
|
||||
// 2D DMA: scales sub-range
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width, src + scale_off), dst_stride, src_stride, scale_width, n_rows);
|
||||
}
|
||||
TIMER_STOP(fetch);
|
||||
|
||||
TIMER_START(act_load);
|
||||
// load activation block
|
||||
{
|
||||
dma_queue_pop(ctx->dma[0]); // wait for act DNA
|
||||
transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz);
|
||||
}
|
||||
TIMER_STOP(act_load);
|
||||
|
||||
TIMER_START(wt_dequant);
|
||||
// dequantize weight block
|
||||
{
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
// vtcm_scratch0 is used to store the qweight chunk
|
||||
// worker_pool_run_func already returned, so fetch is done
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0,
|
||||
n_blk_sz, k_blk_sz, sub_row_stride, weight_type);
|
||||
}
|
||||
TIMER_STOP(wt_dequant);
|
||||
|
||||
// core mma
|
||||
TIMER_START(core);
|
||||
{
|
||||
core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles,
|
||||
n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0);
|
||||
}
|
||||
TIMER_STOP(core);
|
||||
}
|
||||
|
||||
// store output block
|
||||
{
|
||||
float *output_block = out + (mr * n + nc);
|
||||
transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us",
|
||||
TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core));
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
|
||||
int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
|
||||
const uint8_t *restrict permuted_weight, int m, int k, int n,
|
||||
int weight_type) {
|
||||
if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; }
|
||||
if (k % 32 != 0 || n % 32 != 0) { return -1; }
|
||||
|
||||
if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// for large m, k (e.g. prefill FFN Down), use out-stationary version
|
||||
if (m >= 128 && k > n && n > 1024) {
|
||||
int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type);
|
||||
if (rc != FALLBACK_TO_STANDARD) {
|
||||
return rc; // 0 success, -1 error
|
||||
}
|
||||
FARF(HIGH, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n);
|
||||
// fall through to standard path
|
||||
}
|
||||
|
||||
size_t row_stride = get_x4x2_row_stride(weight_type, k);
|
||||
if (row_stride == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
FARF(HIGH, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type);
|
||||
|
||||
// --- Dynamic VTCM layout ---
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
const size_t vec_dot_size = k * sizeof(__fp16);
|
||||
const size_t vec_dot_size = k * sizeof(__fp16);
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
size_t vtcm_used = 0;
|
||||
|
||||
// Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap.
|
||||
// Only pays off when the chunker yields >=2 n-chunks, so the main loop can
|
||||
// overlap HMX (C) with HVX (B/D); with a single n-chunk the extra VTCM for
|
||||
// double-buffered output and the worker-dispatch overhead are pure loss.
|
||||
// Try pipeline costs first; fall back to sequential if the layout collapses
|
||||
// to one n-chunk. m >= 128 floor keeps HMX utilization reasonable.
|
||||
const size_t pipe_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs)
|
||||
const size_t pipe_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer)
|
||||
const size_t seq_per_n = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs)
|
||||
const size_t seq_per_mn = sizeof(__fp16); // O x 1
|
||||
const size_t size_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs)
|
||||
const size_t size_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer)
|
||||
|
||||
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
|
||||
bool use_pipeline = false;
|
||||
|
||||
if (m >= 128) {
|
||||
size_t mc = 0, nc = 0, used = 0;
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn,
|
||||
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
|
||||
/*m_block_cost=*/(size_t) n * 3,
|
||||
/*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 &&
|
||||
hmx_ceil_div((size_t) n, nc) >= 2) {
|
||||
m_chunk_n_rows = mc;
|
||||
n_chunk_n_cols = nc;
|
||||
vtcm_used = used;
|
||||
use_pipeline = true;
|
||||
}
|
||||
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0;
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn,
|
||||
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
|
||||
/*m_block_cost=*/(size_t) n * 3,
|
||||
/*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) {
|
||||
FARF(HIGH, "hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!use_pipeline) {
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn,
|
||||
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
|
||||
/*m_block_cost=*/(size_t) n * 3,
|
||||
/*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute precise buffer sizes per execution path
|
||||
const size_t weight_area_size = hex_align_up(
|
||||
n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE);
|
||||
const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE);
|
||||
const size_t output_area_size = hex_align_up(
|
||||
m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE);
|
||||
const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE);
|
||||
const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
|
||||
size_t scratch0_size, scratch1_size, scratch2_size;
|
||||
if (use_pipeline) {
|
||||
scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0
|
||||
scratch1_size = scratch0_size; // dequant buf 1
|
||||
scratch2_size = output_area_size; // output buf 1
|
||||
} else {
|
||||
scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0
|
||||
scratch1_size = scratch0_size; // x4x2 DMA buf 1
|
||||
scratch2_size = 0; // unused
|
||||
}
|
||||
scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0
|
||||
scratch1_size = scratch0_size; // dequant buf 1
|
||||
scratch2_size = output_area_size; // output buf 1
|
||||
|
||||
uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base;
|
||||
__fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size);
|
||||
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size);
|
||||
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size);
|
||||
__fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size);
|
||||
void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size);
|
||||
void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size);
|
||||
void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL;
|
||||
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
|
||||
if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) {
|
||||
FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__,
|
||||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||||
|
||||
vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base;
|
||||
if (vtcm_used > vtcm_budget) {
|
||||
FARF(ERROR, "hmx-mm-q: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget);
|
||||
return -1;
|
||||
}
|
||||
|
||||
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
|
||||
|
||||
FARF(HIGH, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu",
|
||||
__func__, m, k, n, weight_type, use_pipeline,
|
||||
m_chunk_n_rows, n_chunk_n_cols,
|
||||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||||
FARF(HIGH, "hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu",
|
||||
m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget);
|
||||
|
||||
TIMER_DEFINE(activation_load);
|
||||
TIMER_DEFINE(weight_load);
|
||||
@@ -1178,184 +938,115 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
TIMER_DEFINE(total);
|
||||
TIMER_START(total);
|
||||
|
||||
FARF(HIGH, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu",
|
||||
use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols,
|
||||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||||
// 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D)
|
||||
// HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D).
|
||||
|
||||
if (!use_pipeline) {
|
||||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||||
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
|
||||
// transfer activation matrix chunk into VTCM
|
||||
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
|
||||
const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
|
||||
// A --> B: vtcm_qweight, 1 buffer
|
||||
// B --> C: vtcm_weight0/vtcm_weight1, 2 buffers
|
||||
// C --> D: vtcm_output0/vtcm_output1, 2 buffers
|
||||
|
||||
TIMER_START(activation_load);
|
||||
{
|
||||
const float *activation_chunk = activation + mr * k;
|
||||
transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k);
|
||||
}
|
||||
TIMER_STOP(activation_load);
|
||||
// Async timeline (C overlaps B+D):
|
||||
// main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2]
|
||||
// HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████]
|
||||
|
||||
void *buf_curr = vtcm_scratch0;
|
||||
void *buf_next = vtcm_scratch1;
|
||||
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
|
||||
hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors
|
||||
|
||||
{
|
||||
const size_t n_cols_first = hex_smin(n, n_chunk_n_cols);
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first);
|
||||
}
|
||||
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
|
||||
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
|
||||
|
||||
for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
|
||||
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
|
||||
const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
|
||||
void *vtcm_qweight = vtcm_weight;
|
||||
void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 };
|
||||
void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 };
|
||||
|
||||
TIMER_START(weight_load);
|
||||
{
|
||||
dma_queue_pop(ctx->dma[0]); // wait until current weight chunk become ready
|
||||
|
||||
const size_t nc_next = nc + n_chunk_n_cols;
|
||||
if (nc_next < n) {
|
||||
const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols);
|
||||
|
||||
const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride;
|
||||
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next);
|
||||
}
|
||||
|
||||
// Dequant + vscatter writes directly to [K, N] transposed tiles.
|
||||
// HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight.
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type);
|
||||
|
||||
hex_swap_ptr(&buf_curr, &buf_next);
|
||||
}
|
||||
TIMER_STOP(weight_load);
|
||||
|
||||
TIMER_START(hmx_core);
|
||||
{
|
||||
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32);
|
||||
}
|
||||
TIMER_STOP(hmx_core);
|
||||
|
||||
TIMER_START(output_store);
|
||||
{
|
||||
float *output = dst + (mr * n + nc);
|
||||
transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n);
|
||||
}
|
||||
TIMER_STOP(output_store);
|
||||
}
|
||||
// prologue: A0
|
||||
const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols);
|
||||
{
|
||||
const uint8_t *qweight_chunk_A0 = permuted_weight;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0);
|
||||
}
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
} else {
|
||||
// 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D)
|
||||
// HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D).
|
||||
|
||||
// A --> B: vtcm_qweight, 1 buffer
|
||||
// B --> C: vtcm_weight0/vtcm_weight1, 2 buffers
|
||||
// C --> D: vtcm_output0/vtcm_output1, 2 buffers
|
||||
{
|
||||
const float *activation_chunk = activation + mr * k;
|
||||
transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k);
|
||||
}
|
||||
|
||||
// Async timeline (C overlaps B+D):
|
||||
// main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2]
|
||||
// HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████]
|
||||
// prologue: B0, A1, submit C0 (async), B1 (overlaps C0)
|
||||
{
|
||||
// B0: wait for DMA, dequant weight chunk 0
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type);
|
||||
|
||||
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
|
||||
hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors
|
||||
|
||||
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
|
||||
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
|
||||
|
||||
void *vtcm_qweight = vtcm_weight;
|
||||
void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 };
|
||||
void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 };
|
||||
|
||||
// prologue: A0
|
||||
const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols);
|
||||
{
|
||||
// Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow.
|
||||
const uint8_t *qweight_chunk_A0 = permuted_weight;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0);
|
||||
// A1: issue DMA for weight chunk 1
|
||||
const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
|
||||
if (1 < n_chunk_cnt) {
|
||||
const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1);
|
||||
}
|
||||
|
||||
{
|
||||
const float *activation_chunk = activation + mr * k;
|
||||
transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k);
|
||||
}
|
||||
// submit C0 (non-blocking — HMX worker executes in parallel)
|
||||
hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation,
|
||||
(__fp16 *) vtcm_weight_bufs[0], vtcm_scales,
|
||||
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
|
||||
hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||||
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0]));
|
||||
|
||||
// prologue: B0, A1, submit C0 (async), B1 (overlaps C0)
|
||||
{
|
||||
// B0: wait for DMA, dequant weight chunk 0
|
||||
// B1: DMA pop + dequant (runs in parallel with C0 on HMX worker)
|
||||
if (1 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type);
|
||||
|
||||
// A1: issue DMA for weight chunk 1
|
||||
const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
|
||||
if (1 < n_chunk_cnt) {
|
||||
const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1);
|
||||
}
|
||||
|
||||
// submit C0 (non-blocking — HMX worker executes in parallel)
|
||||
hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation,
|
||||
(__fp16 *) vtcm_weight_bufs[0], vtcm_scales,
|
||||
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
|
||||
hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||||
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0]));
|
||||
|
||||
// B1: DMA pop + dequant (runs in parallel with C0 on HMX worker)
|
||||
if (1 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type);
|
||||
}
|
||||
}
|
||||
|
||||
// main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1})
|
||||
for (int i = 0; i < n_chunk_cnt; ++i) {
|
||||
const size_t nc = i * n_chunk_n_cols;
|
||||
const size_t nc_p1 = nc + 1 * n_chunk_n_cols;
|
||||
const size_t nc_p2 = nc + 2 * n_chunk_n_cols;
|
||||
|
||||
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
|
||||
const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols);
|
||||
const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols);
|
||||
|
||||
// issue A_{i+2}: DMA push (non-blocking)
|
||||
if (i + 2 < n_chunk_cnt) {
|
||||
const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2);
|
||||
}
|
||||
|
||||
// wait C_i: block until prologue/previous C completes
|
||||
hmx_queue_pop(ctx->hmx_queue);
|
||||
|
||||
// submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below)
|
||||
// job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's
|
||||
// counterpart — and (i+1)%2 was last used by C_{i-1} which completed
|
||||
// before C_i was submitted.
|
||||
if (i + 1 < n_chunk_cnt) {
|
||||
hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2],
|
||||
(__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2],
|
||||
vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
|
||||
hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||||
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2]));
|
||||
}
|
||||
|
||||
// D_i: store output (multi-thread HVX, parallel with C_{i+1})
|
||||
float *output_chunk = dst + (mr * n + nc);
|
||||
transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n);
|
||||
|
||||
// B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1})
|
||||
if (i + 2 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type);
|
||||
}
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type);
|
||||
}
|
||||
}
|
||||
|
||||
hmx_queue_suspend(ctx->hmx_queue);
|
||||
// main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1})
|
||||
for (int i = 0; i < n_chunk_cnt; ++i) {
|
||||
const size_t nc = i * n_chunk_n_cols;
|
||||
const size_t nc_p1 = nc + 1 * n_chunk_n_cols;
|
||||
const size_t nc_p2 = nc + 2 * n_chunk_n_cols;
|
||||
|
||||
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
|
||||
const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols);
|
||||
const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols);
|
||||
|
||||
// issue A_{i+2}: DMA push (non-blocking)
|
||||
if (i + 2 < n_chunk_cnt) {
|
||||
const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2);
|
||||
}
|
||||
|
||||
// wait C_i: block until prologue/previous C completes
|
||||
hmx_queue_pop(ctx->hmx_queue);
|
||||
|
||||
// submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below)
|
||||
// job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's
|
||||
// counterpart — and (i+1)%2 was last used by C_{i-1} which completed
|
||||
// before C_i was submitted.
|
||||
if (i + 1 < n_chunk_cnt) {
|
||||
hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2],
|
||||
(__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2],
|
||||
vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
|
||||
hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||||
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2]));
|
||||
}
|
||||
|
||||
// D_i: store output (multi-thread HVX, parallel with C_{i+1})
|
||||
float *output_chunk = dst + (mr * n + nc);
|
||||
transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n);
|
||||
|
||||
// B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1})
|
||||
if (i + 2 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hmx_queue_suspend(ctx->hmx_queue);
|
||||
|
||||
TIMER_STOP(total);
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d pipeline=%d", __func__, TIMER_US(total), m, k, n, use_pipeline);
|
||||
FARF(HIGH, "hex-mm-q: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n);
|
||||
if (!use_pipeline) {
|
||||
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
|
||||
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
|
||||
@@ -1370,15 +1061,15 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
|
||||
//
|
||||
|
||||
static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) {
|
||||
static inline int hmx_matmul_batch_r2(const hmx_matmul_f16_f32_batched_params_t *params) {
|
||||
return params->ne02 > 0 ? params->ne12 / params->ne02 : 1;
|
||||
}
|
||||
|
||||
static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) {
|
||||
static inline int hmx_matmul_batch_r3(const hmx_matmul_f16_f32_batched_params_t *params) {
|
||||
return params->ne03 > 0 ? params->ne13 / params->ne03 : 1;
|
||||
}
|
||||
|
||||
static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params,
|
||||
static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params,
|
||||
int dst_b2, int dst_b3) {
|
||||
const int r2 = hmx_matmul_batch_r2(params);
|
||||
const int r3 = hmx_matmul_batch_r3(params);
|
||||
@@ -1387,37 +1078,36 @@ static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_
|
||||
(size_t) (dst_b3 / r3) * params->src0_nb3);
|
||||
}
|
||||
|
||||
static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params,
|
||||
static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params,
|
||||
int dst_b2, int dst_b3) {
|
||||
return (const float *) ((const uint8_t *) params->activation +
|
||||
(size_t) dst_b2 * params->src1_nb2 +
|
||||
(size_t) dst_b3 * params->src1_nb3);
|
||||
}
|
||||
|
||||
static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params,
|
||||
static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params,
|
||||
int dst_b2, int dst_b3) {
|
||||
return (float *) ((uint8_t *) params->dst +
|
||||
(size_t) dst_b2 * params->dst_nb2 +
|
||||
(size_t) dst_b3 * params->dst_nb3);
|
||||
}
|
||||
|
||||
static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx,
|
||||
const hmx_matmul_w16a32_batched_params_t *params) {
|
||||
static int hmx_matmul_f16_f32_batched_legacy(struct htp_context *ctx,
|
||||
const hmx_matmul_f16_f32_batched_params_t *params) {
|
||||
int ret = 0;
|
||||
for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) {
|
||||
for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) {
|
||||
ret = hmx_mat_mul_permuted_w16a32(ctx,
|
||||
hmx_matmul_dst_batch_ptr(params, b2, b3),
|
||||
hmx_matmul_activation_batch_ptr(params, b2, b3),
|
||||
hmx_matmul_weight_batch_ptr(params, b2, b3),
|
||||
params->m, params->k, params->n,
|
||||
params->act_stride, params->weight_stride);
|
||||
ret = hmx_matmul_f16_f32(ctx, hmx_matmul_dst_batch_ptr(params, b2, b3),
|
||||
hmx_matmul_activation_batch_ptr(params, b2, b3),
|
||||
hmx_matmul_weight_batch_ptr(params, b2, b3),
|
||||
params->m, params->k, params->n,
|
||||
params->act_stride, params->weight_stride);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) {
|
||||
int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params) {
|
||||
if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; }
|
||||
if (!params->m || !params->k || !params->n) { return -1; }
|
||||
if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; }
|
||||
@@ -1435,7 +1125,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
|
||||
|
||||
if (group_size <= 1) {
|
||||
FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size);
|
||||
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
|
||||
return hmx_matmul_f16_f32_batched_legacy(ctx, params);
|
||||
}
|
||||
|
||||
// Grouped path: reuse interleaved weight across all q_heads sharing a
|
||||
@@ -1464,7 +1154,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
|
||||
/*m_block_cost=*/(size_t) params->n,
|
||||
/*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__);
|
||||
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
|
||||
return hmx_matmul_f16_f32_batched_legacy(ctx, params);
|
||||
}
|
||||
|
||||
const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads
|
||||
@@ -1486,7 +1176,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
|
||||
|
||||
if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) {
|
||||
FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__);
|
||||
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
|
||||
return hmx_matmul_f16_f32_batched_legacy(ctx, params);
|
||||
}
|
||||
|
||||
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
|
||||
@@ -1614,7 +1304,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
|
||||
|
||||
//
|
||||
|
||||
int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
|
||||
int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
|
||||
const __fp16 *restrict permuted_weight, int m, int k, int n,
|
||||
int act_stride, int weight_stride) {
|
||||
if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; }
|
||||
|
||||
@@ -33,14 +33,14 @@ typedef struct {
|
||||
size_t src1_nb3;
|
||||
size_t dst_nb2;
|
||||
size_t dst_nb3;
|
||||
} hmx_matmul_w16a32_batched_params_t;
|
||||
} hmx_matmul_f16_f32_batched_params_t;
|
||||
|
||||
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
|
||||
// act_stride: activation row stride in elements (= k for contiguous, or
|
||||
// nb[1]/sizeof(float) for permuted tensors like attention Q).
|
||||
// weight_stride: weight row stride in elements (= k for compact weights, or
|
||||
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
|
||||
int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx,
|
||||
int hmx_matmul_f16_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const __fp16 *permuted_weight,
|
||||
@@ -48,13 +48,12 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx,
|
||||
int act_stride,
|
||||
int weight_stride);
|
||||
|
||||
// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32.
|
||||
// Batched F16 wrapper over hmx_mat_mul_f16_f32.
|
||||
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
|
||||
int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx,
|
||||
const hmx_matmul_w16a32_batched_params_t *params);
|
||||
int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params);
|
||||
|
||||
// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL)
|
||||
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
|
||||
// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4)
|
||||
int hmx_matmul_q_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const uint8_t *permuted_weight,
|
||||
|
||||
@@ -107,5 +107,7 @@ int op_fill(struct htp_ops_context * octx);
|
||||
int op_diag(struct htp_ops_context * octx);
|
||||
int op_solve_tri(struct htp_ops_context * octx);
|
||||
int op_gated_delta_net(struct htp_ops_context * octx);
|
||||
int op_tri(struct htp_ops_context * octx);
|
||||
int op_pad(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_CTX_H */
|
||||
|
||||
@@ -86,6 +86,9 @@ enum htp_op_code {
|
||||
HTP_OP_SOLVE_TRI,
|
||||
HTP_OP_L2_NORM,
|
||||
HTP_OP_GATED_DELTA_NET,
|
||||
HTP_OP_TRI,
|
||||
HTP_OP_PAD,
|
||||
HTP_OP_NORM,
|
||||
|
||||
HTP_OP_INVALID
|
||||
};
|
||||
|
||||
@@ -87,6 +87,27 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
|
||||
}
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ >= 75
|
||||
{
|
||||
// Power on HMX and set HMX clock
|
||||
HAP_power_request_t request;
|
||||
memset(&request, 0, sizeof(HAP_power_request_t));
|
||||
request.type = HAP_power_set_HMX_v2;
|
||||
request.hmx_v2.set_power = TRUE;
|
||||
request.hmx_v2.power_up = TRUE;
|
||||
request.hmx_v2.set_clock = TRUE;
|
||||
request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH;
|
||||
FARF(ALWAYS, "Setting HMX clock\n");
|
||||
err = HAP_power_set((void *) ctx, &request);
|
||||
if (err != AEE_SUCCESS) {
|
||||
FARF(ERROR, "Error setting HMX clock.");
|
||||
return err;
|
||||
}
|
||||
}
|
||||
#else
|
||||
{
|
||||
// Power on HMX
|
||||
HAP_power_request_t request;
|
||||
@@ -94,31 +115,12 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
|
||||
request.type = HAP_power_set_HMX;
|
||||
request.hmx.power_up = TRUE;
|
||||
FARF(ALWAYS, "Powering HMX on\n");
|
||||
err = HAP_power_set((void *) &ctx, &request);
|
||||
err = HAP_power_set((void *) ctx, &request);
|
||||
if (err != AEE_SUCCESS) {
|
||||
FARF(ERROR, "Error powering on HMX.");
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ >= 75
|
||||
{
|
||||
// Set HMX clock
|
||||
HAP_power_request_t request;
|
||||
memset(&request, 0, sizeof(HAP_power_request_t));
|
||||
request.type = HAP_power_set_HMX_v2;
|
||||
request.hmx_v2.set_clock = TRUE;
|
||||
request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH;
|
||||
FARF(ALWAYS, "Setting HMX clock\n");
|
||||
err = HAP_power_set((void *) &ctx, &request);
|
||||
if (err != AEE_SUCCESS) {
|
||||
FARF(ERROR, "Error setting HMX clock.");
|
||||
return err;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return AEE_SUCCESS;
|
||||
@@ -534,6 +536,7 @@ static int execute_op(struct htp_ops_context * octx) {
|
||||
case HTP_OP_ADD_ID:
|
||||
return op_binary(octx);
|
||||
|
||||
case HTP_OP_NORM:
|
||||
case HTP_OP_RMS_NORM:
|
||||
case HTP_OP_SCALE:
|
||||
case HTP_OP_SQR:
|
||||
@@ -595,9 +598,15 @@ static int execute_op(struct htp_ops_context * octx) {
|
||||
case HTP_OP_SOLVE_TRI:
|
||||
return op_solve_tri(octx);
|
||||
|
||||
case HTP_OP_PAD:
|
||||
return op_pad(octx);
|
||||
|
||||
case HTP_OP_GATED_DELTA_NET:
|
||||
return op_gated_delta_net(octx);
|
||||
|
||||
case HTP_OP_TRI:
|
||||
return op_tri(octx);
|
||||
|
||||
case HTP_OP_INVALID:
|
||||
break;
|
||||
|
||||
|
||||
@@ -2995,7 +2995,6 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
// is handled by HMX itself; when M < 32 fall back to HVX.
|
||||
const int m_total = (int) src1->ne[1];
|
||||
const int m_hmx = m_total & ~31; // 0 when M < 32
|
||||
|
||||
if (m_hmx == 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
@@ -3020,7 +3019,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
|
||||
if (src0->type == HTP_TYPE_F16) {
|
||||
if (is_batched) {
|
||||
hmx_matmul_w16a32_batched_params_t batch_params = {
|
||||
hmx_matmul_f16_f32_batched_params_t batch_params = {
|
||||
.dst = (float *) dst->data,
|
||||
.activation = (float *) src1->data,
|
||||
.permuted_weight = (const __fp16 *) src0->data,
|
||||
@@ -3041,15 +3040,14 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
.dst_nb2 = dst->nb[2],
|
||||
.dst_nb3 = dst->nb[3],
|
||||
};
|
||||
ret = hmx_mat_mul_permuted_w16a32_batched(octx->ctx, &batch_params);
|
||||
ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params);
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_w16a32(octx->ctx,
|
||||
ret = hmx_matmul_f16_f32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data,
|
||||
m_total, k, n, act_stride, wgt_stride);
|
||||
}
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
|
||||
ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
|
||||
m_total, k, n, (int) src0->type);
|
||||
}
|
||||
|
||||
|
||||
545
ggml/src/ggml-hexagon/htp/pad-ops.c
Normal file
545
ggml/src/ggml-hexagon/htp/pad-ops.c
Normal file
@@ -0,0 +1,545 @@
|
||||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
/* Circular wrap: maps any integer x into [0, n) */
|
||||
static inline uint32_t wrap_around(int32_t x, uint32_t n) {
|
||||
return (uint32_t)(((x % (int32_t)n) + (int32_t)n) % (int32_t)n);
|
||||
}
|
||||
|
||||
/* Decompose a flat dst row index into (i1, i2, i3) */
|
||||
static inline void pad_decompose_row(uint32_t ir, uint32_t ne1, uint32_t ne2,
|
||||
uint32_t *i1, uint32_t *i2, uint32_t *i3) {
|
||||
*i1 = ir % ne1;
|
||||
*i2 = (ir / ne1) % ne2;
|
||||
*i3 = ir / (ne1 * ne2);
|
||||
}
|
||||
|
||||
/* Return non-zero if row (i1,i2,i3) falls in the non-padded interior */
|
||||
static inline int pad_is_interior(uint32_t i1, uint32_t i2, uint32_t i3,
|
||||
int32_t lp1, int32_t rp1, uint32_t ne1,
|
||||
int32_t lp2, int32_t rp2, uint32_t ne2,
|
||||
int32_t lp3, int32_t rp3, uint32_t ne3) {
|
||||
return ((int32_t)i1 >= lp1 && (int32_t)i1 < (int32_t)ne1 - rp1) &&
|
||||
((int32_t)i2 >= lp2 && (int32_t)i2 < (int32_t)ne2 - rp2) &&
|
||||
((int32_t)i3 >= lp3 && (int32_t)i3 < (int32_t)ne3 - rp3);
|
||||
}
|
||||
|
||||
/* Compute the DDR src row pointer for a zero-pad interior row */
|
||||
static inline const uint8_t * pad_src_row_ptr(const struct htp_tensor * src,
|
||||
uint32_t i1, uint32_t i2, uint32_t i3,
|
||||
int32_t lp1, int32_t lp2, int32_t lp3) {
|
||||
return (const uint8_t *) src->data
|
||||
+ (i1 - (uint32_t)lp1) * src->nb[1]
|
||||
+ (i2 - (uint32_t)lp2) * src->nb[2]
|
||||
+ (i3 - (uint32_t)lp3) * src->nb[3];
|
||||
}
|
||||
|
||||
/* Compute the DDR src row pointer for a circular row (wrap-around indexing) */
|
||||
static inline const uint8_t * pad_circ_src_row_ptr(const struct htp_tensor * src,
|
||||
uint32_t i1, uint32_t i2, uint32_t i3,
|
||||
int32_t lp1, int32_t lp2, int32_t lp3) {
|
||||
return (const uint8_t *) src->data
|
||||
+ wrap_around((int32_t)i1 - lp1, src->ne[1]) * src->nb[1]
|
||||
+ wrap_around((int32_t)i2 - lp2, src->ne[2]) * src->nb[2]
|
||||
+ wrap_around((int32_t)i3 - lp3, src->ne[3]) * src->nb[3];
|
||||
}
|
||||
|
||||
struct htp_pad_context {
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
int32_t lp0, rp0;
|
||||
int32_t lp1, rp1;
|
||||
int32_t lp2, rp2;
|
||||
int32_t lp3, rp3;
|
||||
|
||||
uint32_t nrows_per_thread;
|
||||
uint32_t total_dst_rows;
|
||||
|
||||
size_t type_size;
|
||||
|
||||
// Row sizes for DMA kernel (populated when VTCM is available)
|
||||
size_t src_row_size;
|
||||
size_t src_row_size_aligned;
|
||||
size_t dst_row_size;
|
||||
size_t dst_row_size_aligned;
|
||||
};
|
||||
|
||||
#define htp_pad_preamble \
|
||||
const struct htp_tensor * src = octx->src[0]; \
|
||||
const struct htp_tensor * dst = octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src->ne[0]; \
|
||||
const uint32_t nb00 = src->nb[0]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3]; \
|
||||
\
|
||||
const int32_t lp0 = pctx->lp0, rp0 = pctx->rp0; \
|
||||
const int32_t lp1 = pctx->lp1, rp1 = pctx->rp1; \
|
||||
const int32_t lp2 = pctx->lp2, rp2 = pctx->rp2; \
|
||||
const int32_t lp3 = pctx->lp3, rp3 = pctx->rp3; \
|
||||
\
|
||||
const size_t type_size = pctx->type_size; \
|
||||
\
|
||||
const uint32_t row_start = pctx->nrows_per_thread * ith; \
|
||||
const uint32_t row_end = MIN(row_start + pctx->nrows_per_thread, pctx->total_dst_rows);
|
||||
|
||||
|
||||
#define htp_pad_dma_preamble \
|
||||
const size_t src_row_size = pctx->src_row_size; \
|
||||
const size_t src_row_size_aligned = pctx->src_row_size_aligned; \
|
||||
const size_t dst_row_size = pctx->dst_row_size; \
|
||||
const size_t dst_row_size_aligned = pctx->dst_row_size_aligned; \
|
||||
\
|
||||
uint8_t * src_spad_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; \
|
||||
uint8_t * dst_spad_base = octx->dst_spad.data + ith * octx->dst_spad.size_per_thread; \
|
||||
\
|
||||
dma_queue * dma = octx->ctx->dma[ith];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HVX vectorized PAD kernel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static void pad_job_per_thread_hvx(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_pad_context * pctx = (const struct htp_pad_context *) data;
|
||||
struct htp_ops_context * octx = pctx->octx;
|
||||
htp_pad_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) {
|
||||
uint32_t i1, i2, i3;
|
||||
pad_decompose_row(dst_row, ne1, ne2, &i1, &i2, &i3);
|
||||
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3;
|
||||
|
||||
const int interior = pad_is_interior(i1, i2, i3,
|
||||
lp1, rp1, ne1,
|
||||
lp2, rp2, ne2,
|
||||
lp3, rp3, ne3);
|
||||
|
||||
if (!interior) {
|
||||
hvx_splat_f32_u(dst_ptr, 0.0f, ne0);
|
||||
} else {
|
||||
const uint8_t * src_ptr = pad_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3);
|
||||
|
||||
if (lp0 > 0) {
|
||||
hvx_splat_f32_u(dst_ptr, 0.0f, (uint32_t)lp0);
|
||||
}
|
||||
|
||||
uint8_t * dst_row_start = dst_ptr + (size_t)lp0 * type_size;
|
||||
if (nb00 == type_size) {
|
||||
hvx_copy_f32_uu(dst_row_start, src_ptr, ne00);
|
||||
} else {
|
||||
for (uint32_t i = 0; i < ne00; i++) {
|
||||
memcpy(dst_row_start + i * type_size,
|
||||
src_ptr + (size_t)i * nb00,
|
||||
type_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (rp0 > 0) {
|
||||
hvx_splat_f32_u(dst_ptr + ((size_t)lp0 + ne00) * type_size, 0.0f, (uint32_t)rp0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "pad-hvx %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n",
|
||||
ith, nth,
|
||||
src->ne[0], src->ne[1], src->ne[2], src->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
row_start, row_end,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HVX + DMA PAD kernel — aligned, double-buffered
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static void pad_job_per_thread_hvx_dma(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_pad_context * pctx = (const struct htp_pad_context *) data;
|
||||
struct htp_ops_context * octx = pctx->octx;
|
||||
htp_pad_preamble;
|
||||
htp_pad_dma_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Priming phase: push 2 pairs of (dummy_dst_DMA, src_DMA) to seed the
|
||||
// double-buffer pipeline before the main loop begins.
|
||||
// -----------------------------------------------------------------------
|
||||
for (uint32_t ir = row_start, spad_idx = 0; ir < row_end && spad_idx < 2; ir++, spad_idx++) {
|
||||
uint8_t * src_spad_cur = src_spad_base + spad_idx * src_row_size_aligned;
|
||||
uint8_t * dst_spad_cur = dst_spad_base + spad_idx * dst_row_size_aligned;
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma,
|
||||
dma_make_ptr((uint8_t *)dst->data, dst_spad_cur),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
uint32_t i1, i2, i3;
|
||||
pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3);
|
||||
const int interior = pad_is_interior(i1, i2, i3,
|
||||
lp1, rp1, ne1,
|
||||
lp2, rp2, ne2,
|
||||
lp3, rp3, ne3);
|
||||
|
||||
const uint8_t * src_ptr = interior
|
||||
? pad_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3) : NULL;
|
||||
|
||||
// Interior row: real DMA (1 row) from DDR to VTCM.
|
||||
// Border row: null DMA (nrows=0)
|
||||
dma_queue_push_ddr_to_vtcm(dma,
|
||||
dma_make_ptr(src_spad_cur,
|
||||
src_ptr ? src_ptr : (const uint8_t *)src_spad_cur),
|
||||
src_row_size_aligned, src_row_size, src_ptr ? 1 : 0);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Main loop: pop completed DMAs, compute in VTCM with aligned HVX ops,
|
||||
// push dst DMA and prefetch src for the next+1 row.
|
||||
// -----------------------------------------------------------------------
|
||||
for (uint32_t ir = row_start; ir < row_end; ir++) {
|
||||
uint8_t * dst_spad_cur = (uint8_t *) dma_queue_pop(dma).src;
|
||||
uint8_t * src_spad_cur = (uint8_t *) dma_queue_pop(dma).dst;
|
||||
|
||||
uint32_t i1, i2, i3;
|
||||
pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3);
|
||||
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3;
|
||||
|
||||
const int interior = pad_is_interior(i1, i2, i3,
|
||||
lp1, rp1, ne1,
|
||||
lp2, rp2, ne2,
|
||||
lp3, rp3, ne3);
|
||||
|
||||
if (!interior) {
|
||||
hvx_splat_f32_a(dst_spad_cur, 0.0f, ne0);
|
||||
} else {
|
||||
hvx_splat_f32_a(dst_spad_cur, 0.0f, ne0);
|
||||
|
||||
uint8_t * dst_interior = dst_spad_cur + (size_t)lp0 * type_size;
|
||||
|
||||
if ((uintptr_t)dst_interior % VLEN == 0) {
|
||||
hvx_copy_f32_aa(dst_interior, src_spad_cur, ne00);
|
||||
} else {
|
||||
hvx_copy_f32_ua(dst_interior, src_spad_cur, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma,
|
||||
dma_make_ptr(dst_ptr, dst_spad_cur),
|
||||
dst_row_size, dst_row_size_aligned, 1);
|
||||
|
||||
const uint32_t next_row = ir + 2;
|
||||
if (next_row < row_end) {
|
||||
uint32_t ni1, ni2, ni3;
|
||||
pad_decompose_row(next_row, ne1, ne2, &ni1, &ni2, &ni3);
|
||||
const int next_interior = pad_is_interior(ni1, ni2, ni3,
|
||||
lp1, rp1, ne1,
|
||||
lp2, rp2, ne2,
|
||||
lp3, rp3, ne3);
|
||||
const uint8_t * next_src_ptr = next_interior
|
||||
? pad_src_row_ptr(src, ni1, ni2, ni3, lp1, lp2, lp3) : NULL;
|
||||
|
||||
dma_queue_push_ddr_to_vtcm(dma,
|
||||
dma_make_ptr(src_spad_cur,
|
||||
next_src_ptr ? next_src_ptr : (const uint8_t *)src_spad_cur),
|
||||
src_row_size_aligned, src_row_size, next_src_ptr ? 1 : 0);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "pad-hvx-dma %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n",
|
||||
ith, nth,
|
||||
src->ne[0], src->ne[1], src->ne[2], src->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
row_start, row_end,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HVX circular PAD kernel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static void pad_job_per_thread_hvx_circular(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_pad_context * pctx = (const struct htp_pad_context *) data;
|
||||
struct htp_ops_context * octx = pctx->octx;
|
||||
htp_pad_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) {
|
||||
uint32_t i1, i2, i3;
|
||||
pad_decompose_row(dst_row, ne1, ne2, &i1, &i2, &i3);
|
||||
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3;
|
||||
const uint8_t * src_row = pad_circ_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3);
|
||||
|
||||
if (nb00 == type_size) {
|
||||
|
||||
if (lp0 > 0) {
|
||||
if ((uint32_t)lp0 < 32) {
|
||||
memcpy(dst_ptr,
|
||||
src_row + (size_t)(ne00 - (uint32_t)lp0) * type_size,
|
||||
(size_t)lp0 * type_size);
|
||||
} else {
|
||||
hvx_copy_f32_uu(dst_ptr,
|
||||
src_row + (size_t)(ne00 - (uint32_t)lp0) * type_size,
|
||||
(uint32_t)lp0);
|
||||
}
|
||||
}
|
||||
hvx_copy_f32_uu(dst_ptr + (size_t)lp0 * type_size, src_row, ne00);
|
||||
if (rp0 > 0) {
|
||||
if ((uint32_t)rp0 < 32) {
|
||||
memcpy(dst_ptr + ((size_t)lp0 + ne00) * type_size,
|
||||
src_row,
|
||||
(size_t)rp0 * type_size);
|
||||
} else {
|
||||
hvx_copy_f32_uu(dst_ptr + ((size_t)lp0 + ne00) * type_size,
|
||||
src_row,
|
||||
(uint32_t)rp0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint32_t i = 0; i < (uint32_t)lp0; i++) {
|
||||
*(float *)(dst_ptr + i * type_size) =
|
||||
*(const float *)(src_row + (size_t)(ne00 - (uint32_t)lp0 + i) * nb00);
|
||||
}
|
||||
for (uint32_t i = 0; i < ne00; i++) {
|
||||
*(float *)(dst_ptr + ((size_t)lp0 + i) * type_size) =
|
||||
*(const float *)(src_row + (size_t)i * nb00);
|
||||
}
|
||||
for (uint32_t i = 0; i < (uint32_t)rp0; i++) {
|
||||
*(float *)(dst_ptr + ((size_t)lp0 + ne00 + i) * type_size) =
|
||||
*(const float *)(src_row + (size_t)i * nb00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "pad-hvx-circ %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n",
|
||||
ith, nth,
|
||||
src->ne[0], src->ne[1], src->ne[2], src->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
row_start, row_end,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HVX + DMA circular PAD kernel — aligned, double-buffered
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static void pad_job_per_thread_hvx_circular_dma(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_pad_context * pctx = (const struct htp_pad_context *) data;
|
||||
struct htp_ops_context * octx = pctx->octx;
|
||||
htp_pad_preamble;
|
||||
htp_pad_dma_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Priming phase: push 2 pairs of (dummy_dst_DMA, src_DMA) to seed the
|
||||
// double-buffer pipeline. Every row is a real src DMA (no null DMAs).
|
||||
// -----------------------------------------------------------------------
|
||||
for (uint32_t ir = row_start, spad_idx = 0; ir < row_end && spad_idx < 2; ir++, spad_idx++) {
|
||||
uint8_t * src_spad_cur = src_spad_base + spad_idx * src_row_size_aligned;
|
||||
uint8_t * dst_spad_cur = dst_spad_base + spad_idx * dst_row_size_aligned;
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma,
|
||||
dma_make_ptr((uint8_t *)dst->data, dst_spad_cur),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
uint32_t pi1, pi2, pi3;
|
||||
pad_decompose_row(ir, ne1, ne2, &pi1, &pi2, &pi3);
|
||||
dma_queue_push_ddr_to_vtcm(dma,
|
||||
dma_make_ptr(src_spad_cur, pad_circ_src_row_ptr(src, pi1, pi2, pi3, lp1, lp2, lp3)),
|
||||
src_row_size_aligned, src_row_size, 1);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Main loop: pop completed DMAs, assemble circular row in VTCM with
|
||||
// aligned HVX ops, push dst DMA and prefetch src for the next+1 row.
|
||||
// -----------------------------------------------------------------------
|
||||
for (uint32_t ir = row_start; ir < row_end; ir++) {
|
||||
uint8_t * dst_spad_cur = (uint8_t *) dma_queue_pop(dma).src;
|
||||
uint8_t * src_spad_cur = (uint8_t *) dma_queue_pop(dma).dst;
|
||||
|
||||
uint32_t i1, i2, i3;
|
||||
pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3);
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3;
|
||||
|
||||
|
||||
if (lp0 > 0) {
|
||||
uint8_t * dst_left = dst_spad_cur;
|
||||
const uint8_t * src_left = src_spad_cur + (size_t)(ne00 - (uint32_t)lp0) * type_size;
|
||||
if ((uint32_t)lp0 < 32) {
|
||||
memcpy(dst_left, src_left, (size_t)lp0 * type_size);
|
||||
} else {
|
||||
hvx_copy_f32_uu(dst_left, src_left, (uint32_t)lp0);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
uint8_t * dst_mid = dst_spad_cur + (size_t)lp0 * type_size;
|
||||
if ((uintptr_t)dst_mid % VLEN == 0) {
|
||||
hvx_copy_f32_aa(dst_mid, src_spad_cur, ne00);
|
||||
} else {
|
||||
hvx_copy_f32_ua(dst_mid, src_spad_cur, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
if (rp0 > 0) {
|
||||
uint8_t * dst_right = dst_spad_cur + ((size_t)lp0 + ne00) * type_size;
|
||||
if ((uint32_t)rp0 < 32) {
|
||||
memcpy(dst_right, src_spad_cur, (size_t)rp0 * type_size);
|
||||
} else {
|
||||
if ((uintptr_t)dst_right % VLEN == 0) {
|
||||
hvx_copy_f32_aa(dst_right, src_spad_cur, (uint32_t)rp0);
|
||||
} else {
|
||||
hvx_copy_f32_ua(dst_right, src_spad_cur, (uint32_t)rp0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma,
|
||||
dma_make_ptr(dst_ptr, dst_spad_cur),
|
||||
dst_row_size, dst_row_size_aligned, 1);
|
||||
|
||||
const uint32_t next_row = ir + 2;
|
||||
if (next_row < row_end) {
|
||||
uint32_t nri1, nri2, nri3;
|
||||
pad_decompose_row(next_row, ne1, ne2, &nri1, &nri2, &nri3);
|
||||
dma_queue_push_ddr_to_vtcm(dma,
|
||||
dma_make_ptr(src_spad_cur,
|
||||
pad_circ_src_row_ptr(src, nri1, nri2, nri3, lp1, lp2, lp3)),
|
||||
src_row_size_aligned, src_row_size, 1);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "pad-hvx-circ-dma %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n",
|
||||
ith, nth,
|
||||
src->ne[0], src->ne[1], src->ne[2], src->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
row_start, row_end,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
int op_pad(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
// Only F32 supported
|
||||
size_t type_size;
|
||||
switch (src0->type) {
|
||||
case HTP_TYPE_F32: type_size = 4; break;
|
||||
default:
|
||||
FARF(ERROR, "pad-hvx: unsupported type %u\n", src0->type);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
const int32_t lp0 = octx->op_params[0];
|
||||
const int32_t rp0 = octx->op_params[1];
|
||||
const int32_t lp1 = octx->op_params[2];
|
||||
const int32_t rp1 = octx->op_params[3];
|
||||
const int32_t lp2 = octx->op_params[4];
|
||||
const int32_t rp2 = octx->op_params[5];
|
||||
const int32_t lp3 = octx->op_params[6];
|
||||
const int32_t rp3 = octx->op_params[7];
|
||||
const int32_t circular = octx->op_params[8];
|
||||
|
||||
const uint32_t ne0 = dst->ne[0];
|
||||
const uint32_t ne00 = src0->ne[0];
|
||||
|
||||
const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows > 0 ? total_dst_rows : 1);
|
||||
|
||||
const size_t src_row_size = (size_t)ne00 * type_size;
|
||||
const size_t dst_row_size = (size_t)ne0 * type_size;
|
||||
const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
// Total VTCM needed: 2 buffers (ping+pong) for src and dst, per thread
|
||||
const size_t vtcm_needed = (size_t)n_threads * 2 * (src_row_size_aligned + dst_row_size_aligned);
|
||||
|
||||
const int use_dma = (src0->nb[0] == (uint32_t)type_size) &&
|
||||
(ne00 >= 512) &&
|
||||
(octx->ctx->vtcm_base != NULL) &&
|
||||
(octx->ctx->vtcm_size >= vtcm_needed);
|
||||
|
||||
if (use_dma) {
|
||||
octx->src0_spad.size_per_thread = 2 * src_row_size_aligned;
|
||||
octx->dst_spad.size_per_thread = 2 * dst_row_size_aligned;
|
||||
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
||||
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
}
|
||||
|
||||
struct htp_pad_context pctx = {
|
||||
.octx = octx,
|
||||
.lp0 = lp0, .rp0 = rp0,
|
||||
.lp1 = lp1, .rp1 = rp1,
|
||||
.lp2 = lp2, .rp2 = rp2,
|
||||
.lp3 = lp3, .rp3 = rp3,
|
||||
.nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads,
|
||||
.total_dst_rows = total_dst_rows,
|
||||
.type_size = type_size,
|
||||
.src_row_size = src_row_size,
|
||||
.src_row_size_aligned = src_row_size_aligned,
|
||||
.dst_row_size = dst_row_size,
|
||||
.dst_row_size_aligned = dst_row_size_aligned,
|
||||
};
|
||||
|
||||
FARF(HIGH, "pad-hvx%s%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) pads=(%d,%d,%d,%d,%d,%d,%d,%d)\n",
|
||||
circular ? "-circ" : "",
|
||||
use_dma ? "-dma" : "",
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||
|
||||
if (circular && use_dma) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_circular_dma, &pctx, n_threads); }
|
||||
else if (circular) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_circular, &pctx, n_threads); }
|
||||
else if (use_dma) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_dma, &pctx, n_threads); }
|
||||
else { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx, &pctx, n_threads); }
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
@@ -18,9 +18,11 @@
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h
|
||||
// Redefined the rope type constants as we can't include ggml.h
|
||||
#define HTP_ROPE_TYPE_NORMAL 0
|
||||
#define HTP_ROPE_TYPE_NEOX 2
|
||||
#define HTP_ROPE_TYPE_MROPE 8
|
||||
#define HTP_ROPE_TYPE_IMROPE 40
|
||||
|
||||
#define HTP_ROPE_SPAD_NROWS 16
|
||||
#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2)
|
||||
@@ -82,7 +84,30 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
return (1 - MIN(1, MAX(0, y)));
|
||||
}
|
||||
|
||||
static void rope_cache_init(const float theta_base,
|
||||
// Compute one (cos, sin) pair into cache[i0], cache[i0+1] applying YaRN scaling.
|
||||
static inline void rope_yarn_one(float theta, float freq_scale, float * corr_dims,
|
||||
uint32_t i0, float ext_factor, float mscale,
|
||||
float * cache) {
|
||||
float theta_extrap = theta;
|
||||
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta_final = theta_interp;
|
||||
float mscale_final = mscale;
|
||||
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||
theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||
}
|
||||
|
||||
cache[i0 + 0] = cosf(theta_final) * mscale_final;
|
||||
cache[i0 + 1] = sinf(theta_final) * mscale_final;
|
||||
}
|
||||
|
||||
static __attribute__((noinline)) void rope_cache_init(const float theta_base,
|
||||
const float freq_scale,
|
||||
const float * freq_factors,
|
||||
float * corr_dims,
|
||||
@@ -96,29 +121,65 @@ static void rope_cache_init(const float theta_base,
|
||||
|
||||
for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
|
||||
|
||||
float theta_extrap = theta / ff;
|
||||
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta_final = theta_interp;
|
||||
float mscale_final = mscale;
|
||||
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||
theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||
}
|
||||
|
||||
cache[i0 + 0] = cosf(theta_final) * mscale_final;
|
||||
cache[i0 + 1] = sinf(theta_final) * mscale_final;
|
||||
rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache);
|
||||
|
||||
theta *= theta_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra).
|
||||
// sections[4]: number of head dims assigned to each position component.
|
||||
static __attribute__((noinline)) void mrope_cache_init(const float pos_t,
|
||||
const float pos_h,
|
||||
const float pos_w,
|
||||
const float pos_e,
|
||||
const int32_t sections[4],
|
||||
const bool is_imrope,
|
||||
const float freq_scale,
|
||||
const float * freq_factors,
|
||||
float * corr_dims,
|
||||
const uint32_t ne0,
|
||||
const float ext_factor,
|
||||
const float mscale,
|
||||
float * cache,
|
||||
const float theta_scale) {
|
||||
const int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
||||
const int sec_w = sections[0] + sections[1];
|
||||
const int sec_e = sec_w + sections[2];
|
||||
|
||||
float theta_t = pos_t;
|
||||
float theta_h = pos_h;
|
||||
float theta_w = pos_w;
|
||||
float theta_e = pos_e;
|
||||
|
||||
for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta;
|
||||
if (is_imrope) {
|
||||
// Interleaved: sector mod 3 selects component
|
||||
if (sector % 3 == 0 && sector < 3 * sections[0]) { theta = theta_t; }
|
||||
else if (sector % 3 == 1 && sector < 3 * sections[1]) { theta = theta_h; }
|
||||
else if (sector % 3 == 2 && sector < 3 * sections[2]) { theta = theta_w; }
|
||||
else { theta = theta_e; }
|
||||
} else {
|
||||
// Contiguous sections
|
||||
if (sector < sections[0]) { theta = theta_t; }
|
||||
else if (sector < sec_w) { theta = theta_h; }
|
||||
else if (sector < sec_e) { theta = theta_w; }
|
||||
else { theta = theta_e; }
|
||||
}
|
||||
|
||||
rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache);
|
||||
|
||||
theta_t *= theta_scale;
|
||||
theta_h *= theta_scale;
|
||||
theta_w *= theta_scale;
|
||||
theta_e *= theta_scale;
|
||||
}
|
||||
}
|
||||
|
||||
#define M_PI 3.1415926535897932384626433
|
||||
|
||||
static void rope_corr_dims(int n_dims,
|
||||
@@ -274,7 +335,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
uint64_t tt = HAP_perf_get_qtimer_count();
|
||||
|
||||
const int32_t mode = rctx->mode;
|
||||
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
|
||||
// MROPE and IMROPE use NEOX-style pairing for the rotation
|
||||
const bool is_neox = (mode & HTP_ROPE_TYPE_NEOX) || (mode & HTP_ROPE_TYPE_MROPE);
|
||||
|
||||
// VTCM setup
|
||||
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
|
||||
@@ -326,8 +388,25 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
if (i2 != prev_i2) {
|
||||
prev_i2 = i2;
|
||||
|
||||
const int32_t p = pos[i2];
|
||||
rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale);
|
||||
const bool is_mrope = (rctx->mode & HTP_ROPE_TYPE_MROPE) != 0;
|
||||
if (is_mrope) {
|
||||
// src1 holds four position arrays stacked along ne0:
|
||||
// pos[i2], pos[i2+ne2], pos[i2+ne2*2], pos[i2+ne2*3]
|
||||
const bool is_imrope = (rctx->mode == HTP_ROPE_TYPE_IMROPE);
|
||||
mrope_cache_init(
|
||||
(float) pos[i2],
|
||||
(float) pos[i2 + ne2],
|
||||
(float) pos[i2 + ne2 * 2],
|
||||
(float) pos[i2 + ne2 * 3],
|
||||
rctx->sections, is_imrope,
|
||||
rctx->freq_scale, freq_factors, rctx->corr_dims,
|
||||
ne0, rctx->ext_factor, rctx->attn_factor,
|
||||
theta_cache, rctx->theta_scale);
|
||||
} else {
|
||||
rope_cache_init(pos[i2], rctx->freq_scale, freq_factors, rctx->corr_dims,
|
||||
ne0, rctx->ext_factor, rctx->attn_factor,
|
||||
theta_cache, rctx->theta_scale);
|
||||
}
|
||||
|
||||
// FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
|
||||
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
|
||||
|
||||
@@ -20,55 +20,56 @@
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define htp_ssm_conv_tensors_preamble \
|
||||
const struct htp_tensor * restrict src0 = octx->src[0]; \
|
||||
const struct htp_tensor * restrict src1 = octx->src[1]; \
|
||||
const struct htp_tensor * restrict dst = octx->dst; \
|
||||
struct htp_spad * restrict src0_spad = &octx->src0_spad; \
|
||||
struct htp_spad * restrict src1_spad = &octx->src1_spad; \
|
||||
struct htp_spad * restrict dst_spad = &octx->dst_spad; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1->ne[0]; \
|
||||
const uint32_t ne11 = src1->ne[1]; \
|
||||
const uint32_t ne12 = src1->ne[2]; \
|
||||
const uint32_t ne13 = src1->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = src1->nb[0]; \
|
||||
const uint32_t nb11 = src1->nb[1]; \
|
||||
const uint32_t nb12 = src1->nb[2]; \
|
||||
const uint32_t nb13 = src1->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
#define htp_ssm_conv_tensors_preamble \
|
||||
const struct htp_tensor * restrict src0 = octx->src[0]; \
|
||||
const struct htp_tensor * restrict src1 = octx->src[1]; \
|
||||
const struct htp_tensor * restrict dst = octx->dst; \
|
||||
struct htp_spad * restrict src0_spad = &octx->src0_spad; \
|
||||
struct htp_spad * restrict src1_spad = &octx->src1_spad; \
|
||||
struct htp_spad * restrict dst_spad = &octx->dst_spad; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1->ne[0]; \
|
||||
const uint32_t ne11 = src1->ne[1]; \
|
||||
const uint32_t ne12 = src1->ne[2]; \
|
||||
const uint32_t ne13 = src1->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = src1->nb[0]; \
|
||||
const uint32_t nb11 = src1->nb[1]; \
|
||||
const uint32_t nb12 = src1->nb[2]; \
|
||||
const uint32_t nb13 = src1->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct htp_ssm_conv_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t nrows_per_thread;
|
||||
uint32_t d_inner_tile;
|
||||
uint64_t t_start;
|
||||
};
|
||||
|
||||
#define htp_ssm_conv_preamble \
|
||||
#define htp_ssm_conv_preamble \
|
||||
struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \
|
||||
struct htp_ops_context * octx = scctx->octx; \
|
||||
htp_ssm_conv_tensors_preamble; \
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
struct htp_ops_context * octx = scctx->octx; \
|
||||
htp_ssm_conv_tensors_preamble; \
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
|
||||
// Scalar FP32 SSM_CONV implementation
|
||||
static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
@@ -128,118 +129,211 @@ static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension
|
||||
|
||||
// In-register 32x32 fp32 transpose using std 5-stage HVX vshuff butterfly.
|
||||
static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) {
|
||||
HVX_Vector tmp[32];
|
||||
|
||||
// Stage 0 (R = -4): pair (2i, 2i+1) for i = 0..15. m -> tmp.
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
HVX_VectorPair p = Q6_W_vshuff_VVR(m[2*i + 1], m[2*i], -4);
|
||||
tmp[2*i + 0] = Q6_V_lo_W(p);
|
||||
tmp[2*i + 1] = Q6_V_hi_W(p);
|
||||
}
|
||||
|
||||
// Stage 1 (R = -8): per block of 4, pair (b+0, b+2) and (b+1, b+3). tmp -> m.
|
||||
for (int b = 0; b < 32; b += 4) {
|
||||
HVX_VectorPair p0 = Q6_W_vshuff_VVR(tmp[b + 2], tmp[b + 0], -8);
|
||||
HVX_VectorPair p1 = Q6_W_vshuff_VVR(tmp[b + 3], tmp[b + 1], -8);
|
||||
m[b + 0] = Q6_V_lo_W(p0); m[b + 1] = Q6_V_hi_W(p0);
|
||||
m[b + 2] = Q6_V_lo_W(p1); m[b + 3] = Q6_V_hi_W(p1);
|
||||
}
|
||||
|
||||
// Stage 2 (R = -16): per block of 8, pair (b+i, b+i+4) for i = 0..3. m -> tmp.
|
||||
for (int b = 0; b < 32; b += 8) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
HVX_VectorPair p = Q6_W_vshuff_VVR(m[b + i + 4], m[b + i], -16);
|
||||
tmp[b + 2*i + 0] = Q6_V_lo_W(p);
|
||||
tmp[b + 2*i + 1] = Q6_V_hi_W(p);
|
||||
}
|
||||
}
|
||||
|
||||
// Stage 3 (R = -32): per block of 16, pair (b+i, b+i+8) for i = 0..7. tmp -> m.
|
||||
for (int b = 0; b < 32; b += 16) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
HVX_VectorPair p = Q6_W_vshuff_VVR(tmp[b + i + 8], tmp[b + i], -32);
|
||||
m[b + 2*i + 0] = Q6_V_lo_W(p);
|
||||
m[b + 2*i + 1] = Q6_V_hi_W(p);
|
||||
}
|
||||
}
|
||||
|
||||
// Stage 4 (R = -64): pair (i, i+16) for i = 0..15. m -> tmp -> m.
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
HVX_VectorPair p = Q6_W_vshuff_VVR(m[i + 16], m[i], -64);
|
||||
tmp[2 * i + 0] = Q6_V_lo_W(p);
|
||||
tmp[2 * i + 1] = Q6_V_hi_W(p);
|
||||
}
|
||||
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
m[i] = tmp[i];
|
||||
}
|
||||
}
|
||||
|
||||
// HVX FP32 SSM_CONV implementation - channel-vectorized HVX kernel with src0/src1
|
||||
// transposed into VTCM.
|
||||
//
|
||||
// VTCM layouts (per thread):
|
||||
// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small).
|
||||
// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile.
|
||||
//
|
||||
// d_inner_tile is chosen so that per-thread VTCM stays under the budget.
|
||||
// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially.
|
||||
#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread
|
||||
|
||||
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM)
|
||||
static inline void transpose_src1(const float * src1_data,
|
||||
uint32_t src1_stride_inner,
|
||||
uint32_t i1_off,
|
||||
uint32_t d_inner_per_thread,
|
||||
uint32_t d_conv,
|
||||
float * src1_T) {
|
||||
for (uint32_t i = 0; i < d_inner_per_thread; ++i) {
|
||||
const float * src_row = src1_data + (i1_off + i) * src1_stride_inner;
|
||||
for (uint32_t j = 0; j < d_conv; ++j) {
|
||||
src1_T[j * d_inner_per_thread + i] = src_row[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HVX 32x32 src0 transpose: src0 {ncs, d_inner} (DDR) -> src0_T {d_inner_tile, ncs} (VTCM)
|
||||
static inline void transpose_src0_block(const float * src0_block,
|
||||
uint32_t ncs,
|
||||
uint32_t cb_n,
|
||||
uint32_t d_inner_tile,
|
||||
float * src0_T_block_dst,
|
||||
uint32_t cb /* dst column offset */) {
|
||||
const uint32_t T_TILE = VLEN_FP32;
|
||||
|
||||
HVX_Vector __attribute__((aligned(VLEN))) sub[32];
|
||||
|
||||
for (uint32_t t0 = 0; t0 < ncs; t0 += T_TILE) {
|
||||
const uint32_t t_n = MIN(T_TILE, ncs - t0);
|
||||
|
||||
// Load 32 rows (channels) of T_TILE samples; pad missing channels with zeros.
|
||||
for (uint32_t r = 0; r < cb_n; ++r) {
|
||||
const float * src_row = src0_block + r * ncs + t0;
|
||||
if (t_n == T_TILE) {
|
||||
sub[r] = *(const HVX_UVector *) src_row;
|
||||
} else {
|
||||
HVX_Vector v = hvx_vec_splat_f32(0.0f);
|
||||
hvx_vec_store_u(&v, t_n * sizeof(float), hvx_vec_splat_f32(0.0f));
|
||||
|
||||
float __attribute__((aligned(VLEN))) tmp[VLEN_FP32] = { 0 };
|
||||
for (uint32_t k = 0; k < t_n; ++k) tmp[k] = src_row[k];
|
||||
v = *(const HVX_Vector *) tmp;
|
||||
sub[r] = v;
|
||||
}
|
||||
}
|
||||
for (uint32_t r = cb_n; r < T_TILE; ++r) {
|
||||
sub[r] = hvx_vec_splat_f32(0.0f);
|
||||
}
|
||||
|
||||
hvx_transpose_32x32_f32(sub);
|
||||
|
||||
// Store transposed sub-tile to src0_T at offsets (t0 + j) * d_inner_tile + cb.
|
||||
// Only write the valid t_n rows of the transposed result.
|
||||
for (uint32_t r = 0; r < t_n; ++r) {
|
||||
float * dst = src0_T_block_dst + (t0 + r) * d_inner_tile + cb;
|
||||
if (cb_n == T_TILE) {
|
||||
*(HVX_UVector *) dst = sub[r];
|
||||
} else {
|
||||
hvx_vec_store_u(dst, cb_n * sizeof(float), sub[r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) {
|
||||
htp_ssm_conv_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const int nc = src1->ne[0]; // d_conv
|
||||
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
|
||||
|
||||
const uint32_t d_conv = src1->ne[0];
|
||||
const uint32_t d_inner = src0->ne[1];
|
||||
const uint32_t n_t = dst->ne[1];
|
||||
const uint32_t n_s = dst->ne[2];
|
||||
const uint32_t ncs = src0->ne[0];
|
||||
|
||||
const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float);
|
||||
const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float);
|
||||
const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float);
|
||||
const uint32_t dst_stride_token = dst->nb[1] / sizeof(float);
|
||||
const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float);
|
||||
|
||||
const uint32_t dr = scctx->nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = MIN(ir0 + dr, d_inner);
|
||||
|
||||
if (ir0 >= ir1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t d_inner_per_thread = ir1 - ir0;
|
||||
const uint32_t d_inner_tile = scctx->d_inner_tile;
|
||||
|
||||
const float * src0_data = (const float *) src0->data;
|
||||
const float * src1_data = (const float *) src1->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
// Calculate row range for this thread
|
||||
const int dr = scctx->nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = MIN(ir0 + dr, d_inner);
|
||||
const uint32_t ir = ir1 - ir0;
|
||||
// Per-thread VTCM regions.
|
||||
float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread);
|
||||
float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread);
|
||||
|
||||
if (ir0 >= ir1) {
|
||||
return; // No work for this thread
|
||||
}
|
||||
// Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout.
|
||||
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T);
|
||||
|
||||
// src0 and src1 gather offsets
|
||||
uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 };
|
||||
uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 };
|
||||
|
||||
for (uint32_t i = 0; i < VLEN_FP32; ++i) {
|
||||
src0_offsets[i] = i * (ncs) * sizeof(float);
|
||||
src1_offsets[i] = i * (d_conv) * sizeof(float);
|
||||
}
|
||||
|
||||
const uint32_t src0_gather_len = VLEN * ncs;
|
||||
const uint32_t src1_gather_len = VLEN * d_conv;
|
||||
|
||||
// gather scratchpads
|
||||
HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0);
|
||||
HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN);
|
||||
|
||||
float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]);
|
||||
float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]);
|
||||
|
||||
uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread;
|
||||
uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread;
|
||||
|
||||
// copy src1 workload to VTCM
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir);
|
||||
|
||||
// FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir);
|
||||
const uint32_t C_TILE = VLEN_FP32;
|
||||
|
||||
for (uint32_t i3 = 0; i3 < n_s; ++i3) {
|
||||
float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2]));
|
||||
for (uint32_t tile_off = 0; tile_off < d_inner_per_thread; tile_off += d_inner_tile) {
|
||||
const uint32_t tile_n = MIN(d_inner_tile, d_inner_per_thread - tile_off);
|
||||
|
||||
// copy src0 workload to VTCM
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir);
|
||||
// Place src0 chunk into VTCM in {d_inner_tile, ncs} layout.
|
||||
const float * src0_block = src0_data + i3 * src0_stride_seq + (ir0 + tile_off) * src0_stride_inner;
|
||||
|
||||
// FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir);
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
for (uint32_t i2 = 0; i2 < n_t; ++i2) {
|
||||
float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2]));
|
||||
|
||||
const uint32_t nvec = ir / VLEN_FP32;
|
||||
const uint32_t nloe = ir % VLEN_FP32;
|
||||
uint32_t i1 = 0;
|
||||
|
||||
for (uint32_t vi1 = 0; vi1 < nvec; vi1++) {
|
||||
HVX_Vector acc_vec = Q6_V_vsplat_R(0);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
|
||||
uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]);
|
||||
uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float);
|
||||
Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets));
|
||||
Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets));
|
||||
|
||||
HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
|
||||
acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
|
||||
}
|
||||
|
||||
*(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec);
|
||||
i1 += VLEN_FP32;
|
||||
for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) {
|
||||
const uint32_t cb_n = MIN(C_TILE, tile_n - cb);
|
||||
transpose_src0_block(src0_block + cb * src0_stride_inner, ncs, cb_n, d_inner_tile, src0_T, cb);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector acc_vec = Q6_V_vsplat_R(0);
|
||||
for (uint32_t t = 0; t < n_t; ++t) {
|
||||
for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) {
|
||||
const uint32_t cb_n = MIN(C_TILE, tile_n - cb);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
|
||||
uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]);
|
||||
uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float);
|
||||
Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets));
|
||||
Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets));
|
||||
HVX_Vector acc = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t j = 0; j < d_conv; ++j) {
|
||||
HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb);
|
||||
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb);
|
||||
acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w));
|
||||
}
|
||||
HVX_Vector res = Q6_Vsf_equals_Vqf32(acc);
|
||||
|
||||
HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
|
||||
acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
|
||||
float * dst_ptr = dst_data + i3 * dst_stride_seq + t * dst_stride_token + (ir0 + tile_off + cb);
|
||||
if (cb_n == C_TILE) {
|
||||
*(HVX_UVector *) dst_ptr = res;
|
||||
} else {
|
||||
hvx_vec_store_u(dst_ptr, cb_n * sizeof(float), res);
|
||||
}
|
||||
}
|
||||
|
||||
hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
|
||||
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,
|
||||
FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) tile=%u * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
|
||||
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, d_inner_tile,
|
||||
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1],
|
||||
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
@@ -264,46 +358,44 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t use_hvx = 0;
|
||||
if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) {
|
||||
int is_aligned = hex_is_aligned((void *) src0->data, VLEN) &&
|
||||
hex_is_aligned((void *) src1->data, VLEN) &&
|
||||
hex_is_aligned((void *) dst->data, VLEN);
|
||||
|
||||
if (is_aligned) {
|
||||
use_hvx = 1;
|
||||
}
|
||||
if (d_inner >= VLEN_FP32 && n_t >= VLEN_FP32) {
|
||||
use_hvx = 1;
|
||||
}
|
||||
|
||||
if (use_hvx) {
|
||||
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread
|
||||
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even
|
||||
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads;
|
||||
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1);
|
||||
|
||||
octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256);
|
||||
const uint32_t d_inner_per_thread = scctx.nrows_per_thread;
|
||||
const uint32_t ncs = src0->ne[0];
|
||||
|
||||
const uint32_t src1_T_size = hex_round_up(d_conv * d_inner_per_thread * sizeof(float), 256);
|
||||
const uint32_t src0_T_max = HTP_SSM_CONV_VTCM_BUDGET > src1_T_size ? HTP_SSM_CONV_VTCM_BUDGET - src1_T_size : 0;
|
||||
|
||||
uint32_t d_inner_tile = (src0_T_max / sizeof(float)) / ncs;
|
||||
d_inner_tile -= (d_inner_tile % VLEN_FP32);
|
||||
if (d_inner_tile == 0) {
|
||||
FARF(HIGH, "ssm_conv-f32: inner tile rounds to 0 (ncs=%u), falling back to scalar\n", ncs);
|
||||
use_hvx = 0;
|
||||
} else {
|
||||
scctx.d_inner_tile = d_inner_tile;
|
||||
|
||||
octx->src0_spad.size_per_thread = hex_round_up(d_inner_tile * ncs * sizeof(float), 256);
|
||||
octx->src1_spad.size_per_thread = src1_T_size;
|
||||
octx->dst_spad.size_per_thread = 0;
|
||||
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads;
|
||||
octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads;
|
||||
octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads;
|
||||
octx->dst_spad.size = 0;
|
||||
|
||||
// Compute gather scratchpad size for src0 and src1
|
||||
const size_t gather_spad_size = n_threads * VLEN * 2;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.src = NULL;
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n",
|
||||
gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread,
|
||||
octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size,
|
||||
octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data);
|
||||
|
||||
const size_t total_spad_size =
|
||||
gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (total_spad_size > octx->ctx->vtcm_size) {
|
||||
FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size,
|
||||
octx->ctx->vtcm_size);
|
||||
const size_t total_spad = octx->src0_spad.size + octx->src1_spad.size;
|
||||
if (total_spad > octx->ctx->vtcm_size) {
|
||||
FARF(HIGH, "ssm_conv-f32: scratchpad %zu exceeds VTCM %zu, falling back to scalar\n",
|
||||
total_spad, octx->ctx->vtcm_size);
|
||||
use_hvx = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
struct htp_unary_context {
|
||||
struct htp_ops_context * octx;
|
||||
@@ -159,6 +158,79 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_fast_norm_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict dst,
|
||||
uint8_t * restrict pad,
|
||||
const int num_elems,
|
||||
float epsilon) {
|
||||
(void)pad;
|
||||
|
||||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
const int nvec = num_elems / VLEN_FP32; // number of full vectors
|
||||
const int nloe = num_elems % VLEN_FP32; // leftover elements
|
||||
|
||||
// Compute sum of squares and sum of values for full vectors
|
||||
HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
|
||||
sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
|
||||
}
|
||||
|
||||
// Handle tail elements using vectorized ops with masking
|
||||
if (nloe > 0) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
|
||||
sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
|
||||
}
|
||||
|
||||
// Reduce HVX sums
|
||||
sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v));
|
||||
sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v));
|
||||
|
||||
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
||||
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
||||
HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v);
|
||||
HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v);
|
||||
HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v));
|
||||
HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v);
|
||||
HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v);
|
||||
|
||||
// scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction
|
||||
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v));
|
||||
HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v)));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
|
||||
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
|
||||
v_dst[i] = Q6_Vsf_equals_Vqf32(v3);
|
||||
}
|
||||
|
||||
// Handle tail elements using vectorized ops with masking
|
||||
if (nloe > 0) {
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
|
||||
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
|
||||
HVX_Vector result = Q6_Vsf_equals_Vqf32(v3);
|
||||
|
||||
// Store with masking to avoid overwriting memory beyond the tensor
|
||||
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
||||
}
|
||||
}
|
||||
|
||||
static void scale_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
@@ -197,6 +269,24 @@ static void rms_norm_f32(const float * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
float epsilon = 0.f;
|
||||
memcpy(&epsilon, op_params, sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
static void sqr_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
@@ -277,6 +367,95 @@ static void sigmoid_f32(const float * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static void tri_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
const uint32_t ir,
|
||||
const struct htp_unary_context * uctx) {
|
||||
|
||||
const int32_t ttype = op_params[0];
|
||||
const HVX_Vector zero = hvx_vec_splat_f32(0.0f);
|
||||
const uint32_t nvec = row_elems / VLEN_FP32;
|
||||
const uint32_t nloe = row_elems % VLEN_FP32;
|
||||
|
||||
const uint32_t ne01 = uctx->octx->src[0]->ne[1];
|
||||
|
||||
for (uint32_t b = 0; b < num_rows; b++) {
|
||||
const uint32_t abs_row = ir + b;
|
||||
const uint32_t i01 = abs_row % ne01;
|
||||
|
||||
const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size);
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size);
|
||||
|
||||
uint32_t boundary;
|
||||
int keep_left;
|
||||
switch (ttype) {
|
||||
case 0: boundary = i01; keep_left = 0; break; // keep col >= row
|
||||
case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row
|
||||
case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row
|
||||
case 3: boundary = i01; keep_left = 1; break; // keep col < row
|
||||
default: boundary = 0; keep_left = 0; break;
|
||||
}
|
||||
if (boundary > row_elems) boundary = row_elems;
|
||||
|
||||
// Full HVX vectors — each starts at a 128-byte aligned offset
|
||||
for (uint32_t i = 0; i < nvec; i++) {
|
||||
const uint32_t vec_start = i * VLEN_FP32;
|
||||
const uint32_t vec_end = vec_start + VLEN_FP32;
|
||||
if (keep_left) {
|
||||
if (vec_end <= boundary) {
|
||||
v_dst[i] = v_src[i];
|
||||
} else if (vec_start >= boundary) {
|
||||
v_dst[i] = zero;
|
||||
} else {
|
||||
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
||||
v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero);
|
||||
}
|
||||
} else {
|
||||
if (vec_end <= boundary) {
|
||||
v_dst[i] = zero;
|
||||
} else if (vec_start >= boundary) {
|
||||
v_dst[i] = v_src[i];
|
||||
} else {
|
||||
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
||||
v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tail elements (row_elems not a multiple of VLEN_FP32)
|
||||
if (nloe > 0) {
|
||||
const uint32_t vec_start = nvec * VLEN_FP32;
|
||||
const uint32_t vec_end = vec_start + nloe;
|
||||
HVX_Vector tail_val;
|
||||
if (keep_left) {
|
||||
if (vec_end <= boundary) {
|
||||
tail_val = v_src[nvec];
|
||||
} else if (vec_start >= boundary) {
|
||||
tail_val = zero;
|
||||
} else {
|
||||
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
||||
tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero);
|
||||
}
|
||||
} else {
|
||||
if (vec_end <= boundary) {
|
||||
tail_val = zero;
|
||||
} else if (vec_start >= boundary) {
|
||||
tail_val = v_src[nvec];
|
||||
} else {
|
||||
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
||||
tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]);
|
||||
}
|
||||
}
|
||||
hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void softplus_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
@@ -468,6 +647,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
|
||||
// Process block in VTCM
|
||||
switch (htp_op) {
|
||||
case HTP_OP_NORM:
|
||||
norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_RMS_NORM:
|
||||
rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
@@ -498,6 +680,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
case HTP_OP_L2_NORM:
|
||||
l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_TRI:
|
||||
tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -541,6 +726,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
const char * op_type = NULL;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_NORM:
|
||||
op_type = "norm-f32";
|
||||
break;
|
||||
case HTP_OP_RMS_NORM:
|
||||
op_type = "rmsnorm-f32";
|
||||
break;
|
||||
@@ -571,6 +759,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
case HTP_OP_L2_NORM:
|
||||
op_type = "l2norm-f32";
|
||||
break;
|
||||
case HTP_OP_TRI:
|
||||
op_type = "tri-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
@@ -640,6 +832,22 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
return err;
|
||||
}
|
||||
|
||||
int op_tri(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src[0]->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_unary_f32(octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
err = HTP_STATUS_NO_SUPPORT;
|
||||
break;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
int op_unary(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
|
||||
@@ -1897,7 +1897,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
|
||||
// note: this is slower
|
||||
//const bool is_c4 = op->src[0]->ne[0] % 4 == 0 && op->ne[0] % 4 == 0;
|
||||
const bool is_c4 = false;
|
||||
|
||||
snprintf(base, 256, "kernel_pad_%s%s", ggml_type_name(op->src[0]->type), is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
@@ -1907,6 +1911,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
res.c4 = is_c4;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -564,9 +564,20 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
const int nth = std::min(1024, ne0);
|
||||
int nth = std::min(256, ne0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
if (nth < 256) {
|
||||
nrptg = std::min((256 + nth - 1) / nth, ne1);
|
||||
if (nrptg * nth > 256) {
|
||||
nrptg = 256 / nth;
|
||||
}
|
||||
}
|
||||
|
||||
const int nw0 = (ne1 + nrptg - 1) / nrptg;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nw0, ne2, ne3, nth, nrptg, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -816,9 +827,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
} else {
|
||||
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
const int nth = MIN(args.ne00, nth_max);
|
||||
|
||||
const int nk0 = (args.ne00 + nth - 1)/nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
|
||||
@@ -1788,7 +1797,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
|
||||
nk0 = ne10/ggml_blck_size(op->type);
|
||||
}
|
||||
|
||||
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
int nth = std::min<int>(nk0*ne11, 256);
|
||||
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
@@ -1799,7 +1808,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
|
||||
nrptg = (nth + nk0 - 1)/nk0;
|
||||
nth = nk0;
|
||||
|
||||
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
if (nrptg*nth > 256) {
|
||||
nrptg--;
|
||||
}
|
||||
}
|
||||
@@ -1863,7 +1872,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
nk0 = ne00/ggml_blck_size(op->type);
|
||||
}
|
||||
|
||||
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
int nth = std::min<int>(nk0*ne01, 256);
|
||||
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
@@ -1874,7 +1883,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
nrptg = (nth + nk0 - 1)/nk0;
|
||||
nth = nk0;
|
||||
|
||||
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
if (nrptg*nth > 256) {
|
||||
nrptg--;
|
||||
}
|
||||
}
|
||||
@@ -4039,14 +4048,21 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
|
||||
|
||||
const int nth = std::min(1024, ne0);
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
const int nth_max = MIN(64, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
const int nth = MIN(args.ne0, nth_max);
|
||||
const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel!
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -2643,7 +2643,7 @@ kernel void kernel_gated_delta_net_impl(
|
||||
b_ptr += args.ne21;
|
||||
g_ptr += args.ne21*G;
|
||||
|
||||
if (K > 1u) {
|
||||
if (K > 1) {
|
||||
const int target_slot = (int)t - shift;
|
||||
if (target_slot >= 0 && target_slot < (int)K) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
|
||||
@@ -2655,7 +2655,7 @@ kernel void kernel_gated_delta_net_impl(
|
||||
}
|
||||
}
|
||||
|
||||
if (K == 1u) {
|
||||
if (K == 1) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
@@ -5104,7 +5104,7 @@ kernel void kernel_upscale_bilinear_f32(
|
||||
for (int64_t sx = x_min; sx < x_max; ++sx) {
|
||||
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
|
||||
const float w = wx * wy;
|
||||
const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
|
||||
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
|
||||
sum += (*src_ptr) * w;
|
||||
wsum += w;
|
||||
}
|
||||
@@ -5286,7 +5286,7 @@ kernel void kernel_upscale_bicubic_f32(
|
||||
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
|
||||
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
|
||||
|
||||
const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
|
||||
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
|
||||
sum += (*src_ptr) * wx * wy;
|
||||
}
|
||||
}
|
||||
@@ -5329,42 +5329,46 @@ kernel void kernel_roll_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_pad_f32(
|
||||
template <typename T>
|
||||
kernel void kernel_pad_impl(
|
||||
constant ggml_metal_kargs_pad & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t i3 = tgpig.z;
|
||||
const int32_t i2 = tgpig.y;
|
||||
const int32_t k0 = tgpig.x/args.ne1;
|
||||
const int32_t i1 = tgpig.x - k0*args.ne1;
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
const int32_t i03 = i3;
|
||||
const int32_t i02 = i2;
|
||||
const int32_t i01 = i1;
|
||||
|
||||
const int64_t i03 = i3;
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i01 = i1;
|
||||
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
||||
|
||||
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
||||
|
||||
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
if (i0 < args.ne00) {
|
||||
dst_ptr[i0] = src0_ptr[i0];
|
||||
} else {
|
||||
dst_ptr[i0] = 0.0f;
|
||||
}
|
||||
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
|
||||
const int32_t i0 = k0*1024 + tpitg.x + l0;
|
||||
if (i0 >= args.ne0) {
|
||||
break;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
dst_ptr[i0] = 0.0f;
|
||||
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
||||
dst_ptr[i0] = src0_ptr[i0];
|
||||
} else {
|
||||
dst_ptr[i0] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
|
||||
|
||||
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
|
||||
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
|
||||
|
||||
// TODO: this is slow - optimize
|
||||
kernel void kernel_pad_reflect_1d_f32(
|
||||
constant ggml_metal_kargs_pad_reflect_1d & args,
|
||||
device const char * src0,
|
||||
@@ -7328,23 +7332,27 @@ kernel void kernel_cpy_t_t(
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
||||
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
const int32_t i03 = tgpig[2];
|
||||
const int32_t i02 = tgpig[1];
|
||||
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
||||
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
|
||||
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
|
||||
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
|
||||
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
dst_data[i00] = (T1) src[0];
|
||||
break;
|
||||
@@ -7376,23 +7384,27 @@ kernel void kernel_cpy_f32_q(
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
||||
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
const int32_t i03 = tgpig[2];
|
||||
const int32_t i02 = tgpig[1];
|
||||
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
||||
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
|
||||
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
||||
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
||||
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
||||
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
|
||||
|
||||
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
|
||||
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
|
||||
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
|
||||
|
||||
quantize_func(src, dst_data[i00]);
|
||||
@@ -7417,24 +7429,28 @@ kernel void kernel_cpy_q_f32(
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
||||
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
const int32_t i03 = tgpig[2];
|
||||
const int32_t i02 = tgpig[1];
|
||||
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
||||
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
|
||||
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
|
||||
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
|
||||
T4x4 temp;
|
||||
dequantize_func(src_data + i00/nl, i00%nl, temp);
|
||||
dst_data[i00] = temp;
|
||||
@@ -7470,7 +7486,11 @@ kernel void kernel_concat(
|
||||
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
|
||||
|
||||
if (i1 >= args.ne1) {
|
||||
return;
|
||||
}
|
||||
|
||||
int o[4] = {0, 0, 0, 0};
|
||||
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user