mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-16 16:27:32 +03:00
Compare commits
30 Commits
b8777
...
gg/libcomm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aefc0d1653 | ||
|
|
64c8c88ac9 | ||
|
|
35f69bd4b2 | ||
|
|
a9e852d21a | ||
|
|
48cb5bcb0e | ||
|
|
639b199eb2 | ||
|
|
742a584ccb | ||
|
|
8dc530b86d | ||
|
|
e1a9a6dcbe | ||
|
|
e39eba26f3 | ||
|
|
5d14e5d19b | ||
|
|
fae3a28070 | ||
|
|
c0de6eda72 | ||
|
|
707c0b7a6e | ||
|
|
1f30ac0cea | ||
|
|
f4b5bf2f32 | ||
|
|
aa0f1897b7 | ||
|
|
be76dd0bb2 | ||
|
|
2e05f06ffb | ||
|
|
acc37a42ea | ||
|
|
5a23695d5a | ||
|
|
56666fa607 | ||
|
|
6a6780a232 | ||
|
|
e489a5ca0e | ||
|
|
e21cdc11a0 | ||
|
|
e974923698 | ||
|
|
1c0d9081fd | ||
|
|
a8bad3842e | ||
|
|
75f3bc94e6 | ||
|
|
aa00911d12 |
@@ -7,7 +7,7 @@ RUN apt update && apt install -y git build-essential cmake wget xz-utils
|
||||
|
||||
# Install SSL and Vulkan SDK dependencies
|
||||
RUN apt install -y libssl-dev curl \
|
||||
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libvulkan-dev glslc
|
||||
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libvulkan-dev glslc spirv-headers
|
||||
|
||||
# Build it
|
||||
WORKDIR /app
|
||||
|
||||
108
.github/workflows/build-self-hosted.yml
vendored
108
.github/workflows/build-self-hosted.yml
vendored
@@ -141,61 +141,59 @@ jobs:
|
||||
# amd-smi static
|
||||
# GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
# TODO: sandbox Mac runners
|
||||
# ggml-ci-mac-metal:
|
||||
# runs-on: [self-hosted, macOS, ARM64]
|
||||
#
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v6
|
||||
#
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
# run: |
|
||||
# GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
|
||||
#
|
||||
# ggml-ci-mac-webgpu:
|
||||
# runs-on: [self-hosted, macOS, ARM64]
|
||||
#
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v6
|
||||
#
|
||||
# - name: Dawn Dependency
|
||||
# id: dawn-depends
|
||||
# run: |
|
||||
# DAWN_VERSION="v2.0.0"
|
||||
# DAWN_OWNER="reeselevine"
|
||||
# DAWN_REPO="dawn"
|
||||
# DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release"
|
||||
# echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
# curl -L -o artifact.zip \
|
||||
# "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
# mkdir dawn
|
||||
# unzip artifact.zip
|
||||
# tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
|
||||
#
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
# run: |
|
||||
# GG_BUILD_WEBGPU=1 GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \
|
||||
# bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
|
||||
#
|
||||
# ggml-ci-mac-vulkan:
|
||||
# runs-on: [self-hosted, macOS, ARM64]
|
||||
#
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v6
|
||||
#
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
# run: |
|
||||
# vulkaninfo --summary
|
||||
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
|
||||
ggml-ci-mac-metal:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
|
||||
|
||||
ggml-ci-mac-webgpu:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
run: |
|
||||
DAWN_VERSION="v20260317.182325"
|
||||
DAWN_OWNER="google"
|
||||
DAWN_REPO="dawn"
|
||||
DAWN_ASSET_NAME="Dawn-18eb229ef5f707c1464cc581252e7603c73a3ef0-macos-latest-Release"
|
||||
echo "Fetching release asset from https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
|
||||
curl -L -o artifact.tar.gz \
|
||||
"https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
|
||||
mkdir dawn
|
||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
GG_BUILD_WEBGPU=1 GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \
|
||||
bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
|
||||
|
||||
ggml-ci-mac-vulkan:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
|
||||
|
||||
ggml-ci-linux-intel-vulkan:
|
||||
runs-on: [self-hosted, Linux, Intel]
|
||||
|
||||
3
.github/workflows/build-vulkan.yml
vendored
3
.github/workflows/build-vulkan.yml
vendored
@@ -93,4 +93,5 @@ jobs:
|
||||
export GGML_VK_DISABLE_F16=1
|
||||
export GGML_VK_DISABLE_COOPMAT=1
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 4800
|
||||
# test-backend-ops is too slow on llvmpipe, skip it
|
||||
ctest -L main -E test-backend-ops --verbose --timeout 900
|
||||
|
||||
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -318,7 +318,7 @@ jobs:
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y gcc-14 g++-14 build-essential glslc libvulkan-dev libssl-dev ninja-build
|
||||
sudo apt-get install -y gcc-14 g++-14 build-essential glslc libvulkan-dev spirv-headers libssl-dev ninja-build
|
||||
echo "CC=gcc-14" >> "$GITHUB_ENV"
|
||||
echo "CXX=g++-14" >> "$GITHUB_ENV"
|
||||
|
||||
|
||||
2
.github/workflows/close-issue.yml
vendored
2
.github/workflows/close-issue.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap"
|
||||
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap,security"
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -202,7 +202,7 @@ jobs:
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libssl-dev
|
||||
else
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y gcc-14 g++-14 build-essential glslc libvulkan-dev libssl-dev ninja-build
|
||||
sudo apt-get install -y gcc-14 g++-14 build-essential glslc libvulkan-dev spirv-headers libssl-dev ninja-build
|
||||
echo "CC=gcc-14" >> "$GITHUB_ENV"
|
||||
echo "CXX=g++-14" >> "$GITHUB_ENV"
|
||||
fi
|
||||
|
||||
77
.github/workflows/server-self-hosted.yml
vendored
77
.github/workflows/server-self-hosted.yml
vendored
@@ -84,41 +84,42 @@ jobs:
|
||||
export ${{ matrix.extra_args }}
|
||||
pytest -v -x -m "not slow"
|
||||
|
||||
server-cuda:
|
||||
runs-on: [self-hosted, llama-server, Linux, NVIDIA]
|
||||
|
||||
name: server-cuda (${{ matrix.wf_name }})
|
||||
strategy:
|
||||
matrix:
|
||||
build_type: [Release]
|
||||
wf_name: ["GPUx1"]
|
||||
include:
|
||||
- build_type: Release
|
||||
extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1"
|
||||
wf_name: "GPUx1, backend-sampling"
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -DGGML_SCHED_NO_REALLOC=ON
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(sysctl -n hw.logicalcpu) --target llama-server
|
||||
|
||||
- name: Tests
|
||||
id: server_integration_tests
|
||||
if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }}
|
||||
run: |
|
||||
cd tools/server/tests
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
export ${{ matrix.extra_args }}
|
||||
pytest -v -x -m "not slow"
|
||||
# TODO: provision CUDA runner
|
||||
# server-cuda:
|
||||
# runs-on: [self-hosted, llama-server, Linux, NVIDIA]
|
||||
#
|
||||
# name: server-cuda (${{ matrix.wf_name }})
|
||||
# strategy:
|
||||
# matrix:
|
||||
# build_type: [Release]
|
||||
# wf_name: ["GPUx1"]
|
||||
# include:
|
||||
# - build_type: Release
|
||||
# extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1"
|
||||
# wf_name: "GPUx1, backend-sampling"
|
||||
# fail-fast: false
|
||||
#
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
# ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
#
|
||||
# - name: Build
|
||||
# id: cmake_build
|
||||
# run: |
|
||||
# cmake -B build -DGGML_SCHED_NO_REALLOC=ON
|
||||
# cmake --build build --config ${{ matrix.build_type }} -j $(sysctl -n hw.logicalcpu) --target llama-server
|
||||
#
|
||||
# - name: Tests
|
||||
# id: server_integration_tests
|
||||
# if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }}
|
||||
# run: |
|
||||
# cd tools/server/tests
|
||||
# python3 -m venv venv
|
||||
# source venv/bin/activate
|
||||
# pip install -r requirements.txt
|
||||
# export ${{ matrix.extra_args }}
|
||||
# pytest -v -x -m "not slow"
|
||||
|
||||
@@ -225,7 +225,7 @@ foreach(FILE_PATH ${EXTRA_LICENSES})
|
||||
endforeach()
|
||||
|
||||
if (LLAMA_BUILD_COMMON)
|
||||
license_generate(common)
|
||||
license_generate(llama-common)
|
||||
endif()
|
||||
|
||||
#
|
||||
@@ -249,6 +249,10 @@ set_target_properties(llama
|
||||
|
||||
install(TARGETS llama LIBRARY PUBLIC_HEADER)
|
||||
|
||||
if (LLAMA_BUILD_COMMON)
|
||||
install(TARGETS llama-common LIBRARY)
|
||||
endif()
|
||||
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/llama-config.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/llama-config.cmake
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# common
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
llama_add_compile_flags()
|
||||
|
||||
#
|
||||
# llama-common-base
|
||||
#
|
||||
|
||||
# Build info header
|
||||
|
||||
if(EXISTS "${PROJECT_SOURCE_DIR}/.git")
|
||||
@@ -33,17 +35,25 @@ endif()
|
||||
|
||||
set(TEMPLATE_FILE "${CMAKE_CURRENT_SOURCE_DIR}/build-info.cpp.in")
|
||||
set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/build-info.cpp")
|
||||
|
||||
configure_file(${TEMPLATE_FILE} ${OUTPUT_FILE})
|
||||
|
||||
set(TARGET build_info)
|
||||
add_library(${TARGET} OBJECT ${OUTPUT_FILE})
|
||||
set(TARGET llama-common-base)
|
||||
add_library(${TARGET} STATIC ${OUTPUT_FILE})
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC .)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
set(TARGET common)
|
||||
#
|
||||
# llama-common
|
||||
#
|
||||
|
||||
add_library(${TARGET} STATIC
|
||||
set(TARGET llama-common)
|
||||
|
||||
add_library(${TARGET}
|
||||
arg.cpp
|
||||
arg.h
|
||||
base64.hpp
|
||||
@@ -106,17 +116,24 @@ add_library(${TARGET} STATIC
|
||||
jinja/caps.h
|
||||
)
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES
|
||||
VERSION ${LLAMA_INSTALL_VERSION}
|
||||
SOVERSION 0
|
||||
MACHO_CURRENT_VERSION 0 # keep macOS linker from seeing oversized version number
|
||||
)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC . ../vendor)
|
||||
target_compile_features (${TARGET} PUBLIC cxx_std_17)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
# TODO: make fine-grained exports in the future
|
||||
set_target_properties(${TARGET} PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
||||
endif()
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
build_info
|
||||
cpp-httplib
|
||||
)
|
||||
target_link_libraries(${TARGET} PUBLIC llama-common-base)
|
||||
target_link_libraries(${TARGET} PRIVATE cpp-httplib)
|
||||
|
||||
if (LLAMA_LLGUIDANCE)
|
||||
include(ExternalProject)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "arg.h"
|
||||
|
||||
#include "build-info.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
@@ -1044,8 +1045,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--version"},
|
||||
"show version and build info",
|
||||
[](common_params &) {
|
||||
fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
|
||||
fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
|
||||
fprintf(stderr, "version: %d (%s)\n", llama_build_number(), llama_commit());
|
||||
fprintf(stderr, "built with %s for %s\n", llama_compiler(), llama_build_target());
|
||||
exit(0);
|
||||
}
|
||||
));
|
||||
|
||||
@@ -1,4 +1,35 @@
|
||||
#include "build-info.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
||||
int LLAMA_BUILD_NUMBER = @LLAMA_BUILD_NUMBER@;
|
||||
char const *LLAMA_COMMIT = "@LLAMA_BUILD_COMMIT@";
|
||||
char const *LLAMA_COMPILER = "@BUILD_COMPILER@";
|
||||
char const *LLAMA_BUILD_TARGET = "@BUILD_TARGET@";
|
||||
char const * LLAMA_COMMIT = "@LLAMA_BUILD_COMMIT@";
|
||||
char const * LLAMA_COMPILER = "@BUILD_COMPILER@";
|
||||
char const * LLAMA_BUILD_TARGET = "@BUILD_TARGET@";
|
||||
|
||||
int llama_build_number(void) {
|
||||
return LLAMA_BUILD_NUMBER;
|
||||
}
|
||||
|
||||
const char * llama_commit(void) {
|
||||
return LLAMA_COMMIT;
|
||||
}
|
||||
|
||||
const char * llama_compiler(void) {
|
||||
return LLAMA_COMPILER;
|
||||
}
|
||||
|
||||
const char * llama_build_target(void) {
|
||||
return LLAMA_BUILD_TARGET;
|
||||
}
|
||||
|
||||
const char * llama_build_info(void) {
|
||||
static std::string s = "b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT;
|
||||
return s.c_str();
|
||||
}
|
||||
|
||||
void llama_print_build_info(void) {
|
||||
fprintf(stderr, "%s: build = %d (%s)\n", __func__, llama_build_number(), llama_commit());
|
||||
fprintf(stderr, "%s: built with %s for %s\n", __func__, llama_compiler(), llama_build_target());
|
||||
}
|
||||
|
||||
11
common/build-info.h
Normal file
11
common/build-info.h
Normal file
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
int llama_build_number(void);
|
||||
|
||||
const char * llama_commit(void);
|
||||
const char * llama_compiler(void);
|
||||
|
||||
const char * llama_build_target(void);
|
||||
const char * llama_build_info(void);
|
||||
|
||||
void llama_print_build_info(void);
|
||||
@@ -198,10 +198,19 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont
|
||||
args_field = format.function_field + "." + args_field;
|
||||
}
|
||||
|
||||
auto tools_parser = p.standard_json_tools(
|
||||
format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls,
|
||||
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped,
|
||||
format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order);
|
||||
auto tools_parser = p.eps();
|
||||
if (format.section_start.empty() && !format.per_call_start.empty()) {
|
||||
auto single_tool_parser = p.standard_json_tools(
|
||||
format.per_call_start, format.per_call_end, inputs.tools, inputs.parallel_tool_calls,
|
||||
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped,
|
||||
format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order);
|
||||
tools_parser = p.trigger_rule("tool-calls", p.one_or_more(single_tool_parser + p.space()));
|
||||
} else {
|
||||
tools_parser = p.standard_json_tools(
|
||||
format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls,
|
||||
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped,
|
||||
format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order);
|
||||
}
|
||||
|
||||
// Handle content wrappers if present
|
||||
if (ctx.content && ctx.content->is_always_wrapped()) {
|
||||
|
||||
@@ -308,19 +308,23 @@ struct analyze_tools : analyze_base {
|
||||
|
||||
private:
|
||||
// Extract tool calling 'haystack' for further analysis and delegate further analysis based on format
|
||||
void analyze_tool_calls(const analyze_reasoning & reasoning);
|
||||
void analyze_tool_calls(const analyze_reasoning & reasoning, bool supports_parallel_tool_calls);
|
||||
|
||||
// Analyze format based on position of function and argument name in needle
|
||||
void analyze_tool_call_format(const std::string & haystack,
|
||||
const std::string & fun_name_needle,
|
||||
const std::string & arg_name_needle,
|
||||
const analyze_reasoning & reasoning);
|
||||
const analyze_reasoning & reasoning,
|
||||
bool supports_parallel_tool_calls);
|
||||
|
||||
// Analyze specifics of JSON native format (entire tool call is a JSON object)
|
||||
void analyze_tool_call_format_json_native(const std::string & clean_haystack,
|
||||
const std::string & fun_name_needle,
|
||||
const std::string & arg_name_needle);
|
||||
|
||||
// Check if parallel calls in JSON native format array wrapped or tag wrapped
|
||||
void analyze_json_native_parallel_calls();
|
||||
|
||||
// Analyze specifics of non-JSON native format (tags for function name or for function name and arguments)
|
||||
void analyze_tool_call_format_non_json(const std::string & clean_haystack,
|
||||
const std::string & fun_name_needle);
|
||||
|
||||
@@ -558,7 +558,7 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl,
|
||||
: analyze_base(tmpl) {
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3: Tool call analysis\n" ANSI_RESET);
|
||||
|
||||
analyze_tool_calls(reasoning);
|
||||
analyze_tool_calls(reasoning, caps.supports_parallel_tool_calls);
|
||||
|
||||
if (format.mode != tool_format::NONE && format.mode != tool_format::JSON_NATIVE) {
|
||||
if (caps.supports_parallel_tool_calls) {
|
||||
@@ -577,7 +577,7 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl,
|
||||
}
|
||||
}
|
||||
|
||||
void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning) {
|
||||
void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning, bool supports_parallel_tool_calls) {
|
||||
json assistant_no_tools = json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", ASSISTANT_MSG }
|
||||
@@ -611,13 +611,14 @@ void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning) {
|
||||
return;
|
||||
}
|
||||
|
||||
analyze_tool_call_format(tool_section, FUN_FIRST, ARG_FIRST, reasoning);
|
||||
analyze_tool_call_format(tool_section, FUN_FIRST, ARG_FIRST, reasoning, supports_parallel_tool_calls);
|
||||
}
|
||||
|
||||
void analyze_tools::analyze_tool_call_format(const std::string & haystack,
|
||||
const std::string & fun_name_needle,
|
||||
const std::string & arg_name_needle,
|
||||
const analyze_reasoning & reasoning) {
|
||||
const analyze_reasoning & reasoning,
|
||||
bool supports_parallel_tool_calls) {
|
||||
if (fun_name_needle.empty() || arg_name_needle.empty() || haystack.empty()) {
|
||||
return;
|
||||
}
|
||||
@@ -660,6 +661,9 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
|
||||
|
||||
if (format.mode == tool_format::JSON_NATIVE) {
|
||||
analyze_tool_call_format_json_native(clean_haystack, fun_name_needle, arg_name_needle);
|
||||
if (supports_parallel_tool_calls) {
|
||||
analyze_json_native_parallel_calls();
|
||||
}
|
||||
} else {
|
||||
analyze_tool_call_format_non_json(clean_haystack, fun_name_needle);
|
||||
}
|
||||
@@ -668,6 +672,42 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
|
||||
format.per_call_end = trim_whitespace(format.per_call_end);
|
||||
}
|
||||
|
||||
void analyze_tools::analyze_json_native_parallel_calls() {
|
||||
json assistant_one_tool = json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", "" },
|
||||
{ "tool_calls", json::array({ first_tool_call }) }
|
||||
};
|
||||
|
||||
json assistant_two_tools = json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", "" },
|
||||
{ "tool_calls", json::array({ first_tool_call, second_tool_call }) }
|
||||
};
|
||||
|
||||
template_params params;
|
||||
params.messages = json::array({ user_msg, assistant_one_tool });
|
||||
params.tools = tools;
|
||||
params.add_generation_prompt = false;
|
||||
params.enable_thinking = true;
|
||||
|
||||
auto comparison = compare_variants(
|
||||
*tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_two_tools }); });
|
||||
|
||||
if (!comparison) {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
std::string & second_call = comparison->diff.right;
|
||||
if (!format.section_start.empty() && second_call.find(format.section_start) != std::string::npos) {
|
||||
format.per_call_start = format.section_start;
|
||||
format.per_call_end = format.section_end;
|
||||
format.section_start.clear();
|
||||
format.section_end.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void analyze_tools::analyze_tool_call_format_json_native(const std::string & clean_haystack,
|
||||
const std::string & fun_name_needle,
|
||||
const std::string & arg_name_needle) {
|
||||
|
||||
@@ -676,7 +676,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
|
||||
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
|
||||
|
||||
auto nested_name = literal("\"" + nested_name_field + "\"") + space() + literal(":") + space() +
|
||||
literal("\"") + tool_name(literal(name)) + literal("\"");
|
||||
atomic(literal("\"") + tool_name(literal(name)) + literal("\""));
|
||||
auto nested_args = literal("\"" + nested_args_field + "\"") + space() + literal(":") + space() +
|
||||
tool_args(schema(json(), "tool-" + name + "-schema", params));
|
||||
|
||||
@@ -744,7 +744,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
|
||||
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
|
||||
|
||||
auto tool_name_ = name_key_parser + space() + literal(":") + space() +
|
||||
literal("\"") + tool_name(literal(name)) + literal("\"");
|
||||
atomic(literal("\"") + tool_name(literal(name)) + literal("\""));
|
||||
auto tool_args_ = args_key_parser + space() + literal(":") + space() +
|
||||
tool_args(schema(json(), "tool-" + name + "-schema", params));
|
||||
|
||||
|
||||
197
common/chat.cpp
197
common/chat.cpp
@@ -1091,6 +1091,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
|
||||
if (inputs.add_generation_prompt && string_ends_with(data.prompt, "<turn|>\n")) {
|
||||
// This may happen if the model generates content + tool_call, the
|
||||
// template does not add the model's next turn and confuses the model
|
||||
// from emitting its proper reasoning token sequence.
|
||||
data.prompt += "<|turn>model\n";
|
||||
}
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<|channel>thought";
|
||||
@@ -1118,7 +1126,8 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
p.rule("thought", p.content(p.literal("<|channel>thought") + p.space() + p.until("<channel|>") + p.literal("<channel|>")));
|
||||
}
|
||||
|
||||
auto thought = (p.peek(p.literal("<|channel>")) + p.ref("thought")) | p.negate(p.literal("<|channel>"));
|
||||
auto consume_empty_channels = p.gbnf(p.zero_or_more(p.literal("<|channel>") + p.negate(p.literal("thought"))), "");
|
||||
auto thought = (p.peek(p.literal("<|channel>")) + consume_empty_channels + p.ref("thought")) | p.negate(p.literal("<|channel>"));
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.literal("```json") <<
|
||||
@@ -1182,12 +1191,16 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
/* max = */ inputs.parallel_tool_calls ? -1 : 1
|
||||
));
|
||||
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<|tool_call>"})));
|
||||
auto scan_to_toolcall = p.rule("scan-to-toolcall", p.until("<|tool_call>"));
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<channel|>", "<|tool_call>"})));
|
||||
auto message = p.rule("message", thought + content);
|
||||
return start + p.zero_or_more(message) + tool_call;
|
||||
return start + p.zero_or_more(message) + scan_to_toolcall + tool_call;
|
||||
}
|
||||
|
||||
auto content = p.rule("content", p.content(p.until("<|channel>")));
|
||||
// Gemma 4 may emit an extra <|channel>thought\n<channel|> at the end of the content. It may
|
||||
// also emit a single trailing <channel|> token. Consume all complete reasoning blocks and
|
||||
// then stop at the first unmatched <channel|> token.
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<channel|>"})));
|
||||
auto message = p.rule("message", thought + content);
|
||||
return start + p.one_or_more(message);
|
||||
});
|
||||
@@ -1656,6 +1669,173 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_deepseek_v3_2(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
data.preserved_tokens = {
|
||||
"|DSML|",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
|
||||
const std::string DSML = "|DSML|";
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
const std::string FC_START = "<" + DSML + "function_calls>";
|
||||
const std::string FC_END = "</" + DSML + "function_calls>";
|
||||
const std::string INVOKE_START = "<" + DSML + "invoke";
|
||||
const std::string INVOKE_END = "</" + DSML + "invoke>";
|
||||
const std::string PARAM_START = "<" + DSML + "parameter";
|
||||
const std::string PARAM_END = "</" + DSML + "parameter>";
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
if (extract_reasoning && inputs.enable_thinking) {
|
||||
reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
|
||||
} else if (extract_reasoning) {
|
||||
// Thinking disabled but reasoning extraction requested: the generation prompt
|
||||
// contains an empty <think></think> pair that must still be consumed.
|
||||
reasoning = p.optional(p.literal(THINK_START) + p.until(THINK_END) + p.literal(THINK_END));
|
||||
}
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.rule("response-format",
|
||||
p.literal("```json") + p.space() +
|
||||
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)) +
|
||||
p.space() + p.literal("```"));
|
||||
return generation_prompt + reasoning + response_format + end;
|
||||
}
|
||||
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return generation_prompt + reasoning + p.content(p.rest()) + end;
|
||||
}
|
||||
|
||||
auto tool_choice = p.choice();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto params = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
const auto & props = params.contains("properties") ? params.at("properties") : json::object();
|
||||
|
||||
std::set<std::string> required;
|
||||
if (params.contains("required")) {
|
||||
params.at("required").get_to(required);
|
||||
}
|
||||
|
||||
auto schema_info = common_schema_info();
|
||||
schema_info.resolve_refs(params);
|
||||
|
||||
std::vector<common_peg_parser> required_parsers;
|
||||
std::vector<common_peg_parser> optional_parsers;
|
||||
for (const auto & [param_name, param_schema] : props.items()) {
|
||||
bool is_required = required.find(param_name) != required.end();
|
||||
bool is_string = schema_info.resolves_to_string(param_schema);
|
||||
|
||||
auto arg = p.tool_arg(
|
||||
p.tool_arg_open(
|
||||
p.literal(PARAM_START + " name=\"") +
|
||||
p.tool_arg_name(p.literal(param_name)) +
|
||||
p.literal("\" string=\"" + std::string(is_string ? "true" : "false") + "\">")) +
|
||||
(is_string
|
||||
? p.tool_arg_string_value(p.until(PARAM_END))
|
||||
: p.tool_arg_json_value(p.schema(p.json(),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, false))) +
|
||||
p.tool_arg_close(p.literal(PARAM_END)));
|
||||
|
||||
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
|
||||
if (is_required) {
|
||||
required_parsers.push_back(named_arg);
|
||||
} else {
|
||||
optional_parsers.push_back(named_arg);
|
||||
}
|
||||
}
|
||||
|
||||
common_peg_parser args_seq = p.eps();
|
||||
for (size_t i = 0; i < required_parsers.size(); i++) {
|
||||
if (i > 0) {
|
||||
args_seq = args_seq + p.space();
|
||||
}
|
||||
args_seq = args_seq + required_parsers[i];
|
||||
}
|
||||
|
||||
if (!optional_parsers.empty()) {
|
||||
common_peg_parser any_opt = p.choice();
|
||||
for (const auto & opt : optional_parsers) {
|
||||
any_opt |= opt;
|
||||
}
|
||||
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, -1);
|
||||
}
|
||||
|
||||
common_peg_parser invoke_body = args_seq;
|
||||
auto func_parser = p.tool(
|
||||
p.tool_open(p.literal(INVOKE_START + " name=\"") +
|
||||
p.tool_name(p.literal(name)) + p.literal("\">\n")) +
|
||||
invoke_body + p.space() +
|
||||
p.tool_close(p.literal(INVOKE_END)));
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
auto require_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_parser tool_calls = p.eps();
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_calls = p.trigger_rule("tool-call",
|
||||
p.literal(FC_START) + p.space() + tool_choice +
|
||||
p.zero_or_more(p.space() + tool_choice) + p.space() + p.literal(FC_END));
|
||||
} else {
|
||||
tool_calls = p.trigger_rule("tool-call",
|
||||
p.literal(FC_START) + p.space() + tool_choice + p.space() + p.literal(FC_END));
|
||||
}
|
||||
|
||||
if (!require_tools) {
|
||||
tool_calls = p.optional(tool_calls);
|
||||
}
|
||||
|
||||
auto content_before_tools = p.content(p.until(FC_START));
|
||||
return generation_prompt + reasoning + content_before_tools + tool_calls + end;
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, FC_START },
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
namespace workaround {
|
||||
|
||||
static void map_developer_role_to_system(json & messages) {
|
||||
@@ -1927,6 +2107,15 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
return common_chat_params_init_gigachat_v3(tmpl, params);
|
||||
}
|
||||
|
||||
// DeepSeek V3.2 format detection: template defines dsml_token and uses it for tool calls.
|
||||
// The template source contains the token as a variable assignment, not as a literal in markup.
|
||||
if (src.find("dsml_token") != std::string::npos &&
|
||||
src.find("function_calls") != std::string::npos &&
|
||||
src.find("DSML") != std::string::npos) {
|
||||
LOG_DBG("Using specialized template: DeepSeek V3.2\n");
|
||||
return common_chat_params_init_deepseek_v3_2(tmpl, params);
|
||||
}
|
||||
|
||||
// Gemma4 format detection
|
||||
if (src.find("'<|tool_call>call:'") != std::string::npos) {
|
||||
if (src.find("{#- OpenAI Chat Completions:") == std::string::npos) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include "build-info.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
@@ -372,7 +373,7 @@ void common_init() {
|
||||
const char * build_type = " (debug)";
|
||||
#endif
|
||||
|
||||
LOG_DBG("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
|
||||
LOG_DBG("build: %d (%s) with %s for %s%s\n", llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
|
||||
}
|
||||
|
||||
std::string common_params_get_system_info(const common_params & params) {
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "ggml.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
@@ -27,11 +28,6 @@
|
||||
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
|
||||
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
|
||||
|
||||
#define print_build_info() do { \
|
||||
fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \
|
||||
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
||||
} while(0)
|
||||
|
||||
struct common_time_meas {
|
||||
common_time_meas(int64_t & t_acc, bool disable = false);
|
||||
~common_time_meas();
|
||||
@@ -53,14 +49,6 @@ struct common_adapter_lora_info {
|
||||
|
||||
using llama_tokens = std::vector<llama_token>;
|
||||
|
||||
// build info
|
||||
extern int LLAMA_BUILD_NUMBER;
|
||||
extern const char * LLAMA_COMMIT;
|
||||
extern const char * LLAMA_COMPILER;
|
||||
extern const char * LLAMA_BUILD_TARGET;
|
||||
|
||||
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
|
||||
|
||||
struct common_control_vector_load_info;
|
||||
|
||||
//
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "arg.h"
|
||||
|
||||
#include "build-info.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "download.h"
|
||||
@@ -258,6 +259,9 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
if (progress_step >= p.total / 1000 || p.downloaded == p.total) {
|
||||
if (callback) {
|
||||
callback->on_update(p);
|
||||
if (callback->is_cancelled()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
progress_step = 0;
|
||||
}
|
||||
@@ -300,7 +304,7 @@ static int common_download_file_single_online(const std::string & url,
|
||||
headers.emplace(h.first, h.second);
|
||||
}
|
||||
if (headers.find("User-Agent") == headers.end()) {
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
headers.emplace("User-Agent", "llama-cpp/" + std::string(llama_build_info()));
|
||||
}
|
||||
if (!opts.bearer_token.empty()) {
|
||||
headers.emplace("Authorization", "Bearer " + opts.bearer_token);
|
||||
@@ -373,6 +377,9 @@ static int common_download_file_single_online(const std::string & url,
|
||||
}
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
if (opts.callback && opts.callback->is_cancelled()) {
|
||||
break;
|
||||
}
|
||||
if (i) {
|
||||
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(delay));
|
||||
@@ -412,6 +419,12 @@ static int common_download_file_single_online(const std::string & url,
|
||||
if (opts.callback) {
|
||||
opts.callback->on_done(p, success);
|
||||
}
|
||||
if (opts.callback && opts.callback->is_cancelled() &&
|
||||
std::filesystem::exists(path_temporary)) {
|
||||
if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, path_temporary.c_str());
|
||||
}
|
||||
}
|
||||
if (!success) {
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
return -1; // max attempts reached
|
||||
@@ -429,7 +442,7 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
|
||||
headers.emplace(h.first, h.second);
|
||||
}
|
||||
if (headers.find("User-Agent") == headers.end()) {
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
headers.emplace("User-Agent", "llama-cpp/" + std::string(llama_build_info()));
|
||||
}
|
||||
|
||||
if (params.timeout > 0) {
|
||||
|
||||
@@ -21,6 +21,7 @@ public:
|
||||
virtual void on_start(const common_download_progress & p) = 0;
|
||||
virtual void on_update(const common_download_progress & p) = 0;
|
||||
virtual void on_done(const common_download_progress & p, bool ok) = 0;
|
||||
virtual bool is_cancelled() const { return false; }
|
||||
};
|
||||
|
||||
struct common_remote_params {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "hf-cache.h"
|
||||
|
||||
#include "build-info.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "http.h"
|
||||
@@ -200,7 +201,7 @@ static nl::json api_get(const std::string & url,
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers headers = {
|
||||
{"User-Agent", "llama-cpp/" + build_info},
|
||||
{"User-Agent", "llama-cpp/" + std::string(llama_build_info())},
|
||||
{"Accept", "application/json"}
|
||||
};
|
||||
|
||||
|
||||
@@ -23,6 +23,10 @@
|
||||
|
||||
int common_log_verbosity_thold = LOG_DEFAULT_LLAMA;
|
||||
|
||||
int common_log_get_verbosity_thold(void) {
|
||||
return common_log_verbosity_thold;
|
||||
}
|
||||
|
||||
void common_log_set_verbosity_thold(int verbosity) {
|
||||
common_log_verbosity_thold = verbosity;
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ enum log_colors {
|
||||
|
||||
// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower
|
||||
// set via common_log_set_verbosity()
|
||||
extern int common_log_verbosity_thold;
|
||||
int common_log_get_verbosity_thold(void);
|
||||
|
||||
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
||||
|
||||
@@ -98,7 +98,7 @@ void common_log_flush (struct common_log * log); // f
|
||||
|
||||
#define LOG_TMPL(level, verbosity, ...) \
|
||||
do { \
|
||||
if ((verbosity) <= common_log_verbosity_thold) { \
|
||||
if ((verbosity) <= common_log_get_verbosity_thold()) { \
|
||||
common_log_add(common_log_main(), (level), __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
@@ -890,6 +890,10 @@ struct parser_executor {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
};
|
||||
|
||||
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
|
||||
@@ -957,7 +961,8 @@ void common_peg_arena::resolve_refs() {
|
||||
std::is_same_v<T, common_peg_and_parser> ||
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
@@ -1036,6 +1041,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
|
||||
return "Not(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
return "Atomic(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
|
||||
return "Any";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
|
||||
@@ -1565,6 +1572,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_schema_parser>) {
|
||||
visit(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
@@ -1651,10 +1659,13 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
|
||||
std::string s;
|
||||
for (const auto & child : p.children) {
|
||||
auto child_gbnf = to_gbnf(child);
|
||||
if (child_gbnf.empty()) {
|
||||
continue;
|
||||
}
|
||||
if (!s.empty()) {
|
||||
s += " ";
|
||||
}
|
||||
auto child_gbnf = to_gbnf(child);
|
||||
const auto & child_parser = effective_parser(child);
|
||||
if (std::holds_alternative<common_peg_choice_parser>(child_parser) ||
|
||||
std::holds_alternative<common_peg_sequence_parser>(child_parser)) {
|
||||
@@ -1754,6 +1765,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return p.grammar;
|
||||
} else {
|
||||
static_assert(is_always_false_v<T>);
|
||||
}
|
||||
@@ -1888,6 +1901,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
|
||||
{"child", p.child},
|
||||
{"tag", p.tag}
|
||||
};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
|
||||
}
|
||||
}, variant);
|
||||
}
|
||||
@@ -2050,6 +2065,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
|
||||
};
|
||||
}
|
||||
|
||||
if (type == "gbnf") {
|
||||
if (!j.contains("child") || !j.contains("grammar")) {
|
||||
throw std::runtime_error("gbnf parser missing required fields");
|
||||
}
|
||||
return common_peg_gbnf_parser{
|
||||
j["child"].get<common_peg_parser_id>(),
|
||||
j["grammar"].get<std::string>(),
|
||||
};
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unknown parser type: " + type);
|
||||
}
|
||||
|
||||
|
||||
@@ -270,6 +270,11 @@ struct common_peg_tag_parser {
|
||||
std::string tag;
|
||||
};
|
||||
|
||||
struct common_peg_gbnf_parser {
|
||||
common_peg_parser_id child;
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
// Variant holding all parser types
|
||||
using common_peg_parser_variant = std::variant<
|
||||
common_peg_epsilon_parser,
|
||||
@@ -290,7 +295,8 @@ using common_peg_parser_variant = std::variant<
|
||||
common_peg_rule_parser,
|
||||
common_peg_ref_parser,
|
||||
common_peg_atomic_parser,
|
||||
common_peg_tag_parser
|
||||
common_peg_tag_parser,
|
||||
common_peg_gbnf_parser
|
||||
>;
|
||||
|
||||
class common_peg_arena {
|
||||
@@ -504,6 +510,10 @@ class common_peg_parser_builder {
|
||||
// Unlike rules, you can tag multiple nodes with the same tag.
|
||||
common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); }
|
||||
|
||||
// Wraps a child parser but emits a custom GBNF grammar string instead of
|
||||
// the child's grammar. Parsing delegates entirely to the child.
|
||||
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
|
||||
|
||||
void set_root(const common_peg_parser & p);
|
||||
|
||||
common_peg_arena build();
|
||||
|
||||
@@ -287,8 +287,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning budget sampler
|
||||
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) {
|
||||
// reasoning budget sampler (skip when budget is unlimited unless a lazy grammar is active, which needs rbudget for thinking-block suppression)
|
||||
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty() && (params.grammar_lazy || params.reasoning_budget_tokens >= 0)) {
|
||||
rbudget = common_reasoning_budget_init(
|
||||
vocab,
|
||||
params.reasoning_budget_start,
|
||||
|
||||
@@ -456,7 +456,8 @@ pacman -S git \
|
||||
mingw-w64-ucrt-x86_64-gcc \
|
||||
mingw-w64-ucrt-x86_64-cmake \
|
||||
mingw-w64-ucrt-x86_64-vulkan-devel \
|
||||
mingw-w64-ucrt-x86_64-shaderc
|
||||
mingw-w64-ucrt-x86_64-shaderc \
|
||||
mingw-w64-ucrt-x86_64-spirv-headers
|
||||
```
|
||||
|
||||
Switch into the `llama.cpp` directory and build using CMake.
|
||||
@@ -490,9 +491,11 @@ First, follow the official LunarG instructions for the installation and setup of
|
||||
|
||||
On Debian / Ubuntu, you can install the required dependencies using:
|
||||
```sh
|
||||
sudo apt-get install libvulkan-dev glslc
|
||||
sudo apt-get install libvulkan-dev glslc spirv-headers
|
||||
```
|
||||
|
||||
SPIRV-Headers (`spirv/unified1/spirv.hpp`) are required for the Vulkan backend and are **not** always pulled in by the Vulkan loader dev package alone. Other distros use names such as `spirv-headers` (Ubuntu / Debian / Arch), or `spirv-headers-devel` (Fedora / openSUSE). On Windows, the LunarG Vulkan SDK’s `Include` directory already contains these headers.
|
||||
|
||||
#### Common steps
|
||||
|
||||
Second, after verifying that you have followed all of the SDK installation/setup steps, use this command to make sure before proceeding:
|
||||
|
||||
@@ -114,6 +114,10 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
|
||||
# Mistral's Voxtral
|
||||
(tool_name) -hf ggml-org/Voxtral-Mini-3B-2507-GGUF
|
||||
|
||||
# Qwen3-ASR
|
||||
(tool_name) -hf ggml-org/Qwen3-ASR-0.6B-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen3-ASR-1.7B-GGUF
|
||||
```
|
||||
|
||||
**Mixed modalities**:
|
||||
@@ -124,6 +128,11 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
(tool_name) -hf ggml-org/Qwen2.5-Omni-3B-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen2.5-Omni-7B-GGUF
|
||||
|
||||
# Qwen3 Omni
|
||||
# Capabilities: audio input, vision input
|
||||
(tool_name) -hf ggml-org/Qwen3-Omni-30B-A3B-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen3-Omni-30B-A3B-Thinking-GGUF
|
||||
|
||||
# Gemma 4
|
||||
# Capabilities: audio input, vision input
|
||||
(tool_name) -hf ggml-org/gemma-4-E2B-it-GGUF
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-batched)
|
||||
add_executable(${TARGET} batched.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-convert-llama2c-to-ggml)
|
||||
add_executable(${TARGET} convert-llama2c-to-ggml.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-debug)
|
||||
add_executable(${TARGET} debug.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-diffusion-cli)
|
||||
add_executable(${TARGET} diffusion-cli.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -602,8 +602,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
int n_input = input_tokens.size();
|
||||
|
||||
if (n_input >= params.n_ctx) {
|
||||
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
|
||||
if (static_cast<uint32_t>(n_input) >= llama_n_ctx(ctx)) {
|
||||
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, llama_n_ctx(ctx));
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
return 1;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-embedding)
|
||||
add_executable(${TARGET} embedding.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
set(TARGET llama-eval-callback)
|
||||
add_executable(${TARGET} eval-callback.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_BUILD_TESTS)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-gen-docs)
|
||||
add_executable(${TARGET} gen-docs.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-idle)
|
||||
add_executable(${TARGET} idle.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-lookahead)
|
||||
add_executable(${TARGET} lookahead.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
set(TARGET llama-lookup)
|
||||
add_executable(${TARGET} lookup.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
set(TARGET llama-lookup-create)
|
||||
add_executable(${TARGET} lookup-create.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
set(TARGET llama-lookup-merge)
|
||||
add_executable(${TARGET} lookup-merge.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
set(TARGET llama-lookup-stats)
|
||||
add_executable(${TARGET} lookup-stats.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-parallel)
|
||||
add_executable(${TARGET} parallel.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-passkey)
|
||||
add_executable(${TARGET} passkey.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-retrieval)
|
||||
add_executable(${TARGET} retrieval.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-save-load-state)
|
||||
add_executable(${TARGET} save-load-state.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-speculative-simple)
|
||||
add_executable(${TARGET} speculative-simple.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-speculative)
|
||||
add_executable(${TARGET} speculative.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -5,5 +5,5 @@
|
||||
set(TARGET llama-ls-sycl-device)
|
||||
add_executable(${TARGET} ls-sycl-device.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
set(TARGET llama-finetune)
|
||||
add_executable(${TARGET} finetune.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories.
|
||||
|
||||
# ref: https://cmake.org/cmake/help/latest/policy/CMP0194.html
|
||||
# MSVC is not a valid assembler for the ASM language.
|
||||
# Set to NEW to avoid a warning on CMake 4.1+ with MSVC.
|
||||
if (POLICY CMP0194)
|
||||
cmake_policy(SET CMP0194 NEW)
|
||||
endif()
|
||||
project("ggml" C CXX ASM)
|
||||
|
||||
### GGML Version
|
||||
|
||||
@@ -348,6 +348,53 @@ extern "C" {
|
||||
// Set a callback to be called for each resulting node during graph compute
|
||||
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
|
||||
|
||||
//
|
||||
// Meta backend
|
||||
//
|
||||
|
||||
#define GGML_BACKEND_META_MAX_DEVICES 16
|
||||
|
||||
enum ggml_backend_meta_split_axis {
|
||||
// tensor split by tensor dimensions:
|
||||
GGML_BACKEND_SPLIT_AXIS_0 = 0,
|
||||
GGML_BACKEND_SPLIT_AXIS_1 = 1,
|
||||
GGML_BACKEND_SPLIT_AXIS_2 = 2,
|
||||
GGML_BACKEND_SPLIT_AXIS_3 = 3,
|
||||
|
||||
GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends
|
||||
GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum
|
||||
|
||||
// for internal bookkeeping only:
|
||||
GGML_BACKEND_SPLIT_AXIS_NONE = 98,
|
||||
GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99,
|
||||
};
|
||||
GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis);
|
||||
|
||||
struct ggml_backend_meta_split_state {
|
||||
enum ggml_backend_meta_split_axis axis;
|
||||
|
||||
// for tensors with axis >= 0 && axis < GGML_MAX_DIMS:
|
||||
// - each device has a slice of the tensor along the split axis
|
||||
// - most tensors have n_segments == 1 and a contiguous slice of the tensor data
|
||||
// - some tensors have an inhomogenenous data layout along the split axis,
|
||||
// those tensors are divided into segments which are each individually split across devices
|
||||
// - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis,
|
||||
// the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1],
|
||||
// - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments
|
||||
// that each need to be split individually across devices so that each device gets a slice of Q, K, and V
|
||||
int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES];
|
||||
uint32_t n_segments;
|
||||
};
|
||||
|
||||
// function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible:
|
||||
typedef struct ggml_backend_meta_split_state(*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
// create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this:
|
||||
// TODO: this looks a bit strange - a backend API creates a device. I think we should try
|
||||
// express this as a backend registry functionality instead
|
||||
GGML_API ggml_backend_dev_t ggml_backend_meta_device(
|
||||
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud);
|
||||
|
||||
//
|
||||
// Utils
|
||||
//
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <limits.h>
|
||||
#include <stdarg.h>
|
||||
|
||||
@@ -5,9 +5,6 @@
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
// TODO: tmp
|
||||
#include "ggml-ext.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
@@ -783,6 +783,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_1, m4b));
|
||||
const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4));
|
||||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs);
|
||||
const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16);
|
||||
const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b));
|
||||
@@ -794,15 +795,40 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b));
|
||||
|
||||
const int32x4_t p0 = vaddq_s32(
|
||||
ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0),
|
||||
ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0));
|
||||
vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0),
|
||||
vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0));
|
||||
const int32x4_t p1 = vaddq_s32(
|
||||
ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1),
|
||||
ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1));
|
||||
vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1),
|
||||
vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1));
|
||||
|
||||
const int32x4_t sums = vpaddq_s32(p0, p1);
|
||||
const int32x4_t sumi = vpaddq_s32(p0, p1);
|
||||
#else
|
||||
const int8x8_t q4_0_lo = vget_low_s8(q4_lo_0);
|
||||
const int8x8_t q4_0_hi = vget_low_s8(q4_hi_0);
|
||||
const int8x8_t q4_1_lo = vget_high_s8(q4_lo_0);
|
||||
const int8x8_t q4_1_hi = vget_high_s8(q4_hi_0);
|
||||
const int8x8_t q4_2_lo = vget_low_s8(q4_lo_1);
|
||||
const int8x8_t q4_2_hi = vget_low_s8(q4_hi_1);
|
||||
const int8x8_t q4_3_lo = vget_high_s8(q4_lo_1);
|
||||
const int8x8_t q4_3_hi = vget_high_s8(q4_hi_1);
|
||||
|
||||
const int8x8_t q8_0_lo = vld1_s8(y[2*ib].qs);
|
||||
const int8x8_t q8_0_hi = vld1_s8(y[2*ib].qs + 8);
|
||||
const int8x8_t q8_1_lo = vld1_s8(y[2*ib].qs + 16);
|
||||
const int8x8_t q8_1_hi = vld1_s8(y[2*ib].qs + 24);
|
||||
const int8x8_t q8_2_lo = vld1_s8(y[2*ib+1].qs);
|
||||
const int8x8_t q8_2_hi = vld1_s8(y[2*ib+1].qs + 8);
|
||||
const int8x8_t q8_3_lo = vld1_s8(y[2*ib+1].qs + 16);
|
||||
const int8x8_t q8_3_hi = vld1_s8(y[2*ib+1].qs + 24);
|
||||
|
||||
const int32x4_t sumi = (int32x4_t){
|
||||
vaddvq_s32(ggml_nvfp4_dot8(q4_0_lo, q8_0_lo, q4_0_hi, q8_0_hi)),
|
||||
vaddvq_s32(ggml_nvfp4_dot8(q4_1_lo, q8_1_lo, q4_1_hi, q8_1_hi)),
|
||||
vaddvq_s32(ggml_nvfp4_dot8(q4_2_lo, q8_2_lo, q4_2_hi, q8_2_hi)),
|
||||
vaddvq_s32(ggml_nvfp4_dot8(q4_3_lo, q8_3_lo, q4_3_hi, q8_3_hi)),
|
||||
};
|
||||
#endif
|
||||
|
||||
// Decode 4 UE4M3 scales to f32 and multiply with q8 scales
|
||||
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
|
||||
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
|
||||
const float32x4_t nvsc = {
|
||||
@@ -813,7 +839,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
};
|
||||
const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1});
|
||||
|
||||
acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales);
|
||||
acc = vfmaq_f32(acc, vcvtq_f32_s32(sumi), scales);
|
||||
}
|
||||
sumf = vaddvq_f32(acc);
|
||||
#else
|
||||
|
||||
@@ -306,6 +306,7 @@ inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
|
||||
|
||||
#if !defined(__ARM_FEATURE_DOTPROD)
|
||||
|
||||
// NOTE: this fallback produces the same total sum as native vdotq_s32 but with different per-lane grouping — do not use when individual lane values matter.
|
||||
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
||||
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
||||
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
||||
@@ -319,6 +320,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
|
||||
|
||||
#endif // !defined(__ARM_FEATURE_DOTPROD)
|
||||
|
||||
static inline int32x4_t ggml_nvfp4_dot8(const int8x8_t q4_lo, const int8x8_t q8_lo,
|
||||
const int8x8_t q4_hi, const int8x8_t q8_hi) {
|
||||
const int16x8_t p_lo = vmull_s8(q4_lo, q8_lo);
|
||||
const int16x8_t p_hi = vmull_s8(q4_hi, q8_hi);
|
||||
const int32x4_t sum_lo = vpaddlq_s16(p_lo);
|
||||
const int32x4_t sum_hi = vpaddlq_s16(p_hi);
|
||||
return vaddq_s32(sum_lo, sum_hi);
|
||||
}
|
||||
|
||||
#endif // defined(__ARM_NEON)
|
||||
|
||||
#ifdef __wasm_simd128__
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
// This is a "staging" header for new ggml API
|
||||
// It is not publicly available and it should not be used by 3rd party projects
|
||||
//
|
||||
// When the API matures enough, it will be moved to the official public API
|
||||
|
||||
//
|
||||
// Meta backend
|
||||
//
|
||||
|
||||
#define GGML_BACKEND_META_MAX_DEVICES 16
|
||||
|
||||
enum ggml_backend_meta_split_axis {
|
||||
// tensor split by tensor dimensions:
|
||||
GGML_BACKEND_SPLIT_AXIS_0 = 0,
|
||||
GGML_BACKEND_SPLIT_AXIS_1 = 1,
|
||||
GGML_BACKEND_SPLIT_AXIS_2 = 2,
|
||||
GGML_BACKEND_SPLIT_AXIS_3 = 3,
|
||||
|
||||
GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends
|
||||
GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum
|
||||
|
||||
// for internal bookkeeping only:
|
||||
GGML_BACKEND_SPLIT_AXIS_NONE = 98,
|
||||
GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99,
|
||||
};
|
||||
GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis);
|
||||
|
||||
struct ggml_backend_meta_split_state {
|
||||
enum ggml_backend_meta_split_axis axis;
|
||||
|
||||
// for tensors with axis >= 0 && axis < GGML_MAX_DIMS:
|
||||
// - each device has a slice of the tensor along the split axis
|
||||
// - most tensors have n_segments == 1 and a contiguous slice of the tensor data
|
||||
// - some tensors have an inhomogenenous data layout along the split axis,
|
||||
// those tensors are divided into segments which are each individually split across devices
|
||||
// - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis,
|
||||
// the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1],
|
||||
// - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments
|
||||
// that each need to be split individually across devices so that each device gets a slice of Q, K, and V
|
||||
int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES];
|
||||
uint32_t n_segments;
|
||||
};
|
||||
|
||||
// function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible:
|
||||
typedef struct ggml_backend_meta_split_state(*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
// create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this:
|
||||
// TODO: this looks a bit strange - a backend API creates a device. I think we should try
|
||||
// express this as a backend registry functionality instead
|
||||
GGML_API ggml_backend_dev_t ggml_backend_meta_device(
|
||||
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud);
|
||||
@@ -47,6 +47,7 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
|
||||
if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-queue.c
|
||||
hmx-matmul-ops.c
|
||||
)
|
||||
|
||||
|
||||
@@ -31,6 +31,14 @@ static inline uint64_t hex_get_pktcnt() {
|
||||
return pktcnt;
|
||||
}
|
||||
|
||||
static inline uint32_t hex_ceil_pow2(uint32_t x) {
|
||||
if (x <= 1) { return 1; }
|
||||
int p = 2;
|
||||
x--;
|
||||
while (x >>= 1) { p <<= 1; }
|
||||
return p;
|
||||
}
|
||||
|
||||
static inline size_t hmx_ceil_div(size_t num, size_t den) {
|
||||
return (num + den - 1) / den;
|
||||
}
|
||||
@@ -73,8 +81,7 @@ static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride,
|
||||
#define HEX_L2_LINE_SIZE 64
|
||||
#define HEX_L2_FLUSH_SIZE (128 * 1024)
|
||||
|
||||
static inline void hex_l2flush(void * addr, size_t size)
|
||||
{
|
||||
static inline void hex_l2flush(void * addr, size_t size) {
|
||||
if (size > HEX_L2_FLUSH_SIZE) {
|
||||
qurt_mem_cache_clean((qurt_addr_t) 0, 0, QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE);
|
||||
} else {
|
||||
@@ -89,4 +96,8 @@ static inline void hex_l2flush(void * addr, size_t size)
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hex_pause() {
|
||||
asm volatile(" pause(#255)\n");
|
||||
}
|
||||
|
||||
#endif /* HEX_UTILS_H */
|
||||
|
||||
@@ -16,14 +16,16 @@
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#include "hvx-utils.h"
|
||||
#include "hvx-dump.h"
|
||||
#include "worker-pool.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#include "hmx-utils.h"
|
||||
#include "hmx-ops.h"
|
||||
#include "hmx-utils.h"
|
||||
#include "hmx-queue.h"
|
||||
#include "hmx-profile.h"
|
||||
|
||||
static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
@@ -47,7 +49,8 @@ static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = {
|
||||
0*128, 1*128, 2*128, 3*128, 4*128, 5*128, 6*128, 7*128,
|
||||
8*128, 9*128, 10*128, 11*128, 12*128, 13*128, 14*128, 15*128,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
16*128, 17*128, 18*128, 19*128, 20*128, 21*128, 22*128, 23*128,
|
||||
24*128, 25*128, 26*128, 27*128, 28*128, 29*128, 30*128, 31*128
|
||||
};
|
||||
|
||||
// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes
|
||||
@@ -109,36 +112,45 @@ static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Search for optimal (mc, nc) chunk sizes that maximize mc * nc within VTCM budget.
|
||||
// Search for optimal (mc, nc) chunk sizes within VTCM budget.
|
||||
//
|
||||
// Cost model: total = nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead
|
||||
// per_n_cost: bytes per nc column (weight + scratch buffers)
|
||||
// per_m_cost: bytes per mc row (activation)
|
||||
// per_mn_cost: bytes per mc*nc element (output)
|
||||
// overhead: fixed bytes (scales 256B, eye_tile 2048B, etc.)
|
||||
// VTCM model: nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead
|
||||
//
|
||||
// Minimize ceil(m/mc) * m_block_cost + ceil(n/nc) * n_block_cost.
|
||||
// All matmul paths repeat weight processing per M-block and activation loading
|
||||
// per N-block, so discrete block counts drive total overhead.
|
||||
// Tie-break: when cost is equal, prefer larger mc * nc.
|
||||
//
|
||||
// Caller-provided coefficients:
|
||||
// m_block_cost: penalty per extra M-block (weight redundancy, scales with n).
|
||||
// n_block_cost: penalty per extra N-block (activation redundancy, scales with m).
|
||||
//
|
||||
// Algorithm: nc sweeps from n_max down by 32, analytically solving for mc_max.
|
||||
// Returns 0 on success, -1 if VTCM is insufficient.
|
||||
static int hmx_compute_chunks(
|
||||
size_t vtcm_total, size_t overhead,
|
||||
size_t per_n_cost, size_t per_m_cost, size_t per_mn_cost,
|
||||
int m, int n,
|
||||
size_t *m_chunk_out, size_t *n_chunk_out,
|
||||
size_t *total_out)
|
||||
{
|
||||
static int hmx_compute_chunks(size_t vtcm_total,
|
||||
size_t overhead,
|
||||
size_t per_n_cost,
|
||||
size_t per_m_cost,
|
||||
size_t per_mn_cost,
|
||||
int m,
|
||||
int n,
|
||||
size_t m_block_cost,
|
||||
size_t n_block_cost,
|
||||
size_t * m_chunk_out,
|
||||
size_t * n_chunk_out,
|
||||
size_t * total_out) {
|
||||
if (m <= 0 || n <= 0) return -1;
|
||||
if (vtcm_total <= overhead) return -1;
|
||||
if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1;
|
||||
|
||||
const size_t usable = vtcm_total - overhead;
|
||||
size_t best_mn = 0, best_m = 0, best_n = 0;
|
||||
|
||||
size_t best_cost = SIZE_MAX;
|
||||
size_t best_mn = 0;
|
||||
size_t best_m = 0, best_n = 0;
|
||||
|
||||
const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS);
|
||||
for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) {
|
||||
// Early exit: if nc * m_max cannot beat best, smaller nc won't either
|
||||
if (nc * hex_align_down((size_t)m, HMX_FP16_TILE_N_ROWS) <= best_mn)
|
||||
break;
|
||||
|
||||
size_t n_fixed = 0, ncmn = 0, mc_denom = 0;
|
||||
if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue;
|
||||
if (n_fixed >= usable) goto next_nc;
|
||||
@@ -152,10 +164,19 @@ static int hmx_compute_chunks(
|
||||
mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS);
|
||||
mc = hex_smin(mc, (size_t)m);
|
||||
|
||||
if (mc > 0 && mc * nc > best_mn) {
|
||||
best_mn = mc * nc;
|
||||
best_m = mc;
|
||||
best_n = nc;
|
||||
if (mc == 0) {
|
||||
goto next_nc;
|
||||
}
|
||||
|
||||
size_t mblocks = ((size_t) m + mc - 1) / mc;
|
||||
size_t nblocks = ((size_t) n + nc - 1) / nc;
|
||||
size_t cost = mblocks * m_block_cost + nblocks * n_block_cost;
|
||||
size_t mn = mc * nc;
|
||||
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
|
||||
best_cost = cost;
|
||||
best_mn = mn;
|
||||
best_m = mc;
|
||||
best_n = nc;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,7 +254,7 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
|
||||
// q4x4x2 stores two int4 values per byte. Keep only the selected nibble.
|
||||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
// Shuffle before LUT
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
@@ -257,7 +278,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
// Load all 128 packed bytes (4 contiguous 32-byte groups)
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
// Shuffle before LUT
|
||||
@@ -277,10 +298,8 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
// Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter
|
||||
out[0] = v_lo; // group0 already in [0:63]
|
||||
out[1] = Q6_V_vror_VR(v_lo, 64); // group1 rotated to [0:63]
|
||||
out[2] = v_hi; // group2 already in [0:63]
|
||||
out[3] = Q6_V_vror_VR(v_hi, 64); // group3 rotated to [0:63]
|
||||
out[0] = v_lo; // group0 already in [0:63]
|
||||
out[1] = v_hi; // group2 already in [0:63]
|
||||
}
|
||||
|
||||
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
|
||||
@@ -384,8 +403,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
size_t row_stride, int weight_type,
|
||||
int start_tile, int end_tile) {
|
||||
|
||||
const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
|
||||
const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2);
|
||||
const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS;
|
||||
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
|
||||
const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block;
|
||||
|
||||
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
|
||||
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
|
||||
@@ -398,47 +418,46 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes)
|
||||
|
||||
for (int t = start_tile; t < end_tile; ) {
|
||||
int ct = t / n_k_tiles; // column tile index
|
||||
int kt = t % n_k_tiles; // K tile index
|
||||
unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index
|
||||
unsigned kt = (unsigned)start_tile % n_k_tiles; // K tile index
|
||||
for (unsigned t = start_tile; t < end_tile; ) {
|
||||
if (kt >= n_k_tiles) { kt = 0; ct++; }
|
||||
|
||||
// --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row ---
|
||||
if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) &&
|
||||
((t + 3) / n_k_tiles == ct)) {
|
||||
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
int packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
|
||||
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE
|
||||
+ sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales
|
||||
// --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row ---
|
||||
if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||||
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
|
||||
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE
|
||||
+ sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales
|
||||
|
||||
__fp16 *tile_bases[4];
|
||||
for (int g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
|
||||
for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
|
||||
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
const uint8_t *r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row1 * row_stride;
|
||||
|
||||
HVX_Vector v0[4], v1[4];
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
|
||||
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
HVX_Vector v0[2];
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
|
||||
if (row1 < n_cols) {
|
||||
dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt, v1);
|
||||
} else {
|
||||
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); }
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); }
|
||||
|
||||
|
||||
r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); }
|
||||
|
||||
t += 4;
|
||||
t += 4; kt += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -495,20 +514,19 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
// --- Single-tile fallback ---
|
||||
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) {
|
||||
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
int byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
|
||||
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
|
||||
if (is_q4) {
|
||||
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
|
||||
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
|
||||
|
||||
HVX_Vector v_off = v_scat_base; // reset to column 0
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
|
||||
const uint8_t *r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row1 * row_stride;
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(
|
||||
r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
@@ -585,7 +603,7 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
}
|
||||
(void) *(volatile HVX_Vector *)(tile_base);
|
||||
}
|
||||
++t;
|
||||
++t; ++kt;
|
||||
}
|
||||
|
||||
// Drain HVX scatter write buffer: a vmem load on the same HW thread retires
|
||||
@@ -653,9 +671,13 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
|
||||
// --- End x4x2 dequantizers ---
|
||||
|
||||
// requires external HMX lock
|
||||
static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const __fp16 *weight, const __fp16 *scales,
|
||||
static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales,
|
||||
int n_row_tiles, int n_col_tiles, int n_dot_tiles) {
|
||||
hmx_set_output_scales(scales);
|
||||
__builtin_assume(n_row_tiles > 0);
|
||||
__builtin_assume(n_col_tiles > 0);
|
||||
__builtin_assume(n_dot_tiles > 0);
|
||||
|
||||
Q6_bias_mxmem2_A((void *)scales);
|
||||
|
||||
for (int r = 0; r < n_row_tiles; ++r) {
|
||||
for (int c = 0; c < n_col_tiles; ++c) {
|
||||
@@ -665,16 +687,55 @@ static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const
|
||||
const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
for (int k = 0; k < n_dot_tiles; ++k) {
|
||||
int offset = k * HMX_FP16_TILE_N_ELMS;
|
||||
hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset);
|
||||
Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047);
|
||||
row_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
__fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS;
|
||||
hmx_consume_accumulator_fp16(out_tile);
|
||||
Q6_mxmem_AR_after_hf(out_tile, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Async HMX matmul job (for pipeline overlap) ---
|
||||
|
||||
typedef struct {
|
||||
__fp16 * output;
|
||||
const __fp16 * activation;
|
||||
const __fp16 * weight;
|
||||
const __fp16 * scales;
|
||||
uint32_t n_row_tiles;
|
||||
uint32_t n_col_tiles;
|
||||
uint32_t n_dot_tiles;
|
||||
} hmx_matmul_job_t;
|
||||
|
||||
static void hmx_matmul_worker_fn(void * data) {
|
||||
hmx_matmul_job_t * job = (hmx_matmul_job_t *) data;
|
||||
FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles);
|
||||
core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles);
|
||||
}
|
||||
|
||||
static inline void hmx_matmul_job_init(hmx_matmul_job_t * job,
|
||||
__fp16 * output,
|
||||
const __fp16 * activation,
|
||||
const __fp16 * weight,
|
||||
const __fp16 * scales,
|
||||
int n_row_tiles,
|
||||
int n_col_tiles,
|
||||
int n_dot_tiles) {
|
||||
job->output = output;
|
||||
job->activation = activation;
|
||||
job->weight = weight;
|
||||
job->scales = scales;
|
||||
job->n_row_tiles = n_row_tiles;
|
||||
job->n_col_tiles = n_col_tiles;
|
||||
job->n_dot_tiles = n_dot_tiles;
|
||||
}
|
||||
|
||||
// --- End async HMX matmul job ---
|
||||
|
||||
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) {
|
||||
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
|
||||
const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
|
||||
@@ -832,12 +893,13 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
|
||||
const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0;
|
||||
|
||||
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
|
||||
// FP16 weight: interleave and activation load have similar per-element cost.
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256,
|
||||
/*per_n=*/3 * vec_dot_size,
|
||||
/*per_m=*/group_size * vec_dot_size + f32_scratch_per_m,
|
||||
/*per_mn=*/sizeof(__fp16),
|
||||
params->m, params->n,
|
||||
&m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
/*per_n=*/3 * vec_dot_size,
|
||||
/*per_m=*/group_size * vec_dot_size + f32_scratch_per_m,
|
||||
/*per_mn=*/sizeof(__fp16), params->m, params->n,
|
||||
/*m_block_cost=*/(size_t) params->n,
|
||||
/*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__);
|
||||
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
|
||||
}
|
||||
@@ -1006,13 +1068,15 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
|
||||
const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0;
|
||||
|
||||
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
|
||||
// FP16 weight: interleave and activation load have similar per-element cost.
|
||||
if (hmx_compute_chunks(vtcm_budget,
|
||||
/*overhead=*/ 256,
|
||||
/*per_n=*/ 3 * vec_dot_size, // W + S0 + S1
|
||||
/*per_m=*/ vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch
|
||||
/*per_mn=*/ sizeof(__fp16), // O
|
||||
m, n,
|
||||
&m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
/*overhead=*/256,
|
||||
/*per_n=*/3 * vec_dot_size, // W + S0 + S1
|
||||
/*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch
|
||||
/*per_mn=*/sizeof(__fp16), // O
|
||||
m, n,
|
||||
/*m_block_cost=*/(size_t) n,
|
||||
/*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
|
||||
return -1;
|
||||
}
|
||||
@@ -1157,6 +1221,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
|
||||
int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m,
|
||||
int k, int n, int w_type);
|
||||
|
||||
#define FALLBACK_TO_STANDARD 1
|
||||
|
||||
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
|
||||
const uint8_t *restrict permuted_weight, int m, int k, int n,
|
||||
int weight_type) {
|
||||
@@ -1169,9 +1235,12 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
|
||||
// for large m, k (e.g. prefill FFN Down), use out-stationary version
|
||||
if (m >= 128 && k > n && n > 1024) {
|
||||
FARF(MEDIUM, "hmx_matmul_qk: OUT-STATIONARY path m=%d k=%d n=%d type=%d (K_BLOCK=512, %d K-iters with fp16 intermediate)",
|
||||
m, k, n, weight_type, (k + 511) / 512);
|
||||
return mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type);
|
||||
int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type);
|
||||
if (rc != FALLBACK_TO_STANDARD) {
|
||||
return rc; // 0 success, -1 error
|
||||
}
|
||||
FARF(MEDIUM, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n);
|
||||
// fall through to standard path
|
||||
}
|
||||
|
||||
size_t row_stride = get_x4x2_row_stride(weight_type, k);
|
||||
@@ -1197,9 +1266,10 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
}
|
||||
|
||||
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256,
|
||||
per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost,
|
||||
m, n, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
// Quantized weight: dequant ~1.5x more expensive per element than activation load.
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost, m, n,
|
||||
/*m_block_cost=*/(size_t) n * 3,
|
||||
/*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d pipe=%d budget=%zu)",
|
||||
__func__, m, k, n, use_pipeline, vtcm_budget);
|
||||
return -1;
|
||||
@@ -1256,9 +1326,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols,
|
||||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||||
|
||||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||||
|
||||
if (!use_pipeline) {
|
||||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||||
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
|
||||
// transfer activation matrix chunk into VTCM
|
||||
size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
|
||||
@@ -1318,20 +1387,22 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
TIMER_STOP(output_store);
|
||||
}
|
||||
}
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
} else {
|
||||
// 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D)
|
||||
// stage B and D (dequantize and store) are expected to be on the critical path
|
||||
// HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D).
|
||||
|
||||
// A --> B: vtcm_qweight, 1 buffer
|
||||
// B --> C: vtcm_weight0/vtcm_weight1, 2 buffers
|
||||
// C --> D: vtcm_output0/vtcm_output1, 2 buffers
|
||||
|
||||
//
|
||||
// LD ||A3| | B3 ||
|
||||
// MM || C2 ||
|
||||
// ST || D1 | ||
|
||||
// Async timeline (C overlaps B+D):
|
||||
// main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2]
|
||||
// HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████]
|
||||
|
||||
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
|
||||
hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors
|
||||
|
||||
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
|
||||
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
|
||||
|
||||
@@ -1352,31 +1423,34 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k);
|
||||
}
|
||||
|
||||
// prologue: B0, A1, C0, B1
|
||||
// prologue: B0, A1, submit C0 (async), B1 (overlaps C0)
|
||||
{
|
||||
// B0
|
||||
// B0: wait for DMA, dequant weight chunk 0
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type);
|
||||
|
||||
// A1
|
||||
// A1: issue DMA for weight chunk 1
|
||||
const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
|
||||
if (1 < n_chunk_cnt) {
|
||||
const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1);
|
||||
}
|
||||
|
||||
// C0
|
||||
core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[0], vtcm_scales,
|
||||
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||||
// submit C0 (non-blocking — HMX worker executes in parallel)
|
||||
hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation,
|
||||
(__fp16 *) vtcm_weight_bufs[0], vtcm_scales,
|
||||
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
|
||||
hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||||
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0]));
|
||||
|
||||
// B1
|
||||
// B1: DMA pop + dequant (runs in parallel with C0 on HMX worker)
|
||||
if (1 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type);
|
||||
}
|
||||
}
|
||||
|
||||
// main loop
|
||||
// main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1})
|
||||
for (int i = 0; i < n_chunk_cnt; ++i) {
|
||||
const size_t nc = i * n_chunk_n_cols;
|
||||
const size_t nc_p1 = nc + 1 * n_chunk_n_cols;
|
||||
@@ -1386,36 +1460,41 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols);
|
||||
const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols);
|
||||
|
||||
// issue A_{i+2}
|
||||
// issue A_{i+2}: DMA push (non-blocking)
|
||||
if (i + 2 < n_chunk_cnt) {
|
||||
const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2);
|
||||
}
|
||||
|
||||
// wait for HMX (C_{i}) -- C_{i} is done
|
||||
// wait C_i: block until prologue/previous C completes
|
||||
hmx_queue_pop(ctx->hmx_queue);
|
||||
|
||||
// result of B_{i+1} (input of C_{i+1}) should be ready now
|
||||
|
||||
// issue C_{i+1}
|
||||
// submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below)
|
||||
// job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's
|
||||
// counterpart — and (i+1)%2 was last used by C_{i-1} which completed
|
||||
// before C_i was submitted.
|
||||
if (i + 1 < n_chunk_cnt) {
|
||||
core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[(i + 1) % 2], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], vtcm_scales,
|
||||
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||||
hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2],
|
||||
(__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2],
|
||||
vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
|
||||
hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||||
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2]));
|
||||
}
|
||||
|
||||
// compute D_{i}
|
||||
// D_i: store output (multi-thread HVX, parallel with C_{i+1})
|
||||
float *output_chunk = dst + (mr * n + nc);
|
||||
transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n);
|
||||
|
||||
// wait for DMA (A_{i+2}), compute B_{i+2}
|
||||
// B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1})
|
||||
if (i + 2 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
hmx_queue_suspend(ctx->hmx_queue);
|
||||
}
|
||||
|
||||
TIMER_STOP(total);
|
||||
|
||||
@@ -1434,10 +1513,13 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
}
|
||||
|
||||
// C += AB
|
||||
void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp16 *col_scales, const __fp16 *eye_tile,
|
||||
void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile,
|
||||
int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) {
|
||||
__builtin_assume(n_row_tiles > 0);
|
||||
__builtin_assume(n_col_tiles > 0);
|
||||
__builtin_assume(n_dot_tiles > 0);
|
||||
|
||||
hmx_set_output_scales(col_scales);
|
||||
Q6_bias_mxmem2_A((void *)col_scales);
|
||||
|
||||
for (int i = 0; i < n_row_tiles; ++i) {
|
||||
for (int j = 0; j < n_col_tiles; ++j) {
|
||||
@@ -1448,15 +1530,17 @@ void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp
|
||||
|
||||
__fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS;
|
||||
if (!zero_init) {
|
||||
hmx_load_tile_pair_fp16(accum_tile, eye_tile);
|
||||
Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047);
|
||||
}
|
||||
|
||||
for (int k = 0; k < n_dot_tiles; ++k) {
|
||||
int offset = k * HMX_FP16_TILE_N_ELMS;
|
||||
hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset);
|
||||
Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047);
|
||||
row_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
hmx_consume_accumulator_fp16(accum_tile);
|
||||
Q6_mxmem_AR_after_hf(accum_tile, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1540,12 +1624,41 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
|
||||
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
|
||||
const size_t M_BLOCK_SIZE = 512;
|
||||
const size_t N_BLOCK_SIZE = 512;
|
||||
const size_t K_BLOCK_SIZE = 512;
|
||||
const size_t K_BLOCK_SIZE = 1024;
|
||||
|
||||
// Compute precise buffer sizes
|
||||
// Fallback: if k doesn't need K-blocking, out-stationary has no advantage
|
||||
const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE;
|
||||
if (k_iters_check <= 1) {
|
||||
FARF(MEDIUM, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k);
|
||||
return FALLBACK_TO_STANDARD;
|
||||
}
|
||||
|
||||
// Dynamic M,N search via hmx_compute_chunks
|
||||
const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE);
|
||||
const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32)
|
||||
+ K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles)
|
||||
const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant)
|
||||
+ K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles)
|
||||
const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary)
|
||||
// Alignment margin: hex_align_up can add up to 2047 bytes per buffer;
|
||||
// scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin
|
||||
const size_t align_margin = 4 * HMX_FP16_TILE_SIZE;
|
||||
const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment
|
||||
|
||||
size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used;
|
||||
// Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost.
|
||||
// From profiling: wt_dequant per element ≈ 1.5× activation load per element.
|
||||
// m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive).
|
||||
// n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper).
|
||||
const size_t m_block_cost = (size_t) n * 3;
|
||||
const size_t n_block_cost = (size_t) m * 2;
|
||||
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE,
|
||||
&N_BLOCK_SIZE, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Compute precise buffer sizes from searched M,N and fixed K
|
||||
const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
@@ -1554,7 +1667,8 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
|
||||
|
||||
const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256;
|
||||
if (total_vtcm > vtcm_budget) {
|
||||
FARF(HIGH, "%s: VTCM too small: need %zu have %zu (m=%d k=%d n=%d)", __func__, total_vtcm, vtcm_budget, m, k, n);
|
||||
FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm,
|
||||
vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE);
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -1568,8 +1682,8 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
|
||||
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
|
||||
assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget);
|
||||
|
||||
FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu", __func__, m, k, n, weight_type,
|
||||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||||
FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", __func__, m, k, n, weight_type,
|
||||
M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
|
||||
|
||||
// initialize eye tile (32x32 identity matrix)
|
||||
{
|
||||
|
||||
158
ggml/src/ggml-hexagon/htp/hmx-queue.c
Normal file
158
ggml/src/ggml-hexagon/htp/hmx-queue.c
Normal file
@@ -0,0 +1,158 @@
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <qurt_thread.h>
|
||||
#include <qurt_futex.h>
|
||||
|
||||
#include <HAP_compute_res.h>
|
||||
|
||||
#include "hmx-queue.h"
|
||||
|
||||
#define QURT_LOWEST_PRIO (254)
|
||||
|
||||
static inline void hmx_lock(struct hmx_queue *q)
|
||||
{
|
||||
if (!q->hmx_locked) {
|
||||
HAP_compute_res_hmx_lock(q->hap_rctx);
|
||||
q->hmx_locked = true;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hmx_unlock(struct hmx_queue *q)
|
||||
{
|
||||
if (q->hmx_locked) {
|
||||
HAP_compute_res_hmx_unlock(q->hap_rctx);
|
||||
q->hmx_locked = false;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) {
|
||||
unsigned int ir = atomic_load(&q->idx_read);
|
||||
|
||||
while (ir != atomic_load(&q->idx_write)) {
|
||||
struct hmx_queue_desc *d = &q->desc[ir];
|
||||
if (!d->done) {
|
||||
FARF(HIGH, "hmx-queue-process: ir %u func %p data %p", ir, d->func, d->data);
|
||||
|
||||
enum hmx_queue_signal sig = (enum hmx_queue_signal) (unsigned int) d->func;
|
||||
switch (sig) {
|
||||
case HMX_QUEUE_NOOP: /* noop */; break;
|
||||
case HMX_QUEUE_KILL: *killed = true; break;
|
||||
case HMX_QUEUE_SUSPEND: hmx_unlock(q); break;
|
||||
default:
|
||||
hmx_lock(q);
|
||||
d->func(d->data);
|
||||
break;
|
||||
}
|
||||
|
||||
atomic_fetch_add(&d->done, 1);
|
||||
}
|
||||
|
||||
ir = (ir + 1) & q->idx_mask;
|
||||
atomic_store(&q->idx_read, ir);
|
||||
}
|
||||
}
|
||||
|
||||
static void hmx_queue_thread(void * arg) {
|
||||
struct hmx_queue * q = (struct hmx_queue *) arg;
|
||||
|
||||
FARF(HIGH, "hmx-queue-thread: started");
|
||||
|
||||
bool killed = false;
|
||||
|
||||
unsigned int poll_cnt = HMX_QUEUE_POLL_COUNT;
|
||||
unsigned int prev_seqn = 0;
|
||||
while (!killed) {
|
||||
unsigned int seqn = atomic_load(&q->seqn);
|
||||
if (seqn == prev_seqn) {
|
||||
if (--poll_cnt) { hex_pause(); continue; }
|
||||
FARF(HIGH, "hmx-queue-thread: sleeping");
|
||||
qurt_futex_wait(&q->seqn, prev_seqn);
|
||||
continue;
|
||||
}
|
||||
prev_seqn = seqn;
|
||||
poll_cnt = HMX_QUEUE_POLL_COUNT;
|
||||
|
||||
FARF(HIGH, "hmx-queue-thread: new work");
|
||||
|
||||
hmx_queue_process(q, &killed);
|
||||
}
|
||||
|
||||
FARF(HIGH, "hmx-queue-thread: stopped");
|
||||
}
|
||||
|
||||
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx) {
|
||||
capacity = hex_ceil_pow2(capacity);
|
||||
|
||||
struct hmx_queue * q = (struct hmx_queue *) memalign(32, sizeof(struct hmx_queue));
|
||||
if (q == NULL) {
|
||||
FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__);
|
||||
return NULL;
|
||||
}
|
||||
memset(q, 0, sizeof(struct hmx_queue));
|
||||
q->capacity = capacity;
|
||||
q->idx_mask = capacity - 1;
|
||||
q->hap_rctx = hap_rctx;
|
||||
|
||||
q->desc = (struct hmx_queue_desc *) memalign(64, capacity * sizeof(struct hmx_queue_desc));
|
||||
if (!q->desc) {
|
||||
FARF(ERROR, "hmx-queue: failed to allocate HMX queue descriptors\n");
|
||||
return NULL;
|
||||
}
|
||||
memset(q->desc, 0, capacity * sizeof(struct hmx_queue_desc));
|
||||
|
||||
const size_t stack_size = HMX_QUEUE_THREAD_STACK_SIZE;
|
||||
q->stack = (unsigned char *) memalign(64, stack_size);
|
||||
if (!q->stack) {
|
||||
FARF(ERROR, "hmx-queue: thread stack allocation failed (%zu bytes)", stack_size);
|
||||
return NULL;
|
||||
}
|
||||
memset(q->stack, 0, stack_size);
|
||||
|
||||
// Match caller thread priority (same pattern as worker-pool.c).
|
||||
int prio = qurt_thread_get_priority(qurt_thread_get_id());
|
||||
if (prio < 1) {
|
||||
prio = 1;
|
||||
}
|
||||
if (prio > QURT_LOWEST_PRIO) {
|
||||
prio = QURT_LOWEST_PRIO;
|
||||
}
|
||||
|
||||
qurt_thread_attr_t attr;
|
||||
qurt_thread_attr_init(&attr);
|
||||
qurt_thread_attr_set_stack_addr(&attr, q->stack);
|
||||
qurt_thread_attr_set_stack_size(&attr, stack_size);
|
||||
qurt_thread_attr_set_priority(&attr, prio);
|
||||
qurt_thread_attr_set_name(&attr, "hmx-queue");
|
||||
|
||||
int err = qurt_thread_create(&q->thread, &attr, hmx_queue_thread, q);
|
||||
if (err) {
|
||||
FARF(ERROR, "hmx-worker: thread create failed (%d)", err);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
FARF(HIGH, "hmx-queue: capacity %u\n", capacity);
|
||||
|
||||
return q;
|
||||
}
|
||||
|
||||
void hmx_queue_delete(struct hmx_queue * q) {
|
||||
if (!q) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Tell the worker to exit.
|
||||
hmx_queue_flush(q);
|
||||
hmx_queue_signal(q, HMX_QUEUE_KILL);
|
||||
hmx_queue_flush(q);
|
||||
|
||||
int status;
|
||||
qurt_thread_join(q->thread, &status);
|
||||
|
||||
free(q->desc);
|
||||
free(q->stack);
|
||||
free(q);
|
||||
}
|
||||
134
ggml/src/ggml-hexagon/htp/hmx-queue.h
Normal file
134
ggml/src/ggml-hexagon/htp/hmx-queue.h
Normal file
@@ -0,0 +1,134 @@
|
||||
#ifndef HMX_QUEUE_H
|
||||
#define HMX_QUEUE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdatomic.h>
|
||||
|
||||
#include <hexagon_types.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <qurt_futex.h>
|
||||
#include <HAP_farf.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define HMX_QUEUE_THREAD_STACK_SIZE (16 * 1024)
|
||||
#define HMX_QUEUE_POLL_COUNT 2000
|
||||
|
||||
typedef void (*hmx_queue_func)(void *);
|
||||
|
||||
// Dummy funcs used as signals
|
||||
enum hmx_queue_signal {
|
||||
HMX_QUEUE_NOOP = 0, // aka NULL
|
||||
HMX_QUEUE_SUSPEND,
|
||||
HMX_QUEUE_KILL
|
||||
};
|
||||
|
||||
struct hmx_queue_desc {
|
||||
hmx_queue_func func;
|
||||
void * data;
|
||||
atomic_uint done;
|
||||
};
|
||||
|
||||
struct hmx_queue {
|
||||
struct hmx_queue_desc * desc;
|
||||
atomic_uint idx_write; // updated by producer (push)
|
||||
atomic_uint idx_read; // updated by consumer (process)
|
||||
unsigned int idx_pop; // updated by producer (pop)
|
||||
uint32_t idx_mask;
|
||||
uint32_t capacity;
|
||||
|
||||
atomic_uint seqn; // incremented for all pushes, used with futex
|
||||
qurt_thread_t thread;
|
||||
void * stack;
|
||||
uint32_t hap_rctx;
|
||||
bool hmx_locked;
|
||||
};
|
||||
|
||||
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx);
|
||||
void hmx_queue_delete(struct hmx_queue * q);
|
||||
|
||||
static inline struct hmx_queue_desc hmx_queue_make_desc(hmx_queue_func func, void * data) {
|
||||
struct hmx_queue_desc d = { func, data };
|
||||
return d;
|
||||
}
|
||||
|
||||
static inline bool hmx_queue_push(struct hmx_queue * q, struct hmx_queue_desc d) {
|
||||
unsigned int ir = atomic_load(&q->idx_read);
|
||||
unsigned int iw = q->idx_write;
|
||||
|
||||
if (((iw + 1) & q->idx_mask) == ir) {
|
||||
FARF(HIGH, "hmx-queue-push: queue is full\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
atomic_store(&d.done, 0);
|
||||
|
||||
FARF(HIGH, "hmx-queue-push: iw %u func %p data %p\n", iw, d.func, d.data);
|
||||
|
||||
q->desc[iw] = d;
|
||||
atomic_store(&q->idx_write, (iw + 1) & q->idx_mask);
|
||||
// wake up our thread
|
||||
atomic_fetch_add(&q->seqn, 1);
|
||||
qurt_futex_wake(&q->seqn, 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline bool hmx_queue_signal(struct hmx_queue *q, enum hmx_queue_signal sig) {
|
||||
return hmx_queue_push(q, hmx_queue_make_desc((hmx_queue_func) sig, NULL));
|
||||
}
|
||||
|
||||
static inline bool hmx_queue_empty(struct hmx_queue * q) {
|
||||
return q->idx_pop == q->idx_write;
|
||||
}
|
||||
|
||||
static inline uint32_t hmx_queue_depth(struct hmx_queue * q) {
|
||||
return (q->idx_read - q->idx_read) & q->idx_mask;
|
||||
}
|
||||
|
||||
static inline uint32_t hmx_queue_capacity(struct hmx_queue * q) {
|
||||
return q->capacity;
|
||||
}
|
||||
|
||||
static inline struct hmx_queue_desc hmx_queue_pop(struct hmx_queue * q) {
|
||||
unsigned int ip = q->idx_pop;
|
||||
unsigned int iw = q->idx_write;
|
||||
|
||||
struct hmx_queue_desc rd = { NULL, NULL };
|
||||
if (ip == iw) {
|
||||
return rd;
|
||||
}
|
||||
|
||||
// Wait for desc to complete
|
||||
struct hmx_queue_desc * d = &q->desc[ip];
|
||||
while (!atomic_load(&d->done)) {
|
||||
FARF(HIGH, "hmx-queue-pop: waiting for HMX queue : %u\n", ip);
|
||||
hex_pause();
|
||||
}
|
||||
|
||||
rd = *d;
|
||||
q->idx_pop = (ip + 1) & q->idx_mask;
|
||||
|
||||
FARF(HIGH, "hmx-queue-pop: ip %u func %p data %p\n", ip, rd.func, rd.data);
|
||||
return rd;
|
||||
}
|
||||
|
||||
static inline void hmx_queue_flush(struct hmx_queue * q) {
|
||||
while (hmx_queue_pop(q).func != NULL) ;
|
||||
}
|
||||
|
||||
static inline void hmx_queue_suspend(struct hmx_queue *q) {
|
||||
hmx_queue_signal(q, HMX_QUEUE_SUSPEND);
|
||||
hmx_queue_flush(q);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
#endif /* HMX_QUEUE_H */
|
||||
@@ -14,10 +14,6 @@
|
||||
|
||||
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
|
||||
|
||||
static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) {
|
||||
asm volatile("bias = mxmem2(%0)" :: "r"(scales));
|
||||
}
|
||||
|
||||
// Initialise aligned 256-byte area with scale vector + zero padding.
|
||||
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
|
||||
HVX_Vector *pv = (HVX_Vector *)out_scales;
|
||||
@@ -25,58 +21,6 @@ static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vecto
|
||||
*pv = Q6_V_vzero();
|
||||
}
|
||||
|
||||
// Load multiple contiguous tiles with :deep streaming.
|
||||
// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt].
|
||||
// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank
|
||||
// boundary, otherwise the mxmem instruction will raise a precise bus error.
|
||||
// Callers must ensure their VTCM layout satisfies this constraint.
|
||||
static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles,
|
||||
const __fp16 *col_tiles,
|
||||
size_t n_tiles) {
|
||||
size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1;
|
||||
asm volatile(
|
||||
"{ activation.hf = mxmem(%0, %1):deep\n"
|
||||
"weight.hf = mxmem(%2, %3) }\n"
|
||||
:: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Load a single activation+weight tile pair (no :deep streaming).
|
||||
// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula
|
||||
// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047.
|
||||
// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation
|
||||
// places a tile near a 4 MB bank boundary, the oversized region crosses it and
|
||||
// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly
|
||||
// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047).
|
||||
static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile,
|
||||
const __fp16 *wt_tile) {
|
||||
asm volatile(
|
||||
"{ activation.hf = mxmem(%0, %1)\n"
|
||||
"weight.hf = mxmem(%2, %3) }\n"
|
||||
:: "r"(act_tile), "r"(2047),
|
||||
"r"(wt_tile), "r"(2047)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) {
|
||||
// Use the combined convert-and-store instruction (matches the reference
|
||||
// Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence
|
||||
// "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter.
|
||||
asm volatile(
|
||||
"mxmem(%0, %1):after.hf = acc\n"
|
||||
:: "r"(out), "r"(0)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Compute inner product of two vectors of tiles and store result.
|
||||
static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out,
|
||||
const __fp16 *row_tiles,
|
||||
const __fp16 *col_tiles,
|
||||
size_t n_tiles) {
|
||||
hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles);
|
||||
hmx_consume_accumulator_fp16(out);
|
||||
}
|
||||
|
||||
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define HTP_CTX_H
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hmx-queue.h"
|
||||
#include "htp-ops.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
@@ -30,6 +31,8 @@ struct htp_spad {
|
||||
uint32_t size_per_thread; // size per thread
|
||||
};
|
||||
|
||||
struct htp_context;
|
||||
|
||||
// Context while processing an Op
|
||||
// TODO: fold this into the main context
|
||||
struct htp_ops_context {
|
||||
@@ -72,6 +75,10 @@ struct htp_context {
|
||||
atomic_bool vtcm_needs_release;
|
||||
|
||||
struct htp_ops_context octx;
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap
|
||||
#endif
|
||||
};
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx);
|
||||
|
||||
@@ -91,7 +91,12 @@ enum htp_op_code {
|
||||
#define HTP_OP_MAX_BUFS 8
|
||||
#define HTP_OP_MAX_REQS 256
|
||||
#define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS)
|
||||
|
||||
#if __HVX_ARCH__ < 75
|
||||
#define HTP_OP_MAX_VMEM (3167538380u)
|
||||
#else
|
||||
#define HTP_OP_MAX_VMEM (3221225472u)
|
||||
#endif
|
||||
|
||||
enum htp_tensor_flags {
|
||||
HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights)
|
||||
|
||||
@@ -116,9 +116,14 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) {
|
||||
#if __HVX_ARCH__ >= 81
|
||||
HVX_Vector q0 = Q6_Vqf32_equals_Vsf(v0);
|
||||
HVX_Vector q1 = Q6_Vqf32_equals_Vsf(v1);
|
||||
#else
|
||||
const HVX_Vector zero = Q6_V_vzero();
|
||||
HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
|
||||
HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
|
||||
#endif
|
||||
return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0));
|
||||
}
|
||||
|
||||
|
||||
@@ -18,8 +18,9 @@
|
||||
#include <remote.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hex-utils.h"
|
||||
#include "hex-dma.h"
|
||||
#include "hmx-queue.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
@@ -324,6 +325,14 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
ctx->hmx_enabled = use_hmx;
|
||||
ctx->hmx_queue = NULL;
|
||||
if (use_hmx) {
|
||||
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
|
||||
if (!ctx->hmx_queue) {
|
||||
FARF(ERROR, "hmx-queue-create failed");
|
||||
ctx->hmx_enabled = false;
|
||||
}
|
||||
}
|
||||
FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx);
|
||||
#endif
|
||||
|
||||
@@ -389,7 +398,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
ctx->hmx_enabled = 0;
|
||||
if (ctx->hmx_queue) {
|
||||
hmx_queue_delete(ctx->hmx_queue);
|
||||
ctx->hmx_queue = NULL;
|
||||
}
|
||||
ctx->hmx_enabled = false;
|
||||
#endif
|
||||
|
||||
vtcm_free(ctx);
|
||||
|
||||
@@ -250,6 +250,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal
|
||||
case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break;
|
||||
case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break;
|
||||
case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break;
|
||||
case GGML_UNARY_OP_XIELU: op_num = OP_UNARY_NUM_XIELU; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
} break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
|
||||
@@ -1043,6 +1043,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
case GGML_UNARY_OP_XIELU:
|
||||
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
|
||||
default:
|
||||
return false;
|
||||
@@ -1159,6 +1160,23 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
return false;
|
||||
}
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
break;
|
||||
case GGML_TYPE_BF16:
|
||||
if (!has_bfloat) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
|
||||
@@ -127,6 +127,7 @@
|
||||
#define OP_UNARY_NUM_CEIL 118
|
||||
#define OP_UNARY_NUM_ROUND 119
|
||||
#define OP_UNARY_NUM_TRUNC 120
|
||||
#define OP_UNARY_NUM_XIELU 121
|
||||
|
||||
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
|
||||
#define OP_SUM_ROWS_NUM_MEAN 11
|
||||
|
||||
@@ -787,6 +787,13 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
args.max = ggml_get_op_params_f32(op, 1);
|
||||
}
|
||||
|
||||
if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) {
|
||||
args.slope = ggml_get_op_params_f32(op, 1); // alpha_n
|
||||
args.scale = ggml_get_op_params_f32(op, 2); // alpha_p
|
||||
args.bias = ggml_get_op_params_f32(op, 3); // beta
|
||||
args.val = ggml_get_op_params_f32(op, 4); // eps
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
if (pipeline.c4) {
|
||||
|
||||
@@ -1177,6 +1177,15 @@ kernel void kernel_unary_impl(
|
||||
if (FC_OP == OP_UNARY_NUM_TRUNC) {
|
||||
dst_ptr[i0] = (T) trunc(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_XIELU) {
|
||||
const TC xi = x;
|
||||
const TC gate = TC(xi > TC(0.0f));
|
||||
const TC clamped = fmin(xi, TC(args.val));
|
||||
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
|
||||
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
|
||||
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
|
||||
@@ -20,6 +20,13 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
||||
#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
|
||||
|
||||
#include <vulkan/vulkan.hpp>
|
||||
// SPIRV-Headers: LunarG Windows SDK uses Include/spirv-headers/spirv.hpp (not spirv/unified1/). MinGW/MSYS2 and
|
||||
// Linux packages use Khronos layout spirv/unified1/spirv.hpp. See docs/build.md#vulkan.
|
||||
#if defined(_WIN32) && !defined(__MINGW32__)
|
||||
#include <spirv-headers/spirv.hpp>
|
||||
#else
|
||||
#include <spirv/unified1/spirv.hpp>
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
@@ -2131,6 +2138,66 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
||||
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
|
||||
|
||||
vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
|
||||
|
||||
// Patch SPIR-V to enable RTE rounding for FP16, avoiding the need for
|
||||
// separate shader variants compiled with -DRTE16.
|
||||
std::vector<uint32_t> spv;
|
||||
if (device->float_controls_rte_fp16) {
|
||||
const uint32_t* spv_words = reinterpret_cast<const uint32_t *>(spv_data);
|
||||
size_t word_count = spv_size / sizeof(uint32_t);
|
||||
spv.assign(spv_words, spv_words + word_count);
|
||||
|
||||
// Find insertion points respecting SPIR-V layout order:
|
||||
// Header(5) -> OpCapability -> OpExtension -> ... -> OpEntryPoint -> OpExecutionMode -> ...
|
||||
size_t pos = 5; // skip header
|
||||
size_t cap_insert_pos = pos;
|
||||
size_t ext_insert_pos = pos;
|
||||
size_t exec_insert_pos = pos;
|
||||
uint32_t entry_point_id = 0;
|
||||
|
||||
while (pos < spv.size()) {
|
||||
uint32_t opcode = spv[pos] & spv::OpCodeMask;
|
||||
uint32_t len = spv[pos] >> spv::WordCountShift;
|
||||
if (len == 0) break;
|
||||
|
||||
if (opcode == spv::OpCapability) {
|
||||
cap_insert_pos = pos + len;
|
||||
ext_insert_pos = pos + len;
|
||||
} else if (opcode == spv::OpExtension) {
|
||||
ext_insert_pos = pos + len;
|
||||
} else if (opcode == spv::OpEntryPoint) {
|
||||
entry_point_id = spv[pos + 2];
|
||||
exec_insert_pos = pos + len;
|
||||
} else if (opcode == spv::OpExecutionMode || opcode == spv::OpExecutionModeId) {
|
||||
exec_insert_pos = pos + len;
|
||||
} else if (entry_point_id != 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
pos += len;
|
||||
}
|
||||
|
||||
// Insert from latest position first so earlier indices stay valid.
|
||||
|
||||
// OpExecutionMode %entrypoint RoundingModeRTE 16
|
||||
uint32_t exec_mode[] = { (4u << spv::WordCountShift) | spv::OpExecutionMode, entry_point_id, spv::ExecutionModeRoundingModeRTE, 16 };
|
||||
spv.insert(spv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode));
|
||||
|
||||
// OpExtension "SPV_KHR_float_controls"
|
||||
const char ext_str[] = "SPV_KHR_float_controls";
|
||||
size_t ext_str_words = CEIL_DIV(sizeof(ext_str), sizeof(uint32_t));
|
||||
std::vector<uint32_t> extension(1 + ext_str_words, 0);
|
||||
extension[0] = (uint32_t)((1 + ext_str_words) << spv::WordCountShift) | spv::OpExtension;
|
||||
memcpy(&extension[1], ext_str, sizeof(ext_str));
|
||||
spv.insert(spv.begin() + ext_insert_pos, extension.begin(), extension.end());
|
||||
|
||||
// OpCapability RoundingModeRTE
|
||||
uint32_t capability[] = { (2u << spv::WordCountShift) | spv::OpCapability, spv::CapabilityRoundingModeRTE };
|
||||
spv.insert(spv.begin() + cap_insert_pos, std::begin(capability), std::end(capability));
|
||||
|
||||
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spv.size() * sizeof(uint32_t), spv.data());
|
||||
}
|
||||
|
||||
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
|
||||
|
||||
vk::PushConstantRange pcr(
|
||||
@@ -2858,11 +2925,10 @@ struct vk_fa_tuning_params {
|
||||
}
|
||||
};
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
GGML_UNUSED(kv_type);
|
||||
|
||||
vk_fa_tuning_params result{};
|
||||
result.path = FA_SCALAR;
|
||||
@@ -2914,7 +2980,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
|
||||
|
||||
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
|
||||
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
|
||||
result.block_rows /= 2;
|
||||
}
|
||||
|
||||
@@ -3080,6 +3146,10 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||
case GGML_TYPE_MXFP4:
|
||||
lut_size = 4*16;
|
||||
break;
|
||||
case GGML_TYPE_NVFP4:
|
||||
// Same kvalues budget as MXFP4 plus ue4m3_fp32_lut[128] (types.glsl, DATA_A_NVFP4).
|
||||
lut_size = 4*16 + 128u * (uint32_t)sizeof(float);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -3445,21 +3515,47 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
if (device->fp16) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product && device->subgroup_clustered) {
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
|
||||
}
|
||||
} else {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product && device->subgroup_clustered) {
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
|
||||
}
|
||||
}
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
@@ -3533,6 +3629,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_NVFP4], matmul_nvfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
|
||||
GGML_ASSERT(device->subgroup_ballot);
|
||||
|
||||
@@ -3563,6 +3660,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
||||
#undef CREATE_MM
|
||||
#undef CREATE_MM2
|
||||
} else
|
||||
@@ -3626,6 +3724,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
} else {
|
||||
CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
@@ -3649,6 +3748,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
}
|
||||
|
||||
GGML_ASSERT(device->subgroup_ballot);
|
||||
@@ -3683,6 +3783,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
||||
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
||||
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
||||
#undef CREATE_MM2
|
||||
#undef CREATE_MM
|
||||
} else
|
||||
@@ -3748,6 +3849,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
@@ -3794,6 +3896,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
@@ -3839,6 +3942,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
@@ -3914,6 +4018,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
@@ -3958,6 +4063,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_subgroup_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
} else {
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
@@ -3985,6 +4091,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
}
|
||||
}
|
||||
// reusing CREATE_MM from the fp32 path
|
||||
@@ -4083,6 +4190,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
@@ -4108,6 +4216,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
@@ -4159,6 +4268,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_NVFP4], "mul_mat_vec_id_nvfp4_f32", arr_dmmv_id_nvfp4_f32_f32_len[reduc16], arr_dmmv_id_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
@@ -4214,6 +4324,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
|
||||
// get_rows
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||
@@ -4240,6 +4351,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_NVFP4], "get_rows_nvfp4", get_rows_nvfp4_len, get_rows_nvfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||
@@ -4266,6 +4378,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
@@ -4298,10 +4411,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
||||
|
||||
if (device->float_controls_rte_fp16 &&
|
||||
sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
|
||||
if (sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_len, rms_norm_mul_rope_f32_f16_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
@@ -4326,43 +4438,28 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_rte_len, cpy_f32_q1_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
|
||||
#define SET_ROWS(itype, rte) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## rte ## _len, set_rows_q1_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
#define SET_ROWS(itype) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## _len, set_rows_f32 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## _len, set_rows_f16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## _len, set_rows_bf16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## _len, set_rows_q1_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## _len, set_rows_q4_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## _len, set_rows_q4_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## _len, set_rows_q5_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
SET_ROWS(_i32, _rte)
|
||||
SET_ROWS(_i64, _rte)
|
||||
} else {
|
||||
SET_ROWS(_i32, )
|
||||
SET_ROWS(_i64, )
|
||||
}
|
||||
SET_ROWS(_i32)
|
||||
SET_ROWS(_i64)
|
||||
#undef SET_ROWS
|
||||
|
||||
|
||||
@@ -4382,11 +4479,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
return s;
|
||||
};
|
||||
|
||||
bool rte = device->float_controls_rte_fp16;
|
||||
#define CREATE_BINARY(name, namemod, spec, bindings) \
|
||||
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
|
||||
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
|
||||
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
|
||||
"main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
|
||||
|
||||
CREATE_BINARY(add, , {0}, 4)
|
||||
@@ -4429,13 +4525,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -4476,19 +4567,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_UNARY(floor)
|
||||
CREATE_UNARY(trunc)
|
||||
CREATE_UNARY(sgn)
|
||||
CREATE_UNARY(exp)
|
||||
#undef CREATE_UNARY
|
||||
|
||||
#define CREATE_UNARY_RTE(name) \
|
||||
if (device->float_controls_rte_fp16) { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
||||
}
|
||||
CREATE_UNARY_RTE(exp)
|
||||
#undef CREATE_UNARY_RTE
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -4498,13 +4579,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
#define CREATE_GLU(name) \
|
||||
if (device->float_controls_rte_fp16) { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
|
||||
|
||||
CREATE_GLU(geglu)
|
||||
CREATE_GLU(reglu)
|
||||
@@ -4537,25 +4613,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
||||
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
|
||||
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
|
||||
@@ -4617,13 +4682,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
#define IM2COL(bda) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||
if (device->float_controls_rte_fp16) { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
||||
if (device->shader_int64 && device->buffer_device_address) {
|
||||
IM2COL(_bda)
|
||||
} else {
|
||||
@@ -6064,6 +6124,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
@@ -6136,6 +6197,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
@@ -6202,6 +6264,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
@@ -6293,6 +6356,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
@@ -6362,6 +6426,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
@@ -8780,7 +8845,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
|
||||
GGML_UNUSED(f32acc);
|
||||
// Needs to be kept up to date on shader changes
|
||||
const uint32_t wg_size = params.workgroup_size;
|
||||
@@ -8789,21 +8854,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
|
||||
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
|
||||
const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
|
||||
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
|
||||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
|
||||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);
|
||||
|
||||
// tmpsh is overestimated slightly
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
|
||||
|
||||
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
|
||||
|
||||
const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
uint32_t Qf, kvsh, kblocksh_size;
|
||||
if (mmq) {
|
||||
// block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds
|
||||
const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size;
|
||||
Qf = Br * (hsk / 32) * block_b_size;
|
||||
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
// kvsh uses D = HSV (K goes through kblocksh instead)
|
||||
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
|
||||
// block_a_cache size depends on quant type
|
||||
uint32_t block_a_size;
|
||||
switch (kv_type) {
|
||||
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
|
||||
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
|
||||
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
|
||||
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
|
||||
default: block_a_size = 0; break;
|
||||
}
|
||||
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
|
||||
} else {
|
||||
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
kblocksh_size = 0;
|
||||
}
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
@@ -14262,8 +14357,7 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co
|
||||
}
|
||||
|
||||
// conditions for pipeline creation
|
||||
if (!(ctx->device->float_controls_rte_fp16 &&
|
||||
sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
|
||||
if (sizeof(vk_op_rms_norm_mul_rope_push_constants) > ctx->device->properties.limits.maxPushConstantsSize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -15318,6 +15412,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
@@ -15433,6 +15528,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
case GGML_TYPE_I32:
|
||||
return true;
|
||||
default:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "generic_unary_head.glsl"
|
||||
#include "dequant_funcs.glsl"
|
||||
|
||||
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
|
||||
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4)
|
||||
// 16 invocations needed for init_iq_shmem
|
||||
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
|
||||
#else
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#if defined(SET_ROWS) && QUANT_K == 1
|
||||
|
||||
@@ -450,6 +450,25 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint sub = iqs >> 4;
|
||||
const float d = ue4m3_to_fp32(data_a[a_offset + ib].d[sub]);
|
||||
const uint j = iqs & 7;
|
||||
const uint shift = (iqs & 8) >> 1; // 0 or 4
|
||||
const uint vui0 = uint(data_a[a_offset + ib].qs[sub * 8u + j]);
|
||||
const uint vui1 = uint(data_a[a_offset + ib].qs[sub * 8u + j + 1]);
|
||||
const uint qs0 = (vui0 >> shift) & 0xF;
|
||||
const uint qs1 = (vui1 >> shift) & 0xF;
|
||||
return vec2(float(kvalues_mxfp4[qs0]), float(kvalues_mxfp4[qs1])) * d * 0.5;
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const vec2 v0 = dequantize(ib, iqs, a_offset);
|
||||
const vec2 v1 = dequantize(ib, iqs + 2u, a_offset);
|
||||
return vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(0, 0);
|
||||
@@ -484,6 +503,12 @@ vec2 get_dm(uint ib, uint a_offset) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(1.0, 0.0);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm);
|
||||
|
||||
@@ -697,6 +697,24 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4 {
|
||||
block_nvfp4 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint sub = (idx & 0x30) >> 4;
|
||||
const uint iqs = ((idx & 0x30) >> 1) + (idx & 0x7);
|
||||
const uint shift = (idx & 0x8) >> 1;
|
||||
const float d = ue4m3_to_fp32(bl.block.d[sub]);
|
||||
uint qs = uint(bl.block.qs[iqs]);
|
||||
qs = (qs >> shift) & 0xF;
|
||||
return float16_t(kvalues_mxfp4[qs] * d * 0.5);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q1_0)
|
||||
#define dequantFuncA dequantFuncQ1_0
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
@@ -743,6 +761,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
||||
#define dequantFuncA dequantFuncIQ4_NL
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define dequantFuncA dequantFuncMXFP4
|
||||
#elif defined(DATA_A_NVFP4)
|
||||
#define dequantFuncA dequantFuncNVFP4
|
||||
#elif defined(DATA_A_F32)
|
||||
#define dequantFuncA dequantFuncF32
|
||||
#endif
|
||||
|
||||
32
ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp
Normal file
32
ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp
Normal file
@@ -0,0 +1,32 @@
|
||||
#version 450
|
||||
|
||||
#include "dequant_head.glsl"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_nvfp4 data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
const uint tid = gl_LocalInvocationID.x % 64;
|
||||
const uint sub = tid / 16;
|
||||
const uint ir = tid % 16;
|
||||
const uint ib = 16 * i + ir;
|
||||
if (ib >= p.nel / 64) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint q_idx = 8 * sub;
|
||||
const uint b_idx = 1024 * i + 64 * ir + 16 * sub;
|
||||
|
||||
const float d = ue4m3_to_fp32(data_a[ib].d[sub]);
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
|
||||
data_b[b_idx + l + 8] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
|
||||
@@ -10,6 +10,13 @@
|
||||
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
|
||||
#endif
|
||||
|
||||
#ifdef MMQ
|
||||
#extension GL_EXT_integer_dot_product : require
|
||||
#extension GL_KHR_shader_subgroup_clustered : require
|
||||
|
||||
#include "mul_mmq_shmem_types.glsl"
|
||||
#endif
|
||||
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
|
||||
@@ -41,15 +48,34 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
|
||||
const uint32_t masksh_stride = Br + 1;
|
||||
shared FLOAT_TYPE masksh[Bc * masksh_stride];
|
||||
|
||||
#ifndef MMQ
|
||||
const uint32_t qf_stride = HSK / 4 + 1;
|
||||
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
|
||||
#else
|
||||
|
||||
const uint32_t qf_stride = HSK / 32;
|
||||
shared block_b_cache Qf[Br * qf_stride];
|
||||
#endif
|
||||
|
||||
#ifndef MMQ
|
||||
const uint32_t D = HSK > HSV ? HSK : HSV;
|
||||
#else
|
||||
const uint32_t D = HSV;
|
||||
#endif
|
||||
const uint32_t kvsh_stride = D / 4 + 1;
|
||||
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
|
||||
|
||||
#ifdef MMQ
|
||||
|
||||
shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1];
|
||||
#endif
|
||||
|
||||
shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
|
||||
|
||||
#ifdef MMQ
|
||||
#include "flash_attn_mmq_funcs.glsl"
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
@@ -82,10 +108,39 @@ void main() {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t r = (idx + tid) / (HSK / 4);
|
||||
if (r < Br && d < HSK / 4 &&
|
||||
i * Br + r < N) {
|
||||
const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N;
|
||||
#ifndef MMQ
|
||||
if (is_in_bounds) {
|
||||
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||
}
|
||||
#else
|
||||
const uint buf_ib = r * qf_stride + d / 8;
|
||||
const uint buf_iqs = d % 8;
|
||||
|
||||
FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f);
|
||||
const FLOAT_TYPEV4 abs_vals = abs(vals);
|
||||
|
||||
const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
|
||||
const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8);
|
||||
const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0);
|
||||
const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0);
|
||||
vals = round(vals * qd_inv);
|
||||
|
||||
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
|
||||
|
||||
#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
|
||||
}
|
||||
#else // Q4_0, Q4_1, Q5_0, Q5_1
|
||||
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
|
||||
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
|
||||
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
barrier();
|
||||
|
||||
@@ -195,6 +250,7 @@ void main() {
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
barrier();
|
||||
#ifndef MMQ
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
@@ -214,9 +270,29 @@ void main() {
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
#else // MMQ
|
||||
const uint ints_per_block = 8 / QUANT_R_MMQ;
|
||||
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
|
||||
const uint32_t iqs = (idx + tid) % ints_per_block;
|
||||
const uint32_t ib = (idx + tid) / ints_per_block;
|
||||
const uint32_t c = ib / (HSK / 32);
|
||||
const uint32_t block = ib % (HSK / 32);
|
||||
if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) {
|
||||
const uint buf_ib = c * qf_stride + block;
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
const uint global_ib = (j * Bc + c) * k_stride + block;
|
||||
k_block_to_shmem(buf_ib, global_ib, iqs, k_offset);
|
||||
} else {
|
||||
k_block_to_shmem_zero(buf_ib, iqs);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MMQ
|
||||
barrier();
|
||||
}
|
||||
|
||||
#ifndef MMQ
|
||||
// More d iterations means Q register caching becomes relevant
|
||||
// Few iterations means the additional registers needed are worse than the speed-up from caching
|
||||
if (HSK_per_thread / 4 > 4) {
|
||||
@@ -275,6 +351,110 @@ void main() {
|
||||
}
|
||||
}
|
||||
}
|
||||
#else // MMQ
|
||||
const uint hsk4 = HSK_per_thread / 4;
|
||||
const uint d_per_step = (hsk4 % 8 == 0) ? 8 :
|
||||
(hsk4 % 4 == 0) ? 4 :
|
||||
(hsk4 % 2 == 0) ? 2 : 1;
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) {
|
||||
int32_t k_quants[d_per_step];
|
||||
ACC_TYPEV2 k_dm;
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
|
||||
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
|
||||
#if QUANT_AUXF == 1
|
||||
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
|
||||
#else
|
||||
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
if (d_per_step == 8) {
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
uint vui = kblocksh[buf_ib].qs[d];
|
||||
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
#endif
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
|
||||
const uint ib = coord / BLOCK_SIZE;
|
||||
const uint iqs = (coord % BLOCK_SIZE);
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
|
||||
#else
|
||||
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
|
||||
#endif
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
if (d_per_step == 8) {
|
||||
#if defined(DATA_A_Q5_0)
|
||||
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
|
||||
k_packed.k_data_packed16[k_offset + ib].qh[1]));
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
|
||||
#endif
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
|
||||
#else
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
|
||||
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
|
||||
#endif
|
||||
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint qh_lo = (qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
#endif
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8;
|
||||
const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8;
|
||||
|
||||
int32_t acc = 0;
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]);
|
||||
}
|
||||
|
||||
Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x;
|
||||
if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) {
|
||||
Sf[r][c] += k_dot_correction(qib, k_dm);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MMQ
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
// Compute sum across the D_split
|
||||
|
||||
@@ -89,6 +89,11 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32;
|
||||
layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32;
|
||||
#endif
|
||||
|
||||
#ifndef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 1
|
||||
#endif
|
||||
|
||||
149
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl
Normal file
149
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl
Normal file
@@ -0,0 +1,149 @@
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
#ifdef DATA_A_Q4_0
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
#else
|
||||
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
#endif
|
||||
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
return int32_t(vui & 0x0F0F0F0F);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
#ifdef DATA_A_Q5_0
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qh[1]));
|
||||
#else
|
||||
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
|
||||
#endif
|
||||
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
uint qh_bits = (qh >> iqs) & 0xF;
|
||||
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
u8vec4 idx = unpack8(vui & 0x0F0F0F0F);
|
||||
return pack32(i8vec4(kvalues_iq4nl_const[idx.x],
|
||||
kvalues_iq4nl_const[idx.y],
|
||||
kvalues_iq4nl_const[idx.z],
|
||||
kvalues_iq4nl_const[idx.w]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
FLOAT_TYPE get_k_d(uint ib, uint a_offset) {
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d);
|
||||
}
|
||||
#else
|
||||
FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) {
|
||||
return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qh[1]));
|
||||
}
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh;
|
||||
}
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
||||
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
||||
kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y],
|
||||
kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w]));
|
||||
kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y],
|
||||
kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w]));
|
||||
#endif
|
||||
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d);
|
||||
#else
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) {
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4 : 0;
|
||||
return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
|
||||
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4 : 0;
|
||||
int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
|
||||
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF;
|
||||
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
return kblocksh[buf_ib].qs[pos];
|
||||
#endif
|
||||
}
|
||||
|
||||
ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
|
||||
#else
|
||||
return ACC_TYPE(0.0);
|
||||
#endif
|
||||
}
|
||||
|
||||
void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) {
|
||||
kblocksh[buf_ib].qs[iqs] = 0;
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
kblocksh[buf_ib].qs[iqs + 4] = 0;
|
||||
#endif
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f);
|
||||
#else
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "utils.glsl"
|
||||
#if RMS_NORM_ROPE_FUSION
|
||||
#include "rope_params.glsl"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#include "rte.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
|
||||
@@ -501,6 +501,23 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
kvalues_mxfp4[vui2 & 0xF] * d);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d,
|
||||
kvalues_mxfp4[vui2 >> 4] * d);
|
||||
#elif defined(DATA_A_NVFP4)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
// lo and hi nibbles are 8 elements apart, which doesn't quite line up with
|
||||
// how the thread mapping and buf_idx calculation works for other types.
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row & 3) + (row & ~3) * 2;
|
||||
|
||||
const uint ib = idx / 16u;
|
||||
const uint sub = (idx & 0xC) >> 2;
|
||||
const uint iqs = (idx & 0xF) * 2;
|
||||
const float d = ue4m3_to_fp32(data_a[ib].d[sub]) * 0.5;
|
||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d,
|
||||
kvalues_mxfp4[vui2 & 0xF] * d);
|
||||
buf_a[buf_idx + 4] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d,
|
||||
kvalues_mxfp4[vui2 >> 4] * d);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,12 @@ struct block_a_cache {
|
||||
int32_t qs[32/4];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
#include "utils.glsl"
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "rope_params.glsl"
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
#if !defined(GGML_ROPE_PARAMS)
|
||||
#define GGML_ROPE_PARAMS
|
||||
|
||||
#include "rte.glsl"
|
||||
|
||||
struct rope_params {
|
||||
uint rope_mode;
|
||||
uint nrows;
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
|
||||
#if RTE16
|
||||
#extension GL_EXT_spirv_intrinsics : enable
|
||||
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
|
||||
#endif // RTE16
|
||||
@@ -1,6 +1,5 @@
|
||||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
|
||||
@@ -1692,6 +1692,7 @@ struct block_iq4_nl_packed16
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
#define QUANT_K QUANT_K_IQ4_NL
|
||||
#define QUANT_R QUANT_R_IQ4_NL
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_iq4_nl
|
||||
#define A_TYPE_PACKED16 block_iq4_nl_packed16
|
||||
#endif
|
||||
@@ -1712,6 +1713,22 @@ struct block_mxfp4
|
||||
#define A_TYPE block_mxfp4
|
||||
#endif
|
||||
|
||||
#define QUANT_K_NVFP4 64
|
||||
#define QUANT_R_NVFP4 1
|
||||
|
||||
struct block_nvfp4
|
||||
{
|
||||
uint8_t d[QUANT_K_NVFP4 / 16];
|
||||
uint8_t qs[QUANT_K_NVFP4 / 2];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
#define QUANT_K QUANT_K_NVFP4
|
||||
#define QUANT_R QUANT_R_NVFP4
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_nvfp4
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
|
||||
const int8_t kvalues_iq4nl_const[16] = {
|
||||
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
|
||||
@@ -1731,7 +1748,7 @@ void init_iq_shmem(uvec3 wgsize)
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
#if defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4)
|
||||
const int8_t kvalues_mxfp4_const[16] = {
|
||||
int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
|
||||
int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
|
||||
@@ -1739,6 +1756,24 @@ const int8_t kvalues_mxfp4_const[16] = {
|
||||
|
||||
shared int8_t kvalues_mxfp4[16];
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
// UE4M3 scale in NVFP4 blocks use only 7 bits; sign (bit 7) is always zero.
|
||||
shared float ue4m3_fp32_lut[128];
|
||||
|
||||
float ue4m3_to_fp32_build(uint u) {
|
||||
if (u == 0u || u == 127u) {
|
||||
return 0.0;
|
||||
}
|
||||
const uint exp = (u >> 3) & 15u;
|
||||
const uint man = u & 7u;
|
||||
if (exp == 0u) {
|
||||
return float(man) * (1.0 / 512.0);
|
||||
}
|
||||
const uint bits = (exp + 120u) << 23 | (man << 20);
|
||||
return uintBitsToFloat(bits);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
@@ -1746,6 +1781,11 @@ void init_iq_shmem(uvec3 wgsize)
|
||||
for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) {
|
||||
kvalues_mxfp4[i] = kvalues_mxfp4_const[i];
|
||||
}
|
||||
#if defined(DATA_A_NVFP4)
|
||||
for (uint i = gl_LocalInvocationIndex.x; i < 128u; i += wgsize.x) {
|
||||
ue4m3_fp32_lut[i] = ue4m3_to_fp32_build(i);
|
||||
}
|
||||
#endif
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
@@ -1782,6 +1822,12 @@ float e8m0_to_fp32(uint8_t x) {
|
||||
return uintBitsToFloat(bits);
|
||||
}
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
float ue4m3_to_fp32(uint8_t x) {
|
||||
return ue4m3_fp32_lut[uint(x)];
|
||||
}
|
||||
#endif
|
||||
|
||||
#if BDA
|
||||
|
||||
#extension GL_EXT_buffer_reference : enable
|
||||
|
||||
@@ -66,6 +66,7 @@ const std::vector<std::string> type_names = {
|
||||
"iq4_xs",
|
||||
"iq4_nl",
|
||||
"mxfp4",
|
||||
"nvfp4",
|
||||
"bf16",
|
||||
};
|
||||
|
||||
@@ -406,8 +407,8 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
|
||||
}
|
||||
|
||||
static std::vector<std::future<void>> compiles;
|
||||
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
||||
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
||||
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") {
|
||||
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix;
|
||||
std::string out_path = join_paths(output_dir, name + ".spv");
|
||||
|
||||
if (input_filepath == "") {
|
||||
@@ -556,7 +557,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
std::string load_vec_quant = "2";
|
||||
if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
|
||||
load_vec_quant = "8";
|
||||
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4"))
|
||||
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4") || (tname == "nvfp4"))
|
||||
load_vec_quant = "4";
|
||||
|
||||
if (tname == "bf16") {
|
||||
@@ -625,15 +626,16 @@ void process_shaders() {
|
||||
for (const bool& fp16 : {false, true}) {
|
||||
std::map<std::string, std::string> base_dict;
|
||||
if (fp16) {
|
||||
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
|
||||
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
|
||||
} else {
|
||||
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
|
||||
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}};
|
||||
}
|
||||
|
||||
// flash attention
|
||||
for (const bool& f16acc : {false, true}) {
|
||||
std::map<std::string, std::string> fa_base_dict = base_dict;
|
||||
fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
|
||||
fa_base_dict["ACC_TYPEV2"] = fp16 && f16acc ? "f16vec2" : "vec2";
|
||||
fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
|
||||
if (fp16 && f16acc) {
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
@@ -672,6 +674,12 @@ void process_shaders() {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (tname != "f32") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -737,7 +745,7 @@ void process_shaders() {
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}}));
|
||||
string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}}));
|
||||
string_to_spv("rms_norm_mul_rope_f32_f16", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
@@ -761,15 +769,12 @@ void process_shaders() {
|
||||
|
||||
for (std::string t : {"q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}
|
||||
|
||||
for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}
|
||||
|
||||
auto get_type_str = [](bool f16) {
|
||||
@@ -786,12 +791,10 @@ void process_shaders() {
|
||||
for (auto src0_f16 : {false, true}) {
|
||||
for (auto src1_f16 : {false, true}) {
|
||||
for (auto dst_f16 : {false, true}) {
|
||||
for (auto rte : {false, true}) {
|
||||
auto source = op == "add_rms" ? std::string("add") : op;
|
||||
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
|
||||
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
|
||||
auto add_rms = op == "add_rms" ? "1" : "0";
|
||||
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
|
||||
}
|
||||
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , add_rms}});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -839,14 +842,11 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
for (auto rte : {false, true}) {
|
||||
std::string suffix = rte ? "_rte" : "";
|
||||
string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
}
|
||||
string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
@@ -900,21 +900,18 @@ void process_shaders() {
|
||||
string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
for (auto rte : {false, true}) {
|
||||
std::string suffix = rte ? "_rte" : "";
|
||||
string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
}
|
||||
string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("swiglu_oai_f16", "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("swiglu_oai_f32", "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
@@ -934,25 +931,18 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
|
||||
@@ -975,7 +965,6 @@ void process_shaders() {
|
||||
std::string bda_def = bda ? "1" : "0";
|
||||
string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
|
||||
string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
|
||||
string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1028,8 +1017,8 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "1"}});
|
||||
|
||||
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
|
||||
string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
|
||||
@@ -1082,8 +1071,8 @@ void write_output_files() {
|
||||
|
||||
std::string suffixes[2] = {"_f32", "_f16"};
|
||||
for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) {
|
||||
hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
|
||||
hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
|
||||
hdr << "extern const void * " << op << "_data[2][2][2];\n";
|
||||
hdr << "extern const uint64_t " << op << "_len[2][2][2];\n";
|
||||
|
||||
std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp";
|
||||
if (basename(input_filepath) != op_file) {
|
||||
@@ -1091,8 +1080,8 @@ void write_output_files() {
|
||||
}
|
||||
std::stringstream data = make_generic_stringstream();
|
||||
std::stringstream len = make_generic_stringstream();
|
||||
data << "const void * " << op << "_data[2][2][2][2] = ";
|
||||
len << "const uint64_t " << op << "_len[2][2][2][2] = ";
|
||||
data << "const void * " << op << "_data[2][2][2] = ";
|
||||
len << "const uint64_t " << op << "_len[2][2][2] = ";
|
||||
for (uint32_t t0 = 0; t0 < 2; ++t0) {
|
||||
if (t0 == 0) {
|
||||
data << "{";
|
||||
@@ -1108,20 +1097,10 @@ void write_output_files() {
|
||||
data << "{";
|
||||
len << "{";
|
||||
}
|
||||
for (uint32_t rte = 0; rte < 2; ++rte) {
|
||||
if (rte == 0) {
|
||||
data << "{";
|
||||
len << "{";
|
||||
}
|
||||
data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
|
||||
len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
|
||||
data << "_data,";
|
||||
len << "_len,";
|
||||
if (rte == 1) {
|
||||
data << "}, ";
|
||||
len << "}, ";
|
||||
}
|
||||
}
|
||||
data << op << suffixes[t0] << suffixes[t1] << suffixes[t2];
|
||||
len << op << suffixes[t0] << suffixes[t1] << suffixes[t2];
|
||||
data << "_data,";
|
||||
len << "_len,";
|
||||
if (t2 == 1) {
|
||||
data << "}, ";
|
||||
len << "}, ";
|
||||
|
||||
@@ -79,7 +79,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
|
||||
|
||||
/* Constants */
|
||||
|
||||
#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 32u
|
||||
#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 64u
|
||||
#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u
|
||||
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u
|
||||
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6)
|
||||
@@ -97,14 +97,6 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
|
||||
|
||||
/* End Constants */
|
||||
|
||||
static inline wgpu::CallbackMode ggml_webgpu_callback_mode() {
|
||||
#ifdef __EMSCRIPTEN__
|
||||
return wgpu::CallbackMode::AllowProcessEvents;
|
||||
#else
|
||||
return wgpu::CallbackMode::AllowSpontaneous;
|
||||
#endif
|
||||
}
|
||||
|
||||
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
|
||||
// their locations.
|
||||
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
|
||||
@@ -445,34 +437,25 @@ static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status,
|
||||
}
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
// iOS browsers seem to have very strict limits on the number of in-flight GPU commands, so we need to throttle to avoid failures.
|
||||
EM_JS(int, ggml_webgpu_is_ios_browser, (), {
|
||||
const ua = navigator.userAgent;
|
||||
return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0;
|
||||
});
|
||||
#endif
|
||||
|
||||
static uint32_t ggml_backend_webgpu_get_max_inflight_batches(const wgpu::AdapterInfo & info) {
|
||||
// TODO: these next two functions may want tuning across different platforms and workloads,
|
||||
static uint32_t ggml_backend_webgpu_get_max_inflight_batches() {
|
||||
#ifdef __EMSCRIPTEN__
|
||||
// iOS has very strict limits on the number of in-flight GPU commands,
|
||||
// so we need to throttle to avoid failures.
|
||||
if (ggml_webgpu_is_ios_browser()) {
|
||||
return 1;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(info);
|
||||
#endif
|
||||
|
||||
return UINT32_MAX;
|
||||
}
|
||||
|
||||
static uint32_t ggml_backend_webgpu_get_command_submit_batch_size(const wgpu::AdapterInfo & info) {
|
||||
#ifdef __EMSCRIPTEN__
|
||||
if (ggml_webgpu_is_ios_browser()) {
|
||||
return 16;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(info);
|
||||
#endif
|
||||
|
||||
static uint32_t ggml_backend_webgpu_get_command_submit_batch_size() {
|
||||
return WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE;
|
||||
}
|
||||
|
||||
@@ -482,7 +465,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) {
|
||||
|
||||
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
|
||||
ctx->queue.OnSubmittedWorkDone(
|
||||
ggml_webgpu_callback_mode(),
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
callback_status = status;
|
||||
callback_message = std::string(message);
|
||||
@@ -502,7 +485,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
|
||||
std::string callback_message;
|
||||
|
||||
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
|
||||
buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(),
|
||||
buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
callback_status = status;
|
||||
callback_message = std::string(message);
|
||||
@@ -542,15 +525,15 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
|
||||
#endif
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx,
|
||||
const std::vector<webgpu_command> & commands,
|
||||
std::vector<wgpu::FutureWaitInfo> & futures) {
|
||||
static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx,
|
||||
const std::vector<webgpu_encoded_op> & commands,
|
||||
std::vector<wgpu::FutureWaitInfo> & futures) {
|
||||
for (const auto & command : commands) {
|
||||
auto label = command.pipeline_name;
|
||||
auto ts_bufs = command.timestamp_query_bufs;
|
||||
|
||||
wgpu::Future f = ts_bufs.host_buf.MapAsync(
|
||||
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(),
|
||||
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
|
||||
[ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::MapAsyncStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
|
||||
@@ -3428,7 +3411,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, ggml_webgpu_callback_mode(),
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
@@ -3449,8 +3432,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
}
|
||||
#endif
|
||||
ctx->webgpu_global_ctx->adapter.GetInfo(&info);
|
||||
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(info);
|
||||
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(info);
|
||||
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size();
|
||||
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches();
|
||||
wgpu::SupportedFeatures features;
|
||||
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
|
||||
// we require f16 support
|
||||
@@ -3501,8 +3484,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
dev_desc.requiredFeatures = required_features.data();
|
||||
dev_desc.requiredFeatureCount = required_features.size();
|
||||
dev_desc.SetDeviceLostCallback(
|
||||
ggml_webgpu_callback_mode(),
|
||||
[ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
if (reason == wgpu::DeviceLostReason::Destroyed) {
|
||||
return;
|
||||
}
|
||||
@@ -3535,7 +3518,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->adapter.RequestDevice(
|
||||
&dev_desc, ggml_webgpu_callback_mode(),
|
||||
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||
if (status != wgpu::RequestDeviceStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
|
||||
|
||||
@@ -502,12 +502,6 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let dmin = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
|
||||
}
|
||||
|
||||
// Map k_in_block to loop structure:
|
||||
// Outer loop over 64-element groups (alternating q_b_idx)
|
||||
// Inner loop over 2 shifts per group
|
||||
@@ -523,15 +517,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
var sc: u32;
|
||||
var mn: u32;
|
||||
|
||||
let scale_base = block_byte_base + 4u;
|
||||
|
||||
if (is < 4u) {
|
||||
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
|
||||
let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u);
|
||||
let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
|
||||
sc = sc_byte & 63u;
|
||||
mn = min_byte & 63u;
|
||||
} else {
|
||||
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
|
||||
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
|
||||
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u);
|
||||
let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u);
|
||||
let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
|
||||
|
||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||
@@ -578,11 +574,6 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let dmin = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
|
||||
}
|
||||
|
||||
// The original loop processes elements in groups of 64
|
||||
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
|
||||
@@ -603,15 +594,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
var sc: u32;
|
||||
var mn: u32;
|
||||
|
||||
let scale_base = block_byte_base + 4u;
|
||||
|
||||
if (is < 4u) {
|
||||
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
|
||||
let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u);
|
||||
let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
|
||||
sc = sc_byte & 63u;
|
||||
mn = min_byte & 63u;
|
||||
} else {
|
||||
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
|
||||
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
|
||||
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u);
|
||||
let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u);
|
||||
let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
|
||||
|
||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||
|
||||
@@ -4,14 +4,14 @@ enable f16;
|
||||
#include "mul_mat_decls.tmpl"
|
||||
|
||||
#ifdef VEC
|
||||
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
|
||||
return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
|
||||
fn store_val(acc: array<array<f32, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
|
||||
return vec4<f32>(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef SCALAR
|
||||
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
|
||||
return f32(acc[tm][tn]);
|
||||
fn store_val(acc: array<array<f32, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
|
||||
return acc[tm][tn];
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -98,7 +98,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
|
||||
let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
|
||||
|
||||
var acc: array<array<f16, TILE_N>, TILE_M>;
|
||||
var acc: array<array<f32, TILE_N>, TILE_M>;
|
||||
|
||||
for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
|
||||
|
||||
@@ -122,7 +122,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let src1_idx = src1_n * TILE_K + k_inner;
|
||||
let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
|
||||
for (var tm = 0u; tm < TILE_M; tm++) {
|
||||
acc[tm][tn] += src0_tile[tm] * src1_val;
|
||||
acc[tm][tn] += f32(src0_tile[tm]) * f32(src1_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,9 @@ enable chromium_experimental_subgroup_matrix;
|
||||
#include "common_decls.tmpl"
|
||||
#include "mul_mat_decls.tmpl"
|
||||
|
||||
// TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs.
|
||||
// See https://github.com/ggml-org/llama.cpp/issues/21602
|
||||
|
||||
#ifdef VEC
|
||||
fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
||||
dst[dst_idx] = vec4<f32>(
|
||||
|
||||
161
models/templates/Reka-Edge.jinja
Normal file
161
models/templates/Reka-Edge.jinja
Normal file
@@ -0,0 +1,161 @@
|
||||
{%- macro render_content(content, num_img_tokens, num_video_frames) -%}
|
||||
{%- if content is string -%}
|
||||
{{- content -}}
|
||||
{%- elif content is sequence -%}
|
||||
{%- set ns = namespace(out="", prev_was_text=false) -%}
|
||||
{%- for item in content -%}
|
||||
{%- set item_type = item.get("type") -%}
|
||||
{%- if item_type == "text" or item.get("text") is not none -%}
|
||||
{%- set text = item.get("text", "") -%}
|
||||
{%- if text -%}
|
||||
{%- if ns.prev_was_text -%}
|
||||
{%- set ns.out = ns.out ~ " " -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.out = ns.out ~ text -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.prev_was_text = text != "" -%}
|
||||
{%- elif item_type in ["image", "image_url"] or item.get("image") is not none or item.get("image_url") is not none -%}
|
||||
{%- set ns.out = ns.out ~ "<image>" ~ ("<REKA_IMG_TOKEN>" * num_img_tokens) ~ "</image>" -%}
|
||||
{%- set ns.prev_was_text = false -%}
|
||||
{%- elif item_type in ["video", "video_url"] or item.get("video") is not none or item.get("video_url") is not none -%}
|
||||
{%- set repeat_tokens = num_img_tokens * num_video_frames -%}
|
||||
{%- set ns.out = ns.out ~ "<video>" ~ ("<REKA_IMG_TOKEN>" * repeat_tokens) ~ "</video>" -%}
|
||||
{%- set ns.prev_was_text = false -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ns.out -}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{%- set ns = namespace(out="", last_query_index=messages|length - 1) -%}
|
||||
{%- for msg in messages[::-1] -%}
|
||||
{%- set idx = messages|length - 1 - loop.index0 -%}
|
||||
{%- if msg.get("role") == "user" -%}
|
||||
{%- set content = msg.get("content", "") -%}
|
||||
{%- if not (content is string and content.startswith("<tool_response>") and content.endswith("</tool_response>")) -%}
|
||||
{%- set ns.last_query_index = idx -%}
|
||||
{%- break -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set last_query_index = ns.last_query_index -%}
|
||||
{%- set num_img_tokens = num_img_tokens | default(64, true) | int -%}
|
||||
{%- set num_video_frames = num_video_frames | default(6, true) | int -%}
|
||||
{%- set start_idx = 0 -%}
|
||||
{%- set system_text = "" -%}
|
||||
{%- if messages|length > 0 and messages[0].get("role") in ["system", "developer"] -%}
|
||||
{%- set system_text = render_content(messages[0].get("content", ""), num_img_tokens, num_video_frames) -%}
|
||||
{%- set start_idx = 1 -%}
|
||||
{%- endif -%}
|
||||
{%- if tools or system_text -%}
|
||||
{%- set preamble_ns = namespace(text="") -%}
|
||||
{%- if system_text -%}
|
||||
{%- set preamble_ns.text = "system: " ~ system_text -%}
|
||||
{%- endif -%}
|
||||
{%- if tools -%}
|
||||
{%- if preamble_ns.text -%}
|
||||
{%- set preamble_ns.text = preamble_ns.text ~ "\n\n" -%}
|
||||
{%- else -%}
|
||||
{%- set preamble_ns.text = "system: " -%}
|
||||
{%- endif -%}
|
||||
{%- set preamble_ns.text = preamble_ns.text
|
||||
~ "# Tools\n\n"
|
||||
~ "You may call one or more functions to assist with the user query.\n\n"
|
||||
~ "You are provided with function signatures within <tools></tools> XML tags:\n"
|
||||
~ "<tools>" -%}
|
||||
{%- for tool in tools -%}
|
||||
{%- set preamble_ns.text = preamble_ns.text ~ "\n" ~ (tool | tojson(ensure_ascii=True)) -%}
|
||||
{%- endfor -%}
|
||||
{%- set preamble_ns.text = preamble_ns.text
|
||||
~ "\n</tools>\n\n"
|
||||
~ "For each function call, return a json object with function name and arguments "
|
||||
~ "within <tool_call></tool_call> XML tags:\n"
|
||||
~ "<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.out = ns.out ~ preamble_ns.text ~ "\n\n<sep>" -%}
|
||||
{%- endif -%}
|
||||
{%- for idx in range(start_idx, messages|length) -%}
|
||||
{%- set message = messages[idx] -%}
|
||||
{%- set role = message.get("role") -%}
|
||||
{%- set content = message.get("content") -%}
|
||||
{%- if role == "user" -%}
|
||||
{%- set prefix_ns = namespace(value="human: ") -%}
|
||||
{%- if content is sequence and content is not string -%}
|
||||
{%- for item in content -%}
|
||||
{%- if item.get("type") == "text" or item.get("text") is not none -%}
|
||||
{%- set text = item.get("text", "") -%}
|
||||
{%- if text -%}
|
||||
{%- break -%}
|
||||
{%- endif -%}
|
||||
{%- elif item.get("type") in ["image", "image_url", "video", "video_url"] -%}
|
||||
{%- set prefix_ns.value = "human:" -%}
|
||||
{%- break -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.out = ns.out ~ prefix_ns.value ~ render_content(content, num_img_tokens, num_video_frames) ~ "<sep>" -%}
|
||||
{%- elif role == "assistant" -%}
|
||||
{%- set tool_calls = message.get("tool_calls") -%}
|
||||
{%- set content_text = render_content(content, num_img_tokens, num_video_frames) -%}
|
||||
{%- set reasoning_text = "" -%}
|
||||
{%- if message.get("reasoning_content") is string -%}
|
||||
{%- set reasoning_text = message.get("reasoning_content") -%}
|
||||
{%- elif "</think>" in content_text -%}
|
||||
{%- set reasoning_text = content_text.split("</think>", 1)[0].rstrip("\n").split("<think>")[-1].lstrip("\n") -%}
|
||||
{%- set content_text = content_text.split("</think>", 1)[1].lstrip("\n") -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.out = ns.out ~ "assistant: " -%}
|
||||
{%- set include_thinking = enable_thinking is true
|
||||
and idx > last_query_index
|
||||
and (idx == messages|length - 1 or reasoning_text)
|
||||
-%}
|
||||
{%- if include_thinking -%}
|
||||
{%- set ns.out = ns.out ~ "<think>\n" ~ (reasoning_text.strip() ) ~ "\n</think>\n\n" -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.out = ns.out ~ content_text -%}
|
||||
{%- if tool_calls -%}
|
||||
{%- if content_text and not ns.out.endswith("\n") -%}
|
||||
{%- set ns.out = ns.out ~ "\n" -%}
|
||||
{%- endif -%}
|
||||
{%- for tool_call in tool_calls -%}
|
||||
{%- if tool_call.get("function") is not none -%}
|
||||
{%- set tool_call = tool_call.get("function") -%}
|
||||
{%- endif -%}
|
||||
{%- set arguments = tool_call.get("arguments", {}) -%}
|
||||
{%- if arguments is string -%}
|
||||
{%- set arguments_json = arguments -%}
|
||||
{%- elif arguments is mapping -%}
|
||||
{%- set arguments_json = arguments | tojson(ensure_ascii=True) -%}
|
||||
{%- else -%}
|
||||
{%- set arguments_json = arguments | tojson(ensure_ascii=True) -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.out = ns.out
|
||||
~ "<tool_call>\n"
|
||||
~ "{\"name\": \"" ~ tool_call.get("name", "") ~ "\", \"arguments\": "
|
||||
~ arguments_json
|
||||
~ "}\n</tool_call>" -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- if not (continue_final_message and idx == messages|length - 1) -%}
|
||||
{%- set ns.out = ns.out ~ "\n\n<sep>" -%}
|
||||
{%- endif -%}
|
||||
{%- elif role == "tool" -%}
|
||||
{%- if idx == start_idx or messages[idx - 1].get("role") != "tool" -%}
|
||||
{%- set ns.out = ns.out ~ "human: " -%}
|
||||
{%- endif -%}
|
||||
{%- set response_text = render_content(content, num_img_tokens, num_video_frames) -%}
|
||||
{%- set ns.out = ns.out ~ "<tool_response>\n" ~ response_text ~ "\n</tool_response>" -%}
|
||||
{%- if idx == messages|length - 1 or messages[idx + 1].get("role") != "tool" -%}
|
||||
{%- set ns.out = ns.out ~ "<sep>" -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt
|
||||
and (messages|length == 0 or messages[-1].get("role") != "assistant")
|
||||
-%}
|
||||
{%- if enable_thinking is true -%}
|
||||
{%- set ns.out = ns.out ~ "assistant: <think>\n" -%}
|
||||
{%- else -%}
|
||||
{%- set ns.out = ns.out ~ "assistant:" -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{{- ns.out -}}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user