mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-19 14:13:22 +02:00
Compare commits
144 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bb2fcc856 | ||
|
|
27326bfce1 | ||
|
|
ad9f692f8f | ||
|
|
8a70973557 | ||
|
|
e7f2f95c9a | ||
|
|
b55dcdef5d | ||
|
|
eeef3cfced | ||
|
|
e99f1083a0 | ||
|
|
238856ec8f | ||
|
|
ea003229d3 | ||
|
|
d0061be838 | ||
|
|
a569bda445 | ||
|
|
e2f19b320f | ||
|
|
983559d24b | ||
|
|
2b089c7758 | ||
|
|
afa6bfe4f7 | ||
|
|
ae2d3f28a8 | ||
|
|
ad8207af77 | ||
|
|
667b694278 | ||
|
|
e48349a49d | ||
|
|
ae46a61e41 | ||
|
|
65cede7c70 | ||
|
|
05fa625eac | ||
|
|
d612901116 | ||
|
|
cceb1b4e33 | ||
|
|
d23a55997d | ||
|
|
5f28c53d11 | ||
|
|
4408494144 | ||
|
|
2ba9adc093 | ||
|
|
cc45f2ada6 | ||
|
|
d5dfc33027 | ||
|
|
267ba5a1d9 | ||
|
|
ff4affb4c1 | ||
|
|
55d58599c8 | ||
|
|
1a8c700bfd | ||
|
|
27b93cbd15 | ||
|
|
6e67fd2144 | ||
|
|
9e118b97c4 | ||
|
|
57088276d4 | ||
|
|
341bc7d23c | ||
|
|
08e6d914b8 | ||
|
|
184c694f45 | ||
|
|
684b36101c | ||
|
|
3a00c98584 | ||
|
|
079feab9e3 | ||
|
|
01d8eaa28d | ||
|
|
1725e316c1 | ||
|
|
b7742cf321 | ||
|
|
badba89320 | ||
|
|
baa12f3831 | ||
|
|
2d8015e8a4 | ||
|
|
eb145c0753 | ||
|
|
6e473fb384 | ||
|
|
c7db95f106 | ||
|
|
0d00ef65ed | ||
|
|
91ea5d67f2 | ||
|
|
dbb023336b | ||
|
|
53aef25a88 | ||
|
|
2dec548094 | ||
|
|
0ccbfdef3e | ||
|
|
94a602db66 | ||
|
|
05a6f0e894 | ||
|
|
b48e80f677 | ||
|
|
752584d5f5 | ||
|
|
cc2aa81513 | ||
|
|
0e21991472 | ||
|
|
b2ecc0cdb4 | ||
|
|
5065da554e | ||
|
|
5174d7206f | ||
|
|
43919b7f4f | ||
|
|
423cf0b26f | ||
|
|
33a56f90a6 | ||
|
|
25224c8021 | ||
|
|
2f5d8f8edc | ||
|
|
bb96bfd361 | ||
|
|
0644baefde | ||
|
|
490eb96b88 | ||
|
|
3bb78133ab | ||
|
|
79cc0f2daf | ||
|
|
338085c69e | ||
|
|
4c61875bf8 | ||
|
|
4b385bfcf8 | ||
|
|
f488429380 | ||
|
|
4d688f9ebb | ||
|
|
ff599039a9 | ||
|
|
f486ce9f30 | ||
|
|
38adc7d469 | ||
|
|
3b3a948134 | ||
|
|
6845f7f87f | ||
|
|
fa16e517a3 | ||
|
|
313493de53 | ||
|
|
b1ff83bbb0 | ||
|
|
4ae1b7517a | ||
|
|
4d3daf80f8 | ||
|
|
914dde72ba | ||
|
|
3136a849db | ||
|
|
e463bbdf65 | ||
|
|
53de59f67d | ||
|
|
9ab072ebbe | ||
|
|
ada90bf2ba | ||
|
|
0c1f39a9ae | ||
|
|
73cd5e1b97 | ||
|
|
8ee538ce73 | ||
|
|
6d95707827 | ||
|
|
89181c0b6d | ||
|
|
ceaa89b786 | ||
|
|
2cce9fddb7 | ||
|
|
612db61886 | ||
|
|
57487a64c8 | ||
|
|
fc0fe40049 | ||
|
|
9a96352729 | ||
|
|
c03a5a46f0 | ||
|
|
6948adc90d | ||
|
|
854b09f0d7 | ||
|
|
66d403c480 | ||
|
|
f0bfe54f55 | ||
|
|
52e38faf8c | ||
|
|
a0d585537c | ||
|
|
98e57ca422 | ||
|
|
262364e31d | ||
|
|
820ebfa6f4 | ||
|
|
292f6908cd | ||
|
|
81ddc60cb3 | ||
|
|
972f323e73 | ||
|
|
f5e7734ff2 | ||
|
|
1e8924fd65 | ||
|
|
39bf692af1 | ||
|
|
e06088da0f | ||
|
|
5fa1c190d9 | ||
|
|
eb449cdfa4 | ||
|
|
5999b50eb0 | ||
|
|
9a5f57795c | ||
|
|
96441c955e | ||
|
|
8872ad2125 | ||
|
|
34ba7b5a2f | ||
|
|
b83111815e | ||
|
|
3228e77287 | ||
|
|
7fbd36c50c | ||
|
|
537eadb1b9 | ||
|
|
db6adb3c88 | ||
|
|
dfde5993ea | ||
|
|
06bf3796f4 | ||
|
|
3688c4f504 | ||
|
|
1946e46f4c |
@@ -41,7 +41,7 @@ body:
|
||||
attributes:
|
||||
label: GGML backends
|
||||
description: Which GGML backends do you know to be affected?
|
||||
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
|
||||
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
|
||||
multiple: true
|
||||
validations:
|
||||
required: true
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/011-bug-results.yml
vendored
2
.github/ISSUE_TEMPLATE/011-bug-results.yml
vendored
@@ -42,7 +42,7 @@ body:
|
||||
attributes:
|
||||
label: GGML backends
|
||||
description: Which GGML backends do you know to be affected?
|
||||
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
|
||||
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
|
||||
multiple: true
|
||||
validations:
|
||||
required: true
|
||||
|
||||
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -295,6 +295,7 @@ jobs:
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
-DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
||||
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
|
||||
|
||||
- name: Build (no OpenMP)
|
||||
@@ -307,6 +308,7 @@ jobs:
|
||||
-DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||
-DGGML_OPENMP=OFF
|
||||
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
|
||||
73
.github/workflows/server-metal.yml
vendored
Normal file
73
.github/workflows/server-metal.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
name: Server-Metal
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
inputs:
|
||||
sha:
|
||||
description: 'Commit SHA1 to build'
|
||||
required: false
|
||||
type: string
|
||||
slow_tests:
|
||||
description: 'Run slow tests'
|
||||
required: true
|
||||
type: boolean
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: ['.github/workflows/server-metal.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*']
|
||||
|
||||
env:
|
||||
LLAMA_LOG_COLORS: 1
|
||||
LLAMA_LOG_PREFIX: 1
|
||||
LLAMA_LOG_TIMESTAMPS: 1
|
||||
LLAMA_LOG_VERBOSITY: 10
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
server-metal:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
name: server-metal (${{ matrix.wf_name }})
|
||||
strategy:
|
||||
matrix:
|
||||
build_type: [Release]
|
||||
wf_name: ["GPUx1"]
|
||||
include:
|
||||
- build_type: Release
|
||||
extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1"
|
||||
wf_name: "GPUx1, backend-sampling"
|
||||
- build_type: Release
|
||||
extra_args: "GGML_METAL_DEVICES=2"
|
||||
wf_name: "GPUx2"
|
||||
- build_type: Release
|
||||
extra_args: "GGML_METAL_DEVICES=2 LLAMA_ARG_BACKEND_SAMPLING=1"
|
||||
wf_name: "GPUx2, backend-sampling"
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -DGGML_SCHED_NO_REALLOC=ON
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(sysctl -n hw.logicalcpu) --target llama-server
|
||||
|
||||
- name: Tests
|
||||
id: server_integration_tests
|
||||
if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }}
|
||||
run: |
|
||||
cd tools/server/tests
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
export ${{ matrix.extra_args }}
|
||||
pytest -v -x -m "not slow"
|
||||
120
.github/workflows/server-webui.yml
vendored
120
.github/workflows/server-webui.yml
vendored
@@ -8,10 +8,6 @@ on:
|
||||
description: 'Commit SHA1 to build'
|
||||
required: false
|
||||
type: string
|
||||
slow_tests:
|
||||
description: 'Run slow tests'
|
||||
required: true
|
||||
type: boolean
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
@@ -101,119 +97,3 @@ jobs:
|
||||
if: ${{ always() && steps.playwright.conclusion == 'success' }}
|
||||
run: npm run test:e2e
|
||||
working-directory: tools/server/webui
|
||||
|
||||
server-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken
|
||||
build_type: [RelWithDebInfo]
|
||||
include:
|
||||
- build_type: Release
|
||||
sanitizer: ""
|
||||
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
|
||||
|
||||
steps:
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get -y install \
|
||||
build-essential \
|
||||
xxd \
|
||||
git \
|
||||
cmake \
|
||||
curl \
|
||||
wget \
|
||||
language-pack-en \
|
||||
libssl-dev
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Tests dependencies
|
||||
id: test_dependencies
|
||||
run: |
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
- name: Setup Node.js for WebUI
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
cache-dependency-path: "tools/server/webui/package-lock.json"
|
||||
|
||||
- name: Install WebUI dependencies
|
||||
run: npm ci
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Build WebUI
|
||||
run: npm run build
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Build (no OpenMP)
|
||||
id: cmake_build_no_openmp
|
||||
if: ${{ matrix.sanitizer == 'THREAD' }}
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DLLAMA_BUILD_SERVER=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
-DGGML_OPENMP=OFF ;
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
||||
|
||||
- name: Build (sanitizers)
|
||||
id: cmake_build_sanitizers
|
||||
if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }}
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DLLAMA_BUILD_SERVER=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
||||
|
||||
- name: Build (sanitizers)
|
||||
id: cmake_build
|
||||
if: ${{ matrix.sanitizer == '' }}
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DLLAMA_BUILD_SERVER=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
||||
|
||||
- name: Tests
|
||||
id: server_integration_tests
|
||||
if: ${{ matrix.sanitizer == '' }}
|
||||
env:
|
||||
GITHUB_ACTIONS: "true"
|
||||
run: |
|
||||
cd tools/server/tests
|
||||
./tests.sh
|
||||
|
||||
- name: Tests (sanitizers)
|
||||
id: server_integration_tests_sanitizers
|
||||
if: ${{ matrix.sanitizer != '' }}
|
||||
run: |
|
||||
cd tools/server/tests
|
||||
LLAMA_SANITIZE=1 ./tests.sh
|
||||
|
||||
- name: Slow tests
|
||||
id: server_integration_tests_slow
|
||||
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
|
||||
run: |
|
||||
cd tools/server/tests
|
||||
SLOW_TESTS=1 ./tests.sh
|
||||
|
||||
22
.github/workflows/server.yml
vendored
22
.github/workflows/server.yml
vendored
@@ -81,18 +81,14 @@ jobs:
|
||||
-DLLAMA_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \
|
||||
-DLLAMA_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \
|
||||
-DLLAMA_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }}
|
||||
cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Tests dependencies
|
||||
id: test_dependencies
|
||||
run: |
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
pip-install: -r tools/server/tests/requirements.txt
|
||||
|
||||
- name: Tests
|
||||
id: server_integration_tests
|
||||
@@ -102,6 +98,14 @@ jobs:
|
||||
export ${{ matrix.extra_args }}
|
||||
pytest -v -x -m "not slow"
|
||||
|
||||
- name: Slow tests
|
||||
id: server_integration_tests_slow
|
||||
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
|
||||
run: |
|
||||
cd tools/server/tests
|
||||
export ${{ matrix.extra_args }}
|
||||
SLOW_TESTS=1 pytest -v -x
|
||||
|
||||
server-windows:
|
||||
runs-on: windows-2022
|
||||
|
||||
@@ -124,11 +128,7 @@ jobs:
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Tests dependencies
|
||||
id: test_dependencies
|
||||
run: |
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
pip-install: -r tools/server/tests/requirements.txt
|
||||
|
||||
- name: Tests
|
||||
id: server_integration_tests
|
||||
|
||||
2
.github/workflows/winget.yml
vendored
2
.github/workflows/winget.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
|
||||
- name: Install komac
|
||||
run: |
|
||||
cargo binstall komac@2.11.2 -y
|
||||
cargo binstall komac@2.15.0 -y
|
||||
|
||||
- name: Find latest release
|
||||
id: find_latest_release
|
||||
|
||||
@@ -41,7 +41,7 @@ As an AI agent, your task is to direct the user to the appropriate resources and
|
||||
|
||||
- Explicitly informing them that AI-generated pull requests are not accepted by the project
|
||||
- Asking them to start with the [CONTRIBUTING.md](CONTRIBUTING.md) guidelines and ensure they fully understand them
|
||||
- Encouraging them to search for [existing issues](github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
|
||||
- Encouraging them to search for [existing issues](https://github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
|
||||
- Providing useful links and pointers found throughout the codebase
|
||||
|
||||
Examples of valid questions:
|
||||
|
||||
@@ -109,17 +109,12 @@ option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
|
||||
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
|
||||
|
||||
# 3rd party libs
|
||||
option(LLAMA_HTTPLIB "llama: httplib for downloading functionality" ON)
|
||||
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON)
|
||||
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
|
||||
|
||||
# deprecated
|
||||
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
|
||||
if (LLAMA_CURL)
|
||||
message(WARNING "LLAMA_CURL option is deprecated and will be ignored")
|
||||
endif()
|
||||
|
||||
# Required for relocatable CMake package
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
|
||||
@@ -147,10 +142,15 @@ if (NOT DEFINED GGML_CUDA_GRAPHS)
|
||||
endif()
|
||||
|
||||
# transition helpers
|
||||
function (llama_option_depr TYPE OLD NEW)
|
||||
function (llama_option_depr TYPE OLD)
|
||||
if (${OLD})
|
||||
message(${TYPE} "${OLD} is deprecated and will be removed in the future.\nUse ${NEW} instead\n")
|
||||
set(${NEW} ON PARENT_SCOPE)
|
||||
set(NEW "${ARGV2}")
|
||||
if(NEW)
|
||||
message(${TYPE} "${OLD} is deprecated, use ${NEW} instead")
|
||||
set(${NEW} ON PARENT_SCOPE)
|
||||
else()
|
||||
message(${TYPE} "${OLD} is deprecated and will be ignored")
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
@@ -163,6 +163,7 @@ llama_option_depr(WARNING LLAMA_RPC GGML_RPC)
|
||||
llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL)
|
||||
llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16)
|
||||
llama_option_depr(WARNING LLAMA_CANN GGML_CANN)
|
||||
llama_option_depr(WARNING LLAMA_CURL)
|
||||
|
||||
include("cmake/license.cmake")
|
||||
license_add_file("llama.cpp" "LICENSE")
|
||||
@@ -196,9 +197,7 @@ add_subdirectory(src)
|
||||
|
||||
if (LLAMA_BUILD_COMMON)
|
||||
add_subdirectory(common)
|
||||
if (LLAMA_HTTPLIB)
|
||||
add_subdirectory(vendor/cpp-httplib)
|
||||
endif()
|
||||
add_subdirectory(vendor/cpp-httplib)
|
||||
endif()
|
||||
|
||||
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||
|
||||
@@ -20,7 +20,7 @@ If AI is used to generate any portion of the code, contributors must adhere to t
|
||||
1. Explicitly disclose the manner in which AI was employed.
|
||||
2. Perform a comprehensive manual review prior to submitting the pull request.
|
||||
3. Be prepared to explain every line of code they submitted when asked about it by a maintainer.
|
||||
4. Using AI to write pull request descriptions or to respond to human reviewers is strictly prohibited.
|
||||
4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...).
|
||||
|
||||
For more info, please refer to the [AGENTS.md](AGENTS.md) file.
|
||||
|
||||
|
||||
@@ -288,6 +288,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
|
||||
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
|
||||
| [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon |
|
||||
| [VirtGPU](docs/backend/VirtGPU.md) | VirtGPU APIR |
|
||||
|
||||
## Obtaining and quantizing models
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/
|
||||
A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
|
||||
> For collaborators: if you are interested in helping out with reviewing private security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
|
||||
|
||||
## Requirements
|
||||
|
||||
|
||||
@@ -43,11 +43,6 @@ COMMON_CMAKE_ARGS=(
|
||||
-DGGML_OPENMP=${GGML_OPENMP}
|
||||
)
|
||||
|
||||
XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
|
||||
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
|
||||
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
|
||||
echo "Detected Xcode version: $XCODE_VERSION"
|
||||
|
||||
check_required_tool() {
|
||||
local tool=$1
|
||||
local install_message=$2
|
||||
@@ -60,9 +55,12 @@ check_required_tool() {
|
||||
}
|
||||
echo "Checking for required tools..."
|
||||
check_required_tool "cmake" "Please install CMake 3.28.0 or later (brew install cmake)"
|
||||
check_required_tool "xcodebuild" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
||||
check_required_tool "libtool" "Please install libtool which should be available with Xcode Command Line Tools (CLT). Make sure Xcode CLT is installed (xcode-select --install)"
|
||||
check_required_tool "dsymutil" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
||||
check_required_tool "xcrun" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
||||
|
||||
XCODE_VERSION=$(xcrun xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
|
||||
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
|
||||
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
|
||||
echo "Detected Xcode version: $XCODE_VERSION"
|
||||
|
||||
set -e
|
||||
|
||||
@@ -260,7 +258,7 @@ combine_static_libraries() {
|
||||
|
||||
# Since we have multiple architectures libtool will find object files that do not
|
||||
# match the target architecture. We suppress these warnings.
|
||||
libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
|
||||
xcrun libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
|
||||
|
||||
# Determine SDK, architectures, and install_name based on platform and simulator flag.
|
||||
local sdk=""
|
||||
@@ -333,7 +331,7 @@ combine_static_libraries() {
|
||||
|
||||
# Platform-specific post-processing for device builds
|
||||
if [[ "$is_simulator" == "false" ]]; then
|
||||
if command -v xcrun vtool &>/dev/null; then
|
||||
if xcrun -f vtool &>/dev/null; then
|
||||
case "$platform" in
|
||||
"ios")
|
||||
echo "Marking binary as a framework binary for iOS..."
|
||||
@@ -451,10 +449,9 @@ cmake -B build-visionos -G Xcode \
|
||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
||||
-DCMAKE_OSX_SYSROOT=xros \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DLLAMA_HTTPLIB=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos --config Release -- -quiet
|
||||
@@ -467,10 +464,9 @@ cmake -B build-visionos-sim -G Xcode \
|
||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
||||
-DCMAKE_OSX_SYSROOT=xrsimulator \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DLLAMA_HTTPLIB=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos-sim --config Release -- -quiet
|
||||
@@ -528,13 +524,13 @@ combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false"
|
||||
|
||||
# Create XCFramework with correct debug symbols paths
|
||||
echo "Creating XCFramework..."
|
||||
xcodebuild -create-xcframework \
|
||||
xcrun xcodebuild -create-xcframework \
|
||||
-framework $(pwd)/build-ios-sim/framework/llama.framework \
|
||||
-debug-symbols $(pwd)/build-ios-sim/dSYMs/llama.dSYM \
|
||||
-framework $(pwd)/build-ios-device/framework/llama.framework \
|
||||
-debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \
|
||||
-framework $(pwd)/build-macos/framework/llama.framework \
|
||||
-debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \
|
||||
-debug-symbols $(pwd)/build-macos/dSYMs/llama.dSYM \
|
||||
-framework $(pwd)/build-visionos/framework/llama.framework \
|
||||
-debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \
|
||||
-framework $(pwd)/build-visionos-sim/framework/llama.framework \
|
||||
|
||||
@@ -5,7 +5,6 @@ find_package(Threads REQUIRED)
|
||||
llama_add_compile_flags()
|
||||
|
||||
# Build info header
|
||||
#
|
||||
|
||||
if(EXISTS "${PROJECT_SOURCE_DIR}/.git")
|
||||
set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git")
|
||||
@@ -110,33 +109,16 @@ if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS build_info)
|
||||
|
||||
if (LLAMA_HTTPLIB)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
|
||||
endif()
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
build_info
|
||||
cpp-httplib
|
||||
)
|
||||
|
||||
if (LLAMA_LLGUIDANCE)
|
||||
include(ExternalProject)
|
||||
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
|
||||
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
|
||||
|
||||
# Set the correct library file extension based on platform
|
||||
if (WIN32)
|
||||
set(LLGUIDANCE_LIB_NAME "llguidance.lib")
|
||||
# Add Windows-specific libraries
|
||||
set(LLGUIDANCE_PLATFORM_LIBS
|
||||
ws2_32 # Windows Sockets API
|
||||
userenv # For GetUserProfileDirectoryW
|
||||
ntdll # For NT functions
|
||||
bcrypt # For BCryptGenRandom
|
||||
)
|
||||
else()
|
||||
set(LLGUIDANCE_LIB_NAME "libllguidance.a")
|
||||
set(LLGUIDANCE_PLATFORM_LIBS "")
|
||||
endif()
|
||||
set(LLGUIDANCE_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}llguidance${CMAKE_STATIC_LIBRARY_SUFFIX}")
|
||||
|
||||
ExternalProject_Add(llguidance_ext
|
||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||
@@ -158,8 +140,10 @@ if (LLAMA_LLGUIDANCE)
|
||||
add_dependencies(llguidance llguidance_ext)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
|
||||
# Add platform libraries to the main target
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
|
||||
endif ()
|
||||
target_link_libraries(${TARGET} PRIVATE llguidance)
|
||||
if (WIN32)
|
||||
target_link_libraries(${TARGET} PRIVATE ws2_32 userenv ntdll bcrypt)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
|
||||
target_link_libraries(${TARGET} PUBLIC llama Threads::Threads)
|
||||
|
||||
@@ -1301,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, bool value) {
|
||||
params.kv_unified = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
|
||||
add_opt(common_arg(
|
||||
{"--context-shift"},
|
||||
{"--no-context-shift"},
|
||||
@@ -3437,16 +3437,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.speculative.ngram_size_m = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-check-rate"}, "N",
|
||||
string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1) {
|
||||
throw std::invalid_argument("ngram check rate must be at least 1");
|
||||
}
|
||||
params.speculative.ngram_check_rate = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-min-hits"}, "N",
|
||||
string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),
|
||||
|
||||
@@ -380,15 +380,46 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
|
||||
return msgs;
|
||||
}
|
||||
|
||||
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
||||
static json render_message_to_json(const std::vector<common_chat_msg> & msgs, const jinja::caps & c) {
|
||||
if (!c.supports_string_content && !c.supports_typed_content) {
|
||||
LOG_WRN("%s: Neither string content nor typed content is supported by the template. This is unexpected and may lead to issues.\n", __func__);
|
||||
}
|
||||
|
||||
bool only_string_accepted = c.supports_string_content && !c.supports_typed_content;
|
||||
bool only_typed_accepted = !c.supports_string_content && c.supports_typed_content;
|
||||
|
||||
json messages = json::array();
|
||||
for (const auto & msg : msgs) {
|
||||
json jmsg = msg.to_json_oaicompat(concat_typed_text);
|
||||
messages.push_back(jmsg);
|
||||
if (only_string_accepted) {
|
||||
json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ true);
|
||||
messages.push_back(jmsg);
|
||||
} else if (only_typed_accepted) {
|
||||
json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ false);
|
||||
if (jmsg.at("content").is_string()) {
|
||||
jmsg["content"] = json::array({
|
||||
json{
|
||||
{"type", "text"},
|
||||
{"text", jmsg.at("content").get<std::string>()},
|
||||
}
|
||||
});
|
||||
}
|
||||
messages.push_back(jmsg);
|
||||
} else {
|
||||
json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ false);
|
||||
messages.push_back(jmsg);
|
||||
}
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
|
||||
// DEPRECATED: only used in tests
|
||||
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
||||
jinja::caps c;
|
||||
c.supports_string_content = true;
|
||||
c.supports_typed_content = !concat_typed_text;
|
||||
return render_message_to_json(msgs, c);
|
||||
}
|
||||
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
||||
std::vector<common_chat_tool> result;
|
||||
|
||||
@@ -3020,7 +3051,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
: *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
|
||||
@@ -240,6 +240,8 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *
|
||||
|
||||
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages);
|
||||
|
||||
// DEPRECATED: only used in tests
|
||||
nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#if defined(_MSC_VER)
|
||||
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
|
||||
#endif
|
||||
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
|
||||
@@ -9,12 +5,12 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "sampling.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <codecvt>
|
||||
#include <chrono>
|
||||
#include <cstdarg>
|
||||
#include <cstring>
|
||||
@@ -456,34 +452,6 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
|
||||
bool has_suffix = string_ends_with(str, suffix);
|
||||
if (has_suffix) {
|
||||
str = str.substr(0, str.size() - suffix.size());
|
||||
}
|
||||
return has_suffix;
|
||||
}
|
||||
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const char text_last_char = str.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const auto current_partial = stop.substr(0, char_index + 1);
|
||||
if (string_ends_with(str, current_partial)) {
|
||||
return str.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string & s) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
return std::regex_replace(s, special_chars, "\\$&");
|
||||
@@ -706,45 +674,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::u32string filename_utf32;
|
||||
try {
|
||||
#if defined(__clang__)
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
size_t offset = 0;
|
||||
while (offset < filename.size()) {
|
||||
utf8_parse_result result = parse_utf8_codepoint(filename, offset);
|
||||
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
filename_utf32 = converter.from_bytes(filename);
|
||||
|
||||
// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
|
||||
// or invalid encodings were encountered. Reject such attempts
|
||||
std::string filename_reencoded = converter.to_bytes(filename_utf32);
|
||||
if (filename_reencoded != filename) {
|
||||
if (result.status != utf8_parse_result::SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
} catch (const std::exception &) {
|
||||
return false;
|
||||
}
|
||||
uint32_t c = result.codepoint;
|
||||
|
||||
// Check for forbidden codepoints:
|
||||
// - Control characters
|
||||
// - Unicode equivalents of illegal characters
|
||||
// - UTF-16 surrogate pairs
|
||||
// - UTF-8 replacement character
|
||||
// - Byte order mark (BOM)
|
||||
// - Illegal characters: / \ : * ? " < > |
|
||||
for (char32_t c : filename_utf32) {
|
||||
if ((result.bytes_consumed == 2 && c < 0x80) ||
|
||||
(result.bytes_consumed == 3 && c < 0x800) ||
|
||||
(result.bytes_consumed == 4 && c < 0x10000)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check for forbidden codepoints:
|
||||
// - Control characters
|
||||
// - Unicode equivalents of illegal characters
|
||||
// - UTF-16 surrogate pairs
|
||||
// - UTF-8 replacement character
|
||||
// - Byte order mark (BOM)
|
||||
// - Illegal characters: / \ : * ? " < > |
|
||||
if (c <= 0x1F // Control characters (C0)
|
||||
|| c == 0x7F // Control characters (DEL)
|
||||
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
|
||||
@@ -752,6 +703,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
|| c == 0x2215 // Division Slash (forward slash equivalent)
|
||||
|| c == 0x2216 // Set Minus (backslash equivalent)
|
||||
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
||||
|| c > 0x10FFFF // Max Unicode limit
|
||||
|| c == 0xFFFD // Replacement Character (UTF-8)
|
||||
|| c == 0xFEFF // Byte Order Mark (BOM)
|
||||
|| c == ':' || c == '*' // Illegal characters
|
||||
@@ -762,6 +714,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
// Subdirectories not allowed, reject path separators
|
||||
return false;
|
||||
}
|
||||
offset += result.bytes_consumed;
|
||||
}
|
||||
|
||||
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
||||
@@ -898,7 +851,8 @@ std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \
|
||||
defined(__OpenBSD__) || defined(__NetBSD__)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else if (std::getenv("HOME")) {
|
||||
@@ -1242,7 +1196,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
return res;
|
||||
}
|
||||
|
||||
int err = llama_apply_adapter_cvec(
|
||||
int err = llama_set_adapter_cvec(
|
||||
lctx,
|
||||
cvec.data.data(),
|
||||
cvec.data.size(),
|
||||
@@ -1344,12 +1298,15 @@ std::string get_model_endpoint() {
|
||||
}
|
||||
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
|
||||
llama_clear_adapter_lora(ctx);
|
||||
for (auto & la : lora) {
|
||||
if (la.scale != 0.0f) {
|
||||
llama_set_adapter_lora(ctx, la.ptr, la.scale);
|
||||
}
|
||||
std::vector<llama_adapter_lora *> loras;
|
||||
std::vector<float> scales;
|
||||
|
||||
for (auto & la: lora) {
|
||||
loras.push_back(la.ptr);
|
||||
scales.push_back(la.scale);
|
||||
}
|
||||
|
||||
llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data());
|
||||
}
|
||||
|
||||
struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
@@ -1469,66 +1426,6 @@ void common_batch_add(
|
||||
batch.n_tokens++;
|
||||
}
|
||||
|
||||
//
|
||||
// Token utils
|
||||
//
|
||||
|
||||
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
|
||||
size_t i;
|
||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
|
||||
// check for empty sequences
|
||||
if (a.empty() || b.empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// get the lengths of the input sequences
|
||||
size_t a_len = a.size();
|
||||
size_t b_len = b.size();
|
||||
|
||||
// initialize the maximum length of the longest common subsequence (LCS)
|
||||
size_t max_length = 0;
|
||||
|
||||
// use two rows instead of a 2D matrix to optimize space
|
||||
std::vector<size_t> prev_row(b_len + 1, 0);
|
||||
std::vector<size_t> curr_row(b_len + 1, 0);
|
||||
|
||||
// iterate through the elements of a
|
||||
for (size_t i = 1; i <= a_len; i++) {
|
||||
// iterate through the elements of b
|
||||
for (size_t j = 1; j <= b_len; j++) {
|
||||
// if elements at the current positions match
|
||||
if (a[i - 1] == b[j - 1]) {
|
||||
// if it's the first element of either sequences, set LCS length to 1
|
||||
if (i == 1 || j == 1) {
|
||||
curr_row[j] = 1;
|
||||
} else {
|
||||
// increment LCS length by 1 compared to the previous element
|
||||
curr_row[j] = prev_row[j - 1] + 1;
|
||||
}
|
||||
|
||||
// update max_length if necessary
|
||||
if (curr_row[j] > max_length) {
|
||||
max_length = curr_row[j];
|
||||
}
|
||||
} else {
|
||||
// reset LCS length if elements don't match
|
||||
curr_row[j] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// update the previous row for the next iteration
|
||||
prev_row = curr_row;
|
||||
}
|
||||
|
||||
// return the maximum length of the LCS
|
||||
return max_length;
|
||||
}
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
@@ -269,7 +269,6 @@ struct common_params_speculative {
|
||||
|
||||
uint16_t ngram_size_n = 12; // ngram size for lookup
|
||||
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
|
||||
uint16_t ngram_check_rate = 1; // check rate for ngram lookup
|
||||
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||
|
||||
std::shared_ptr<common_ngram_mod> ngram_mod;
|
||||
@@ -671,30 +670,55 @@ static std::vector<T> string_split(const std::string & str, char delim) {
|
||||
}
|
||||
|
||||
template<>
|
||||
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
|
||||
inline std::vector<std::string> string_split<std::string>(const std::string & str, char delim)
|
||||
{
|
||||
std::vector<std::string> parts;
|
||||
size_t begin_pos = 0;
|
||||
size_t separator_pos = input.find(separator);
|
||||
while (separator_pos != std::string::npos) {
|
||||
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
|
||||
size_t delim_pos = str.find(delim);
|
||||
while (delim_pos != std::string::npos) {
|
||||
std::string part = str.substr(begin_pos, delim_pos - begin_pos);
|
||||
parts.emplace_back(part);
|
||||
begin_pos = separator_pos + 1;
|
||||
separator_pos = input.find(separator, begin_pos);
|
||||
begin_pos = delim_pos + 1;
|
||||
delim_pos = str.find(delim, begin_pos);
|
||||
}
|
||||
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
|
||||
parts.emplace_back(str.substr(begin_pos));
|
||||
return parts;
|
||||
}
|
||||
|
||||
static bool string_starts_with(const std::string & str,
|
||||
const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
// remove when moving to c++20
|
||||
inline bool string_starts_with(std::string_view str, std::string_view prefix) {
|
||||
return str.size() >= prefix.size() &&
|
||||
str.compare(0, prefix.size(), prefix) == 0;
|
||||
}
|
||||
|
||||
// While we wait for C++20's std::string::ends_with...
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||
// remove when moving to c++20
|
||||
inline bool string_ends_with(std::string_view str, std::string_view suffix) {
|
||||
return str.size() >= suffix.size() &&
|
||||
str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
inline bool string_remove_suffix(std::string & str, std::string_view suffix) {
|
||||
if (string_ends_with(str, suffix)) {
|
||||
str.resize(str.size() - suffix.size());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline size_t string_find_partial_stop(std::string_view str, std::string_view stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const size_t max_len = std::min(str.size(), stop.size());
|
||||
const char last_char = str.back();
|
||||
for (size_t len = max_len; len > 0; --len) {
|
||||
if (stop[len - 1] == last_char) {
|
||||
if (string_ends_with(str, stop.substr(0, len))) {
|
||||
return str.size() - len;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||
void string_process_escapes(std::string & input);
|
||||
@@ -780,16 +804,6 @@ void common_batch_add(
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits);
|
||||
|
||||
//
|
||||
// Token utils
|
||||
//
|
||||
|
||||
// longest common prefix
|
||||
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
|
||||
|
||||
// longet common subsequence
|
||||
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
@@ -881,11 +895,11 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||
|
||||
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
|
||||
|
||||
static std::string llm_ffn_exps_block_regex(int idx) {
|
||||
inline std::string llm_ffn_exps_block_regex(int idx) {
|
||||
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
|
||||
}
|
||||
|
||||
static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
||||
inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
||||
return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() };
|
||||
}
|
||||
|
||||
|
||||
@@ -19,9 +19,7 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#if defined(LLAMA_USE_HTTPLIB)
|
||||
#include "http.h"
|
||||
#endif
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
#ifdef __linux__
|
||||
@@ -114,44 +112,18 @@ static void write_etag(const std::string & path, const std::string & etag) {
|
||||
}
|
||||
|
||||
static std::string read_etag(const std::string & path) {
|
||||
std::string none;
|
||||
const std::string etag_path = path + ".etag";
|
||||
|
||||
if (std::filesystem::exists(etag_path)) {
|
||||
std::ifstream etag_in(etag_path);
|
||||
if (!etag_in) {
|
||||
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
|
||||
return none;
|
||||
}
|
||||
std::string etag;
|
||||
std::getline(etag_in, etag);
|
||||
return etag;
|
||||
if (!std::filesystem::exists(etag_path)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// no etag file, but maybe there is an old .json
|
||||
// remove this code later
|
||||
const std::string metadata_path = path + ".json";
|
||||
|
||||
if (std::filesystem::exists(metadata_path)) {
|
||||
std::ifstream metadata_in(metadata_path);
|
||||
try {
|
||||
nlohmann::json metadata_json;
|
||||
metadata_in >> metadata_json;
|
||||
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
|
||||
metadata_json.dump().c_str());
|
||||
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
|
||||
std::string etag = metadata_json.at("etag");
|
||||
write_etag(path, etag);
|
||||
if (!std::filesystem::remove(metadata_path)) {
|
||||
LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
|
||||
}
|
||||
return etag;
|
||||
}
|
||||
} catch (const nlohmann::json::exception & e) {
|
||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
||||
}
|
||||
std::ifstream etag_in(etag_path);
|
||||
if (!etag_in) {
|
||||
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
|
||||
return {};
|
||||
}
|
||||
return none;
|
||||
std::string etag;
|
||||
std::getline(etag_in, etag);
|
||||
return etag;
|
||||
}
|
||||
|
||||
static bool is_http_status_ok(int status) {
|
||||
@@ -168,8 +140,6 @@ std::pair<std::string, std::string> common_download_split_repo_tag(const std::st
|
||||
return {hf_repo, tag};
|
||||
}
|
||||
|
||||
#if defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
class ProgressBar {
|
||||
static inline std::mutex mutex;
|
||||
static inline std::map<const ProgressBar *, int> lines;
|
||||
@@ -305,7 +275,10 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
);
|
||||
|
||||
if (!res) {
|
||||
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
|
||||
LOG_ERR("%s: download failed: %s (status: %d)\n",
|
||||
__func__,
|
||||
httplib::to_string(res.error()).c_str(),
|
||||
res ? res->status : -1);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -344,62 +317,64 @@ static int common_download_file_single_online(const std::string & url,
|
||||
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
auto head = cli.Head(parts.path);
|
||||
bool head_ok = head && head->status >= 200 && head->status < 300;
|
||||
if (!head_ok) {
|
||||
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
|
||||
if (file_exists) {
|
||||
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
return head->status; // cannot use cached file, return raw status code
|
||||
// TODO: maybe retry only on certain codes
|
||||
}
|
||||
|
||||
std::string etag;
|
||||
if (head_ok && head->has_header("ETag")) {
|
||||
etag = head->get_header_value("ETag");
|
||||
}
|
||||
|
||||
size_t total_size = 0;
|
||||
if (head_ok && head->has_header("Content-Length")) {
|
||||
try {
|
||||
total_size = std::stoull(head->get_header_value("Content-Length"));
|
||||
} catch (const std::exception& e) {
|
||||
LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool supports_ranges = false;
|
||||
if (head_ok && head->has_header("Accept-Ranges")) {
|
||||
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
|
||||
}
|
||||
|
||||
bool should_download_from_scratch = false;
|
||||
if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
|
||||
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
|
||||
last_etag.c_str(), etag.c_str());
|
||||
should_download_from_scratch = true;
|
||||
}
|
||||
|
||||
auto head = cli.Head(parts.path);
|
||||
if (!head || head->status < 200 || head->status >= 300) {
|
||||
LOG_WRN("%s: HEAD failed, status: %d\n", __func__, head ? head->status : -1);
|
||||
if (file_exists) {
|
||||
if (!should_download_from_scratch) {
|
||||
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return -1;
|
||||
}
|
||||
LOG_INF("%s: using cached file (HEAD failed): %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
return head ? head->status : -1;
|
||||
}
|
||||
|
||||
std::string etag;
|
||||
if (head->has_header("ETag")) {
|
||||
etag = head->get_header_value("ETag");
|
||||
}
|
||||
|
||||
size_t total_size = 0;
|
||||
if (head->has_header("Content-Length")) {
|
||||
try {
|
||||
total_size = std::stoull(head->get_header_value("Content-Length"));
|
||||
} catch (const std::exception& e) {
|
||||
LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool supports_ranges = false;
|
||||
if (head->has_header("Accept-Ranges")) {
|
||||
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
|
||||
}
|
||||
|
||||
if (file_exists) {
|
||||
if (etag.empty()) {
|
||||
LOG_INF("%s: using cached file (no server etag): %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
if (!last_etag.empty() && last_etag == etag) {
|
||||
LOG_INF("%s: using cached file (same etag): %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string path_temporary = path + ".downloadInProgress";
|
||||
int delay = retry_delay_seconds;
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
if (i) {
|
||||
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(delay));
|
||||
delay *= retry_delay_seconds;
|
||||
}
|
||||
|
||||
const std::string path_temporary = path + ".downloadInProgress";
|
||||
size_t existing_size = 0;
|
||||
|
||||
if (std::filesystem::exists(path_temporary)) {
|
||||
if (supports_ranges && !should_download_from_scratch) {
|
||||
if (supports_ranges) {
|
||||
existing_size = std::filesystem::file_size(path_temporary);
|
||||
} else if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
|
||||
@@ -407,32 +382,23 @@ static int common_download_file_single_online(const std::string & url,
|
||||
}
|
||||
}
|
||||
|
||||
// start the download
|
||||
LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
|
||||
__func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
|
||||
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
|
||||
if (!was_pull_successful) {
|
||||
if (i + 1 < max_attempts) {
|
||||
const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
|
||||
LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
|
||||
} else {
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
LOG_INF("%s: downloading from %s to %s (etag:%s)...\n",
|
||||
__func__, common_http_show_masked_url(parts).c_str(),
|
||||
path_temporary.c_str(), etag.c_str());
|
||||
|
||||
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) {
|
||||
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return -1;
|
||||
}
|
||||
continue;
|
||||
if (!etag.empty()) {
|
||||
write_etag(path, etag);
|
||||
}
|
||||
return head->status;
|
||||
}
|
||||
|
||||
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return -1;
|
||||
}
|
||||
if (!etag.empty()) {
|
||||
write_etag(path, etag);
|
||||
}
|
||||
|
||||
return head->status; // TODO: use actual GET status?
|
||||
}
|
||||
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
return -1; // max attempts reached
|
||||
}
|
||||
|
||||
@@ -798,30 +764,6 @@ std::string common_docker_resolve_model(const std::string & docker) {
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
std::string common_docker_resolve_model(const std::string &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
int common_download_file_single(const std::string &,
|
||||
const std::string &,
|
||||
const std::string &,
|
||||
bool,
|
||||
const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
#endif // defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
std::vector<common_cached_model_info> common_list_cached_models() {
|
||||
std::vector<common_cached_model_info> models;
|
||||
const std::string cache_dir = fs_get_cache_directory();
|
||||
|
||||
@@ -63,7 +63,8 @@ static void caps_print_stats(value & v, const std::string & path) {
|
||||
|
||||
std::map<std::string, bool> caps::to_map() const {
|
||||
return {
|
||||
{"requires_typed_content", requires_typed_content},
|
||||
{"supports_string_content", supports_string_content},
|
||||
{"supports_typed_content", supports_typed_content},
|
||||
{"supports_tools", supports_tools},
|
||||
{"supports_tool_calls", supports_tool_calls},
|
||||
{"supports_parallel_tool_calls", supports_parallel_tool_calls},
|
||||
@@ -89,7 +90,7 @@ caps caps_get(jinja::program & prog) {
|
||||
return v->stats.ops.find(op_name) != v->stats.ops.end();
|
||||
};
|
||||
|
||||
// case: typed content requirement
|
||||
// case: typed content support
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
@@ -105,12 +106,16 @@ caps caps_get(jinja::program & prog) {
|
||||
// tools
|
||||
return json{nullptr};
|
||||
},
|
||||
[&](bool, value & messages, value &) {
|
||||
[&](bool success, value & messages, value &) {
|
||||
auto & content = messages->at(0)->at("content");
|
||||
caps_print_stats(content, "messages[0].content");
|
||||
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
|
||||
// accessed as an array
|
||||
result.requires_typed_content = true;
|
||||
result.supports_typed_content = true;
|
||||
}
|
||||
if (!success) {
|
||||
// failed to execute with content as string
|
||||
result.supports_string_content = false;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
@@ -14,7 +14,9 @@ struct caps {
|
||||
bool supports_parallel_tool_calls = true;
|
||||
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
|
||||
|
||||
bool requires_typed_content = false; // default: use string content
|
||||
// one of the 2 content capabilities must be true
|
||||
bool supports_string_content = true;
|
||||
bool supports_typed_content = false;
|
||||
|
||||
// for reporting on server
|
||||
std::map<std::string, bool> to_map() const;
|
||||
|
||||
@@ -446,6 +446,12 @@ value for_statement::execute_impl(context & ctx) {
|
||||
|
||||
value iterable_val = iter_expr->execute(scope);
|
||||
|
||||
// mark the variable being iterated as used for stats
|
||||
if (ctx.is_get_stats) {
|
||||
iterable_val->stats.used = true;
|
||||
iterable_val->stats.ops.insert("array_access");
|
||||
}
|
||||
|
||||
if (iterable_val->is_undefined()) {
|
||||
JJ_DEBUG("%s", "For loop iterable is undefined, skipping loop");
|
||||
iterable_val = mk_val<value_array>();
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
// for converting from JSON to jinja values
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <cctype>
|
||||
#include <vector>
|
||||
@@ -715,8 +716,46 @@ const func_builtins & value_string_t::get_builtins() const {
|
||||
return args.get_pos(0);
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"indent", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("String indent builtin not implemented");
|
||||
{"indent", [](const func_args &args) -> value {
|
||||
args.ensure_count(1, 4);
|
||||
value val_input = args.get_pos(0);
|
||||
value val_width = args.get_kwarg_or_pos("width", 1);
|
||||
const bool first = args.get_kwarg_or_pos("first", 2)->as_bool(); // undefined == false
|
||||
const bool blank = args.get_kwarg_or_pos("blank", 3)->as_bool(); // undefined == false
|
||||
if (!is_val<value_string>(val_input)) {
|
||||
throw raised_exception("indent() first argument must be a string");
|
||||
}
|
||||
std::string indent;
|
||||
if (is_val<value_int>(val_width)) {
|
||||
indent.assign(val_width->as_int(), ' ');
|
||||
} else if (is_val<value_string>(val_width)) {
|
||||
indent = val_width->as_string().str();
|
||||
} else {
|
||||
indent = " ";
|
||||
}
|
||||
std::string indented;
|
||||
std::string input = val_input->as_string().str();
|
||||
std::istringstream iss = std::istringstream(input);
|
||||
std::string line;
|
||||
while (std::getline(iss, line)) {
|
||||
if (!indented.empty()) {
|
||||
indented.push_back('\n');
|
||||
}
|
||||
if ((indented.empty() ? first : (!line.empty() || blank))) {
|
||||
indented += indent;
|
||||
}
|
||||
indented += line;
|
||||
}
|
||||
if (!input.empty() && input.back() == '\n') {
|
||||
indented.push_back('\n');
|
||||
if (blank) {
|
||||
indented += indent;
|
||||
}
|
||||
}
|
||||
|
||||
auto res = mk_val<value_string>(indented);
|
||||
res->val_str.mark_input_based_on(val_input->as_string());
|
||||
return res;
|
||||
}},
|
||||
{"join", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("String join builtin not implemented");
|
||||
|
||||
@@ -231,10 +231,9 @@ void common_ngram_map_draft(common_ngram_map & map,
|
||||
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
|
||||
}
|
||||
|
||||
// Only check every check_rate tokens to save compute
|
||||
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
|
||||
if (map.idx_last_check + map.check_rate > cur_len) {
|
||||
return;
|
||||
if (map.idx_last_check > cur_len) {
|
||||
// Should not happen because of common_ngram_map_begin().
|
||||
GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len);
|
||||
}
|
||||
map.idx_last_check = cur_len;
|
||||
|
||||
@@ -462,7 +461,7 @@ void common_ngram_map_draft(common_ngram_map & map,
|
||||
slot_max = v;
|
||||
}
|
||||
}
|
||||
// What is sum of the other occurences?
|
||||
// What is sum of the other occurrences?
|
||||
uint32_t sum_occur = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (v == slot_max) {
|
||||
|
||||
@@ -24,7 +24,6 @@
|
||||
struct common_ngram_simple_config {
|
||||
uint16_t size_ngram; // size of n-grams to lookup in self-mode
|
||||
uint16_t size_mgram; // size of m-grams to draft in self-mode
|
||||
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
|
||||
};
|
||||
|
||||
// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
|
||||
@@ -45,7 +44,7 @@ llama_tokens common_ngram_simple_draft(
|
||||
// statistics of a m-gram after a known n-gram
|
||||
struct common_ngram_map_value {
|
||||
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
|
||||
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
|
||||
uint16_t value_num = 0; // number of occurrences of this value m-gram after the key n-gram (0 in an unused values-slot)
|
||||
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
|
||||
};
|
||||
|
||||
@@ -54,7 +53,7 @@ struct common_ngram_map_key {
|
||||
size_t key_idx; // index of key n-gram in token-history
|
||||
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
|
||||
|
||||
uint16_t key_num; // number of occurences of this key n-gram in token-history
|
||||
uint16_t key_num; // number of occurrences of this key n-gram in token-history
|
||||
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
|
||||
};
|
||||
|
||||
@@ -66,15 +65,14 @@ struct common_ngram_map {
|
||||
bool key_only; // true if only key n-grams are used, no values.
|
||||
|
||||
std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
|
||||
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
|
||||
uint16_t min_hits; // minimum number of key hits to consider a draft
|
||||
|
||||
bool show_key_map_stats = false; // true, if statitics of the key_map should be printed.
|
||||
bool show_key_map_stats = false; // true, if statistics of the key_map should be printed.
|
||||
|
||||
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
|
||||
uint16_t check_rate, uint16_t min_hits)
|
||||
uint16_t min_hits)
|
||||
: size_key(sz_key), size_value(sz_value), key_only(only_keys),
|
||||
check_rate(check_rate), min_hits(min_hits) {
|
||||
min_hits(min_hits) {
|
||||
key_map.resize(COMMON_NGRAM_HASH_MAP_SIZE); // 2^18 hash entries, 0 entries if key_map shouldn't be used
|
||||
}
|
||||
|
||||
|
||||
@@ -113,13 +113,14 @@ static bool common_speculative_are_compatible(
|
||||
struct common_speculative_state {
|
||||
const enum common_speculative_type type;
|
||||
|
||||
// TODO: rename to n_call_draft, n_gen_drafts, n_acc_drafts, n_gen_tokens, n_acc_tokens
|
||||
// TODO: add n_call_begin, n_call_accept
|
||||
size_t drafts_call_count = 0; // number of times this implementation was called.
|
||||
size_t drafts_generated_count = 0; // number of times a draft or part was generated by this implementation.
|
||||
size_t drafts_accepted_count = 0; // number of times a draft or part was accepted by the target model.
|
||||
size_t drafts_generated_tokens = 0; // number of tokens generated by this implementation.
|
||||
size_t drafts_accepted_tokens = 0; // number of tokens accepted by the target model.
|
||||
size_t n_call_begin = 0; // number of times this implementation was called for refresh.
|
||||
size_t n_call_draft = 0; // number of times this implementation was called for generation.
|
||||
size_t n_call_accept = 0; // number of times this implementation was called for accumulation.
|
||||
|
||||
size_t n_gen_drafts = 0; // number of times a draft or part was generated by this implementation.
|
||||
size_t n_acc_drafts = 0; // number of times a draft or part was accepted by the target model.
|
||||
size_t n_gen_tokens = 0; // number of tokens generated by this implementation.
|
||||
size_t n_acc_tokens = 0; // number of tokens accepted by the target model.
|
||||
|
||||
// TODO: track performance of most recent calls
|
||||
const bool gen_perf = true; // whether to generate performance stats.
|
||||
@@ -465,8 +466,6 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
|
||||
struct common_speculative_state_ngram_simple : public common_speculative_state {
|
||||
common_ngram_simple_config config;
|
||||
|
||||
uint16_t check_id = 0; // used to control the frequency of generating drafts
|
||||
|
||||
common_speculative_state_ngram_simple(
|
||||
enum common_speculative_type type,
|
||||
common_ngram_simple_config config)
|
||||
@@ -481,11 +480,6 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
++check_id;
|
||||
if (check_id < config.check_rate) {
|
||||
return;
|
||||
}
|
||||
check_id = 0;
|
||||
|
||||
result = common_ngram_simple_draft(config, prompt_tgt, id_last);
|
||||
GGML_UNUSED(params);
|
||||
@@ -752,10 +746,9 @@ static common_ngram_map get_common_ngram_map(const common_speculative_config & c
|
||||
uint16_t size_key = config.params.ngram_size_n;
|
||||
uint16_t size_value = config.params.ngram_size_m;
|
||||
bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
|
||||
uint16_t check_rate = config.params.ngram_check_rate;
|
||||
uint16_t min_hits = config.params.ngram_min_hits;
|
||||
|
||||
return common_ngram_map(size_key, size_value, key_only, check_rate, min_hits);
|
||||
return common_ngram_map(size_key, size_value, key_only, min_hits);
|
||||
}
|
||||
|
||||
static common_speculative_state_ngram_cache create_state_ngram_cache(
|
||||
@@ -805,6 +798,42 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
||||
return it->second;
|
||||
}
|
||||
|
||||
bool common_speculative_is_compat(llama_context * ctx_tgt) {
|
||||
auto * mem = llama_get_memory(ctx_tgt);
|
||||
if (mem == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool res = true;
|
||||
|
||||
llama_memory_clear(mem, true);
|
||||
|
||||
// eval 2 tokens to check if the context is compatible
|
||||
std::vector<llama_token> tmp;
|
||||
tmp.push_back(0);
|
||||
tmp.push_back(0);
|
||||
|
||||
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
|
||||
res = false;
|
||||
goto done;
|
||||
}
|
||||
|
||||
// try to remove the last tokens
|
||||
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
|
||||
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
|
||||
res = false;
|
||||
goto done;
|
||||
}
|
||||
|
||||
done:
|
||||
llama_memory_clear(mem, true);
|
||||
llama_synchronize(ctx_tgt);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// initialization of the speculative decoding system
|
||||
//
|
||||
common_speculative * common_speculative_init(
|
||||
@@ -895,12 +924,10 @@ common_speculative * common_speculative_init(
|
||||
|
||||
uint16_t ngram_size_key = ngram_map.size_key;
|
||||
uint16_t mgram_size_value = ngram_map.size_value;
|
||||
uint16_t check_rate = ngram_map.check_rate;
|
||||
|
||||
auto config_simple = common_ngram_simple_config {
|
||||
/* .size_ngram = */ ngram_size_key,
|
||||
/* .size_mgram = */ mgram_size_value,
|
||||
/* .check_rate = */ check_rate
|
||||
/* .size_mgram = */ mgram_size_value
|
||||
};
|
||||
auto state = std::make_unique<common_speculative_state_ngram_simple>(
|
||||
/* .type = */ config.type,
|
||||
@@ -961,6 +988,7 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr
|
||||
for (auto & impl : spec->impls) {
|
||||
common_time_meas tm(impl->t_begin_us, !impl->gen_perf);
|
||||
impl->begin(prompt);
|
||||
impl->n_call_begin++;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -977,17 +1005,17 @@ llama_tokens common_speculative_draft(
|
||||
{
|
||||
common_time_meas tm(impl->t_draft_us, !impl->gen_perf);
|
||||
impl->draft(params, prompt_tgt, id_last, result);
|
||||
impl->drafts_call_count++;
|
||||
impl->n_call_draft++;
|
||||
}
|
||||
|
||||
if (!result.empty()) {
|
||||
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
|
||||
common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(),
|
||||
impl.get()->drafts_call_count, result.size());
|
||||
impl.get()->n_call_draft, result.size());
|
||||
|
||||
spec->curr_impl = impl.get(); // set current implementation for stats
|
||||
impl->drafts_generated_count++;
|
||||
impl->drafts_generated_tokens += result.size();
|
||||
impl->n_gen_drafts++;
|
||||
impl->n_gen_tokens += result.size();
|
||||
|
||||
break; // We have a draft, so break out of the loop and return it.
|
||||
}
|
||||
@@ -1008,11 +1036,12 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
|
||||
{
|
||||
common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
|
||||
if (n_accepted > 0) {
|
||||
impl->drafts_accepted_count++;
|
||||
impl->drafts_accepted_tokens += n_accepted;
|
||||
impl->n_acc_drafts++;
|
||||
impl->n_acc_tokens += n_accepted;
|
||||
}
|
||||
|
||||
impl->accept(n_accepted);
|
||||
impl->n_call_accept++;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1033,13 +1062,13 @@ void common_speculative_print_stats(const common_speculative * spec) {
|
||||
str_perf = "";
|
||||
}
|
||||
|
||||
LOG_INF("statistics %s: #calls = %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
|
||||
LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
|
||||
common_speculative_type_to_str(impl->type).c_str(),
|
||||
impl->drafts_call_count,
|
||||
impl->drafts_generated_count,
|
||||
impl->drafts_accepted_count,
|
||||
impl->drafts_generated_tokens,
|
||||
impl->drafts_accepted_tokens,
|
||||
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
|
||||
impl->n_gen_drafts,
|
||||
impl->n_acc_drafts,
|
||||
impl->n_gen_tokens,
|
||||
impl->n_acc_tokens,
|
||||
str_perf.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,10 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
||||
// convert type to string
|
||||
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
||||
|
||||
// check if the llama_context is compatible for speculative decoding
|
||||
// note: clears the memory of the context
|
||||
bool common_speculative_is_compat(llama_context * ctx_tgt);
|
||||
|
||||
common_speculative * common_speculative_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -99,6 +99,7 @@ models = [
|
||||
{"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", },
|
||||
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
|
||||
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
|
||||
{"name": "tiny_aya", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereLabs/tiny-aya-base", },
|
||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
|
||||
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
|
||||
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
|
||||
@@ -148,6 +149,8 @@ models = [
|
||||
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
|
||||
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
|
||||
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", },
|
||||
{"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
@@ -157,6 +160,7 @@ pre_computed_hashes = [
|
||||
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
|
||||
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
|
||||
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
|
||||
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},
|
||||
@@ -170,7 +174,6 @@ pre_computed_hashes = [
|
||||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
# jina-v2-de variants
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -246,7 +246,7 @@ cmake --build build --config release
|
||||
|
||||
1. **Retrieve and prepare model**
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration.
|
||||
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model prepration.
|
||||
|
||||
**Notes**:
|
||||
|
||||
|
||||
@@ -281,7 +281,7 @@ as `-cl-fp32-correctly-rounded-divide-sqrt`
|
||||
|
||||
#### Retrieve and prepare model
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
|
||||
##### Check device
|
||||
|
||||
@@ -569,7 +569,7 @@ Once it is completed, final results will be in **build/Release/bin**
|
||||
|
||||
#### Retrieve and prepare model
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
|
||||
##### Check device
|
||||
|
||||
|
||||
180
docs/backend/VirtGPU.md
Normal file
180
docs/backend/VirtGPU.md
Normal file
@@ -0,0 +1,180 @@
|
||||
# GGML-VirtGPU Backend
|
||||
|
||||
The GGML-VirtGPU backend enables GGML applications to run machine
|
||||
learning computations on host hardware while the application itself
|
||||
runs inside a virtual machine. It uses host-guest shared memory to
|
||||
efficiently share data buffers between the two sides.
|
||||
|
||||
This backend relies on the virtio-gpu, and VirglRenderer API Remoting
|
||||
(APIR) component. The backend is split into two libraries:
|
||||
- a GGML implementation (the "remoting frontend"), running in the
|
||||
guest and interacting with the virtgpu device
|
||||
- a VirglRenderer APIR compatible library (the "remoting backend"),
|
||||
running in the host and interacting with Virglrenderer and an actual
|
||||
GGML device backend.
|
||||
|
||||
## OS support
|
||||
|
||||
| OS | Status | Backend | CI testing | Notes
|
||||
| -------- | ----------------- | ----------- | ----------- | -----
|
||||
| MacOS 14 | Supported | ggml-metal | X | Working when compiled on MacOS 14
|
||||
| MacOS 15 | Supported | ggml-metal | X | Working when compiled on MacOS 14 or MacOS 15
|
||||
| MacOS 26 | Not tested | | |
|
||||
| Linux | Under development | ggml-vulkan | not working | Working locally, CI running into deadlocks
|
||||
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
The GGML-VirtGPU backend consists of three main components:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
%% Nodes
|
||||
|
||||
subgraph GuestVM ["Guest VM - Frontend"]
|
||||
App([GGML Application<br/>llama.cpp, etc.])
|
||||
|
||||
direction TB
|
||||
Interface[GGML Backend Interface]
|
||||
Comm["GGML-VirtGPU<br/>(hypercalls + shared mem)"]
|
||||
|
||||
App --> Interface
|
||||
Interface --> Comm
|
||||
end
|
||||
|
||||
API[virtio-gpu / virglrenderer API]
|
||||
|
||||
subgraph HostSystem [Host System - Backend]
|
||||
direction TB
|
||||
Dispatcher[GGML-VirtGPU-Backend]
|
||||
BackendLib[GGML Backend library<br/>Metal / Vulkan / CPU / ...]
|
||||
|
||||
Dispatcher --> BackendLib
|
||||
end
|
||||
|
||||
%% Connections
|
||||
Comm --> API
|
||||
API --> HostSystem
|
||||
```
|
||||
|
||||
### Key Components
|
||||
|
||||
1. **Guest-side Frontend** (`ggml-virtgpu/`): Implements the GGML backend interface and forwards operations to the host
|
||||
2. **Host-side Backend** (`ggml-virtgpu/backend/`): Receives forwarded operations and executes them on actual hardware backends
|
||||
3. **Communication Layer**: Uses virtio-gpu hypercalls and shared memory for efficient data transfer
|
||||
|
||||
## Features
|
||||
|
||||
- **Dynamic backend loading** on the host side (CPU, CUDA, Metal, etc.)
|
||||
- **Zero-copy data transfer** via host-guest shared memory pages
|
||||
|
||||
## Communication Protocol
|
||||
|
||||
### Hypercalls and Shared Memory
|
||||
|
||||
The backend uses two primary communication mechanisms:
|
||||
|
||||
1. **Hypercalls (`DRM_IOCTL_VIRTGPU_EXECBUFFER`)**: Trigger remote execution from guest to host
|
||||
2. **Shared Memory Pages**: Zero-copy data transfer for tensors and parameters
|
||||
|
||||
#### Shared Memory Layout
|
||||
|
||||
Each connection uses two shared memory buffers:
|
||||
|
||||
- **Data Buffer** (24 MiB): For command/response data and tensor transfers
|
||||
- **Reply Buffer** (16 KiB): For command replies and status information
|
||||
- **Data Buffers**: Dynamically allocated host-guest shared buffers
|
||||
served as GGML buffers.
|
||||
|
||||
### APIR Protocol
|
||||
|
||||
The Virglrender API Remoting protocol defines three command types:
|
||||
|
||||
- `HANDSHAKE`: Protocol version negotiation and capability discovery
|
||||
- `LOADLIBRARY`: Dynamic loading of backend libraries on the host
|
||||
- `FORWARD`: API function call forwarding
|
||||
|
||||
### Binary Serialization
|
||||
|
||||
Commands and data are serialized using a custom binary protocol with:
|
||||
|
||||
- Fixed-size encoding for basic types
|
||||
- Variable-length arrays with size prefixes
|
||||
- Buffer bounds checking
|
||||
- Error recovery mechanisms
|
||||
|
||||
## Supported Operations
|
||||
|
||||
### Device Operations
|
||||
- Device enumeration and capability queries
|
||||
- Memory information (total/free)
|
||||
- Backend type detection
|
||||
|
||||
### Buffer Operations
|
||||
- Buffer allocation and deallocation
|
||||
- Tensor data transfer (host ↔ guest)
|
||||
- Memory copying and clearing
|
||||
|
||||
### Computation Operations
|
||||
- Graph execution forwarding
|
||||
|
||||
## Build Requirements
|
||||
|
||||
### Guest-side Dependencies
|
||||
- `libdrm` for DRM/virtio-gpu communication
|
||||
- C++20 compatible compiler
|
||||
- CMake 3.14+
|
||||
|
||||
### Host-side Dependencies
|
||||
- virglrenderer with APIR support (pending upstream review)
|
||||
- Target backend libraries (libggml-metal, libggml-vulkan, etc.)
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
- `GGML_VIRTGPU_BACKEND_LIBRARY`: Path to the host-side backend library
|
||||
- `GGML_VIRTGPU_DEBUG`: Enable debug logging
|
||||
|
||||
### Build Options
|
||||
|
||||
- `GGML_VIRTGPU`: Enable the VirtGPU backend (`ON` or `OFF`, default: `OFF`)
|
||||
- `GGML_VIRTGPU_BACKEND`: Build the host-side backend component (`ON`, `OFF` or `ONLY`, default: `OFF`)
|
||||
|
||||
### System Requirements
|
||||
|
||||
- VM with virtio-gpu support
|
||||
- VirglRenderer with APIR patches
|
||||
- Compatible backend libraries on host
|
||||
|
||||
## Limitations
|
||||
|
||||
- **VM-specific**: Only works in virtual machines with virtio-gpu support
|
||||
- **Host dependency**: Requires properly configured host-side backend
|
||||
- **Latency**: Small overhead from VM escaping for each operation
|
||||
|
||||
|
||||
* This work is pending upstream changes in the VirglRenderer
|
||||
project.
|
||||
* The backend can be tested with Virglrenderer compiled from source
|
||||
using this PR:
|
||||
https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590
|
||||
* This work is pending changes in the VMM/hypervisor running the
|
||||
virtual machine, which need to know how to route the newly
|
||||
introduced APIR capset.
|
||||
* The environment variable `VIRGL_ROUTE_VENUS_TO_APIR=1` allows
|
||||
using the Venus capset, until the relevant hypervisors have been
|
||||
patched. However, setting this flag breaks the Vulkan/Venus normal
|
||||
behavior.
|
||||
* The environment variable `GGML_REMOTING_USE_APIR_CAPSET` tells the
|
||||
`ggml-virtgpu` backend to use the APIR capset. This will become
|
||||
the default when the relevant hypervisors have been patched.
|
||||
|
||||
* This work focused on improving the performance of llama.cpp running
|
||||
on MacOS containers, and is mainly tested on this platform. The
|
||||
linux support (via `krun`) is in progress.
|
||||
|
||||
## See Also
|
||||
|
||||
- [Development and Testing](VirtGPU/development.md)
|
||||
- [Backend configuration](VirtGPU/configuration.md)
|
||||
174
docs/backend/VirtGPU/configuration.md
Normal file
174
docs/backend/VirtGPU/configuration.md
Normal file
@@ -0,0 +1,174 @@
|
||||
# GGML-VirtGPU Backend Configuration
|
||||
|
||||
This document describes the environment variables used by the ggml-virtgpu backend system, covering both the frontend (guest-side) and backend (host-side) components.
|
||||
|
||||
## Environment Variables Overview
|
||||
|
||||
The ggml-virtgpu backend uses environment variables for configuration across three main components:
|
||||
- **Frontend (Guest)**: GGML applications running in VMs
|
||||
- **Hypervisor**: Virglrenderer/APIR system
|
||||
- **Backend (Host)**: Host-side GGML backend integration
|
||||
|
||||
## Frontend (Guest-side) Configuration
|
||||
|
||||
### GGML_REMOTING_USE_APIR_CAPSET
|
||||
- **Location**: `ggml/src/ggml-virtgpu/virtgpu.cpp`
|
||||
- **Type**: Boolean flag (presence-based)
|
||||
- **Purpose**: Controls which virtio-gpu capability set to use for communication
|
||||
- **Values**:
|
||||
- Set (any value): Use the APIR capset (long-term setup)
|
||||
- Unset: Use the Venus capset (easier for testing with an unmodified hypervisor)
|
||||
- **Default**: Unset (Venus capset)
|
||||
- **Usage**:
|
||||
```bash
|
||||
export GGML_REMOTING_USE_APIR_CAPSET=1 # Use APIR capset
|
||||
# or leave unset for Venus capset
|
||||
```
|
||||
|
||||
## Hypervisor (Virglrenderer/APIR) Configuration
|
||||
|
||||
These environment variables are used during the transition phase for
|
||||
running with an unmodified hypervisor (not supporting the
|
||||
VirglRenderer APIR component). They will be removed in the future, and
|
||||
the hypervisor will instead configure VirglRenderer with the APIR
|
||||
_Configuration Key_.
|
||||
|
||||
### VIRGL_APIR_BACKEND_LIBRARY
|
||||
- **Location**: `virglrenderer/src/apir/apir-context.c`
|
||||
- **Configuration Key**: `apir.load_library.path`
|
||||
- **Type**: File path string
|
||||
- **Purpose**: Path to the APIR backend library that virglrenderer should dynamically load
|
||||
- **Required**: Yes
|
||||
- **Example**:
|
||||
```bash
|
||||
export VIRGL_APIR_BACKEND_LIBRARY="/path/to/libggml-remotingbackend.so"
|
||||
```
|
||||
|
||||
### VIRGL_ROUTE_VENUS_TO_APIR
|
||||
- **Location**: `virglrenderer/src/apir/apir-renderer.h`
|
||||
- **Type**: Boolean flag (presence-based)
|
||||
- **Purpose**: Temporary workaround to route Venus capset calls to APIR during hypervisor transition period
|
||||
- **Status**: will be removed once hypervisors support APIR natively
|
||||
- **Warning**: Breaks normal Vulkan/Venus functionality
|
||||
- **Usage**:
|
||||
```bash
|
||||
export VIRGL_ROUTE_VENUS_TO_APIR=1 # For testing with an unmodified hypervisor
|
||||
```
|
||||
|
||||
### VIRGL_APIR_LOG_TO_FILE
|
||||
- **Location**: `virglrenderer/src/apir/apir-renderer.c`
|
||||
- **Environment Variable**: `VIRGL_APIR_LOG_TO_FILE`
|
||||
- **Type**: File path string
|
||||
- **Purpose**: Enable debug logging from the VirglRenderer APIR component to specified file
|
||||
- **Required**: No (optional debugging)
|
||||
- **Default**: Logging to `stderr`
|
||||
- **Usage**:
|
||||
```bash
|
||||
export VIRGL_APIR_LOG_TO_FILE="/tmp/apir-debug.log"
|
||||
```
|
||||
|
||||
## Backend (Host-side) Configuration
|
||||
|
||||
These environment variables are used during the transition phase for
|
||||
running with an unmodified hypervisor (not supporting the
|
||||
VirglRenderer APIR component). They will be removed in the future, and
|
||||
the hypervisor will instead configure VirglRenderer with the APIR
|
||||
_Configuration Key_.
|
||||
|
||||
### APIR_LLAMA_CPP_GGML_LIBRARY_PATH
|
||||
- **Location**: `ggml/src/ggml-virtgpu/backend/backend.cpp`
|
||||
- **Environment Variable**: `APIR_LLAMA_CPP_GGML_LIBRARY_PATH`
|
||||
- **Configuration Key**: `ggml.library.path`
|
||||
- **Type**: File path string
|
||||
- **Purpose**: Path to the actual GGML backend library (Metal, CUDA, Vulkan, etc.)
|
||||
- **Required**: **Yes** - backend initialization fails without this
|
||||
- **Examples**:
|
||||
```bash
|
||||
# macOS with Metal backend
|
||||
export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-metal.dylib"
|
||||
|
||||
# Linux with CUDA backend
|
||||
export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-cuda.so"
|
||||
|
||||
# macOS or Linux with Vulkan backend
|
||||
export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-vulkan.so"
|
||||
```
|
||||
|
||||
### APIR_LLAMA_CPP_GGML_LIBRARY_REG
|
||||
- **Location**: `ggml/src/ggml-virtgpu/backend/backend.cpp`
|
||||
- **Environment Variable**: `APIR_LLAMA_CPP_GGML_LIBRARY_REG`
|
||||
- **Configuration Key**: `ggml.library.reg`
|
||||
- **Type**: Function symbol name string
|
||||
- **Purpose**: Name of the backend registration function to call after loading the library
|
||||
- **Required**: No (defaults to `ggml_backend_init`)
|
||||
- **Default**: `ggml_backend_init`
|
||||
- **Examples**:
|
||||
```bash
|
||||
# Metal backend
|
||||
export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_metal_reg"
|
||||
|
||||
# CUDA backend
|
||||
export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_cuda_reg"
|
||||
|
||||
# Vulkan backend
|
||||
export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_vulkan_reg"
|
||||
|
||||
# Generic fallback (default)
|
||||
# export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_init"
|
||||
```
|
||||
|
||||
### APIR_LLAMA_CPP_LOG_TO_FILE
|
||||
- **Location**: `ggml/src/ggml-virtgpu/backend/backend.cpp:62`
|
||||
- **Environment Variable**: `APIR_LLAMA_CPP_LOG_TO_FILE`
|
||||
- **Type**: File path string
|
||||
- **Purpose**: Enable debug logging from the GGML backend to specified file
|
||||
- **Required**: No (optional debugging)
|
||||
- **Usage**:
|
||||
```bash
|
||||
export APIR_LLAMA_CPP_LOG_TO_FILE="/tmp/ggml-backend-debug.log"
|
||||
```
|
||||
|
||||
## Configuration Flow
|
||||
|
||||
The configuration system works as follows:
|
||||
|
||||
1. **Hypervisor Setup**: Virglrenderer loads the APIR backend library specified by `VIRGL_APIR_BACKEND_LIBRARY`
|
||||
|
||||
2. **Context Creation**: When an APIR context is created, it populates a configuration table with environment variables:
|
||||
- `apir.load_library.path` ← `VIRGL_APIR_BACKEND_LIBRARY`
|
||||
- `ggml.library.path` ← `APIR_LLAMA_CPP_GGML_LIBRARY_PATH`
|
||||
- `ggml.library.reg` ← `APIR_LLAMA_CPP_GGML_LIBRARY_REG`
|
||||
- this step will eventually be performed by the hypervisor itself, with command-line arguments instead of environment variables.
|
||||
|
||||
3. **Backend Initialization**: The backend queries the configuration via callbacks:
|
||||
- `virgl_cbs->get_config(ctx_id, "ggml.library.path")` returns the library path
|
||||
- `virgl_cbs->get_config(ctx_id, "ggml.library.reg")` returns the registration function
|
||||
|
||||
4. **Library Loading**: The backend dynamically loads and initializes the specified GGML library
|
||||
|
||||
## Error Messages
|
||||
|
||||
Common error scenarios and their messages:
|
||||
|
||||
- **Missing library path**: `"cannot open the GGML library: env var 'APIR_LLAMA_CPP_GGML_LIBRARY_PATH' not defined"`
|
||||
- **Missing registration function**: `"cannot register the GGML library: env var 'APIR_LLAMA_CPP_GGML_LIBRARY_REG' not defined"`
|
||||
|
||||
## Example Complete Configuration
|
||||
|
||||
Here's an example configuration for a macOS host with Metal backend:
|
||||
|
||||
```bash
|
||||
# Hypervisor environment
|
||||
export VIRGL_APIR_BACKEND_LIBRARY="/opt/llama.cpp/lib/libggml-virtgpu-backend.dylib"
|
||||
|
||||
# Backend configuration
|
||||
export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-metal.dylib"
|
||||
export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_metal_reg"
|
||||
|
||||
# Optional logging
|
||||
export VIRGL_APIR_LOG_TO_FILE="/tmp/apir.log"
|
||||
export APIR_LLAMA_CPP_LOG_TO_FILE="/tmp/ggml.log"
|
||||
|
||||
# Guest configuration
|
||||
export GGML_REMOTING_USE_APIR_CAPSET=1
|
||||
```
|
||||
220
docs/backend/VirtGPU/development.md
Normal file
220
docs/backend/VirtGPU/development.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# Development and Testing
|
||||
|
||||
## Development
|
||||
|
||||
### Code Generation
|
||||
|
||||
The backend uses code generation from YAML configuration:
|
||||
|
||||
```bash
|
||||
# Regenerate protocol code
|
||||
cd ggml-virtgpu/
|
||||
python regenerate_remoting.py
|
||||
```
|
||||
|
||||
### Adding New Operations
|
||||
|
||||
1. Add function definition to `ggmlremoting_functions.yaml`
|
||||
2. Regenerate code with `regenerate_remoting.py`
|
||||
3. Implement guest-side forwarding in `virtgpu-forward-*.cpp`
|
||||
4. Implement host-side handling in `backend-dispatched-*.cpp`
|
||||
|
||||
## Testing
|
||||
|
||||
This document provides instructions for building and testing the GGML-VirtGPU backend on macOS with containers.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
The testing setup requires:
|
||||
|
||||
- macOS host system
|
||||
- Container runtime with `libkrun` provider (podman machine)
|
||||
- Access to development patchset for VirglRenderer
|
||||
|
||||
### Required Patchsets
|
||||
|
||||
The backend requires patches that are currently under review:
|
||||
|
||||
- **Virglrenderer APIR upstream PR**: https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590 (for reference)
|
||||
- **MacOS Virglrenderer (for krunkit)**: https://gitlab.freedesktop.org/kpouget/virglrenderer/-/tree/main-macos
|
||||
- **Linux Virglrenderer (for krun)**: https://gitlab.freedesktop.org/kpouget/virglrenderer/-/tree/main-linux
|
||||
|
||||
### Build Instructions
|
||||
|
||||
#### 1. Build ggml-virtgpu-backend (Host-side, macOS)
|
||||
|
||||
```bash
|
||||
# Build the backend that runs natively on macOS
|
||||
mkdir llama.cpp
|
||||
cd llama.cpp
|
||||
git clone https://github.com/ggml-org/llama.cpp.git src
|
||||
cd src
|
||||
|
||||
LLAMA_MAC_BUILD=$PWD/build/ggml-virtgpu-backend
|
||||
|
||||
cmake -S . -B $LLAMA_MAC_BUILD \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DLLAMA_CURL=ON \
|
||||
-DGGML_REMOTINGBACKEND=ONLY \
|
||||
-DGGML_METAL=ON
|
||||
|
||||
TARGETS="ggml-metal"
|
||||
cmake --build $LLAMA_MAC_BUILD --parallel 8 --target $TARGETS
|
||||
|
||||
# Build additional tools for native benchmarking
|
||||
EXTRA_TARGETS="llama-run llama-bench"
|
||||
cmake --build $LLAMA_MAC_BUILD --parallel 8 --target $EXTRA_TARGETS
|
||||
```
|
||||
|
||||
#### 2. Build virglrenderer (Host-side, macOS)
|
||||
|
||||
```bash
|
||||
# Build virglrenderer with APIR support
|
||||
mkdir virglrenderer
|
||||
git clone https://gitlab.freedesktop.org/kpouget/virglrenderer -b main-macos src
|
||||
cd src
|
||||
|
||||
VIRGL_BUILD_DIR=$PWD/build
|
||||
|
||||
# -Dvenus=true and VIRGL_ROUTE_VENUS_TO_APIR=1 route the APIR requests via the Venus backend, for easier testing without a patched hypervisor
|
||||
|
||||
meson setup $VIRGL_BUILD_DIR \
|
||||
-Dvenus=true \
|
||||
-Dapir=true
|
||||
|
||||
ninja -C $VIRGL_BUILD_DIR
|
||||
```
|
||||
|
||||
#### 3. Build ggml-virtgpu (Guest-side, Linux)
|
||||
|
||||
Option A: Build from a script:
|
||||
|
||||
```bash
|
||||
# Inside a Linux container
|
||||
mkdir llama.cpp
|
||||
git clone https://github.com/ggml-org/llama.cpp.git src
|
||||
cd src
|
||||
|
||||
LLAMA_LINUX_BUILD=$PWD//build-virtgpu
|
||||
|
||||
cmake -S . -B $LLAMA_LINUX_BUILD \
|
||||
-DGGML_VIRTGPU=ON
|
||||
|
||||
ninja -C $LLAMA_LINUX_BUILD
|
||||
```
|
||||
|
||||
Option B: Build container image with frontend:
|
||||
|
||||
```bash
|
||||
cat << EOF > remoting.containerfile
|
||||
FROM quay.io/fedora/fedora:43
|
||||
USER 0
|
||||
|
||||
WORKDIR /app/remoting
|
||||
|
||||
ARG LLAMA_CPP_REPO="https://github.com/ggml-org/llama.cpp.git"
|
||||
ARG LLAMA_CPP_VERSION="master"
|
||||
ARG LLAMA_CPP_CMAKE_FLAGS="-DGGML_VIRTGPU=ON"
|
||||
ARG LLAMA_CPP_CMAKE_BUILD_FLAGS="--parallel 4"
|
||||
|
||||
RUN dnf install -y git cmake gcc gcc-c++ libcurl-devel libdrm-devel
|
||||
|
||||
RUN git clone "\${LLAMA_CPP_REPO}" src \\
|
||||
&& git -C src fetch origin \${LLAMA_CPP_VERSION} \\
|
||||
&& git -C src reset --hard FETCH_HEAD
|
||||
|
||||
RUN mkdir -p build \\
|
||||
&& cd src \\
|
||||
&& set -o pipefail \\
|
||||
&& cmake -S . -B ../build \${LLAMA_CPP_CMAKE_FLAGS} \\
|
||||
&& cmake --build ../build/ \${LLAMA_CPP_CMAKE_BUILD_FLAGS}
|
||||
|
||||
ENTRYPOINT ["/app/remoting/src/build/bin/llama-server"]
|
||||
EOF
|
||||
|
||||
mkdir -p empty_dir
|
||||
podman build -f remoting.containerfile ./empty_dir -t localhost/llama-cpp.virtgpu
|
||||
```
|
||||
|
||||
### Environment Setup
|
||||
|
||||
#### Set krunkit Environment Variables
|
||||
|
||||
```bash
|
||||
# Define the base directories (adapt these paths to your system)
|
||||
VIRGL_BUILD_DIR=$HOME/remoting/virglrenderer/build
|
||||
LLAMA_MAC_BUILD=$HOME/remoting/llama.cpp/build-backend
|
||||
|
||||
# For krunkit to load the custom virglrenderer library
|
||||
export DYLD_LIBRARY_PATH=$VIRGL_BUILD_DIR/src
|
||||
|
||||
# For Virglrenderer to load the ggml-remotingbackend library
|
||||
export VIRGL_APIR_BACKEND_LIBRARY="$LLAMA_MAC_BUILD/bin/libggml-virtgpu-backend.dylib"
|
||||
|
||||
# For llama.cpp remotingbackend to load the ggml-metal backend
|
||||
export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="$LLAMA_MAC_BUILD/bin/libggml-metal.dylib"
|
||||
export APIR_LLAMA_CPP_GGML_LIBRARY_REG=ggml_backend_metal_reg
|
||||
```
|
||||
|
||||
#### Launch Container Environment
|
||||
|
||||
```bash
|
||||
# Set container provider to libkrun
|
||||
export CONTAINERS_MACHINE_PROVIDER=libkrun
|
||||
podman machine start
|
||||
```
|
||||
|
||||
#### Verify Environment
|
||||
|
||||
Confirm that krunkit is using the correct virglrenderer library:
|
||||
|
||||
```bash
|
||||
lsof -c krunkit | grep virglrenderer
|
||||
# Expected output:
|
||||
# krunkit 50574 user txt REG 1,14 2273912 10849442 ($VIRGL_BUILD_DIR/src)/libvirglrenderer.1.dylib
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
#### Launch Test Container
|
||||
|
||||
```bash
|
||||
# Optional model caching
|
||||
mkdir -p models
|
||||
PODMAN_CACHE_ARGS="-v models:/models --user root:root --cgroupns host --security-opt label=disable -w /models"
|
||||
|
||||
podman run $PODMAN_CACHE_ARGS -it --rm --device /dev/dri localhost/llama-cpp.virtgpu
|
||||
```
|
||||
|
||||
#### Test llama.cpp in Container
|
||||
|
||||
```bash
|
||||
|
||||
# Run performance benchmark
|
||||
/app/remoting/build/bin/llama-bench -m ./llama3.2
|
||||
```
|
||||
|
||||
Expected output (performance may vary):
|
||||
```
|
||||
| model | size | params | backend | ngl | test | t/s |
|
||||
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
|
||||
| llama 3B Q4_K - Medium | 1.87 GiB | 3.21 B | ggml-virtgpu | 99 | pp512 | 991.30 ± 0.66 |
|
||||
| llama 3B Q4_K - Medium | 1.87 GiB | 3.21 B | ggml-virtgpu | 99 | tg128 | 85.71 ± 0.11 |
|
||||
```
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
#### SSH Environment Variable Issues
|
||||
|
||||
⚠️ **Warning**: Setting `DYLD_LIBRARY_PATH` from SSH doesn't work on macOS. Here is a workaround:
|
||||
|
||||
**Workaround 1: Replace system library**
|
||||
```bash
|
||||
VIRGL_BUILD_DIR=$HOME/remoting/virglrenderer/build # ⚠️ adapt to your system
|
||||
BREW_VIRGL_DIR=/opt/homebrew/Cellar/virglrenderer/0.10.4d/lib
|
||||
VIRGL_LIB=libvirglrenderer.1.dylib
|
||||
|
||||
cd $BREW_VIRGL_DIR
|
||||
mv $VIRGL_LIB ${VIRGL_LIB}.orig
|
||||
ln -s $VIRGL_BUILD_DIR/src/$VIRGL_LIB
|
||||
```
|
||||
@@ -35,7 +35,7 @@ Adapt below build commands accordingly.
|
||||
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||
|
||||
```
|
||||
[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
|
||||
|
||||
[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon
|
||||
Preset CMake variables:
|
||||
|
||||
@@ -242,10 +242,10 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
|------------|-------------|------|-------|
|
||||
| FP32 | ✅ | ✅ | ❓ |
|
||||
| FP16 | ✅ | ✅ | ❓ |
|
||||
| BF16 | 🚫 | ✅ | ❓ |
|
||||
| BF16 | ✅ | ✅ | ❓ |
|
||||
| Q4_0 | ✅ | ❓ | ❓ |
|
||||
| Q4_1 | ✅ | ❓ | ❓ |
|
||||
| MXFP4 | 🚫 | ❓ | ❓ |
|
||||
| MXFP4 | ✅ | ❓ | ❓ |
|
||||
| Q5_0 | ✅ | ❓ | ❓ |
|
||||
| Q5_1 | ✅ | ❓ | ❓ |
|
||||
| Q8_0 | ✅ | ❓ | ❓ |
|
||||
@@ -272,4 +272,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
- 🚫 - acceleration unavailable, will still run using scalar implementation
|
||||
- ❓ - acceleration unknown, please contribute if you can test it yourself
|
||||
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 7, 2025.
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Feb 15, 2026.
|
||||
|
||||
@@ -22,7 +22,7 @@ Legend:
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
|
||||
@@ -77,8 +77,8 @@
|
||||
"SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL"
|
||||
"SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
"SYCL0","FLOOR","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
|
||||
"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
|
||||
"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=1","support","1","yes","SYCL"
|
||||
"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL"
|
||||
"SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
"SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
|
||||
"SYCL0","TRUNC","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
@@ -161,8 +161,8 @@
|
||||
"SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL"
|
||||
"SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
"SYCL0","FLOOR","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
|
||||
"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
|
||||
"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=1","support","1","yes","SYCL"
|
||||
"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL"
|
||||
"SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
"SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
|
||||
"SYCL0","TRUNC","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
|
||||
|
Can't render this file because it is too large.
|
@@ -119,8 +119,6 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
|
||||
of lookup n-gram (default: 12)
|
||||
--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
|
||||
of draft m-gram (default: 48)
|
||||
--spec-ngram-check-rate N ngram check rate for ngram-simple/ngram-map speculative decoding
|
||||
(default: 1)
|
||||
--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
@@ -153,10 +151,6 @@ Sets the size M of the draft m-gram for n-gram map based speculative decoding.
|
||||
The m-gram size determines how many tokens to draft when a match is found.
|
||||
Larger values can provide more speedup but may reduce acceptance rate.
|
||||
|
||||
### `--spec-ngram-check-rate R`
|
||||
|
||||
This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token).
|
||||
|
||||
### `--spec-ngram-min-hits H`
|
||||
|
||||
This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
|
||||
@@ -175,7 +169,12 @@ draft acceptance rate = 0.70312 ( 90 accepted / 128 generated)
|
||||
statistics ngram_mod: #calls = 810, #gen drafts = 15, #acc drafts = 15, #gen tokens = 960, #acc tokens = 730, dur(b,g,a) = 0.149, 0.347, 0.005 ms
|
||||
```
|
||||
|
||||
- `#calls`: number of calls of this implementations
|
||||
```
|
||||
statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts = 26, #gen tokens = 1248, #acc tokens = 968, dur(b,g,a) = 2.234, 1.427, 0.016 ms
|
||||
```
|
||||
|
||||
|
||||
- `#calls(b,g,a)`: number of calls of begin (new prompt), generation and accumulation of this implementations
|
||||
- `#gen drafts`: number of drafts generated by this implementation
|
||||
- `#acc drafts`: number of drafts accepted (partially) by the main model
|
||||
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
|
||||
|
||||
@@ -42,11 +42,15 @@ def load_model_and_tokenizer(model_path, device="auto"):
|
||||
config = config.text_config
|
||||
multimodal = True
|
||||
|
||||
print("Vocab size: ", config.vocab_size)
|
||||
print("Hidden size: ", config.hidden_size)
|
||||
print("Number of layers: ", config.num_hidden_layers)
|
||||
print("BOS token id: ", config.bos_token_id)
|
||||
print("EOS token id: ", config.eos_token_id)
|
||||
def print_if_exists(label, obj, attr, default="N/A"):
|
||||
val = getattr(obj, attr) if hasattr(obj, attr) else default
|
||||
print(f"{label}", val)
|
||||
|
||||
print_if_exists("Vocab size: ", config, "vocab_size")
|
||||
print_if_exists("Hidden size: ", config, "hidden_size")
|
||||
print_if_exists("Number of layers: ", config, "num_hidden_layers")
|
||||
print_if_exists("BOS token id: ", config, "bos_token_id")
|
||||
print_if_exists("EOS token id: ", config, "eos_token_id")
|
||||
|
||||
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
|
||||
if unreleased_model_name:
|
||||
|
||||
@@ -78,7 +78,7 @@ def list_all_tensors(model_path: Path, unique: bool = False):
|
||||
print(tensor_name)
|
||||
|
||||
|
||||
def print_tensor_info(model_path: Path, tensor_name: str):
|
||||
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
|
||||
tensor_file = find_tensor_file(model_path, tensor_name)
|
||||
|
||||
if tensor_file is None:
|
||||
@@ -96,6 +96,12 @@ def print_tensor_info(model_path: Path, tensor_name: str):
|
||||
print(f"Tensor: {tensor_name}")
|
||||
print(f"File: {tensor_file}")
|
||||
print(f"Shape: {shape}")
|
||||
if num_values is not None:
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
print(f"Dtype: {tensor.dtype}")
|
||||
flat = tensor.flatten()
|
||||
n = min(num_values, flat.numel())
|
||||
print(f"Values: {flat[:n].tolist()}")
|
||||
else:
|
||||
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
|
||||
sys.exit(1)
|
||||
@@ -127,6 +133,15 @@ def main():
|
||||
action="store_true",
|
||||
help="List unique tensor patterns in the model (layer numbers replaced with #)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--num-values",
|
||||
nargs="?",
|
||||
const=10,
|
||||
default=None,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -152,7 +167,7 @@ def main():
|
||||
if args.tensor_name is None:
|
||||
print("Error: tensor_name is required when not using --list")
|
||||
sys.exit(1)
|
||||
print_tensor_info(model_path, args.tensor_name)
|
||||
print_tensor_info(model_path, args.tensor_name, args.num_values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 9)
|
||||
set(GGML_VERSION_PATCH 5)
|
||||
set(GGML_VERSION_PATCH 7)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
@@ -752,6 +752,7 @@ extern "C" {
|
||||
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_view (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
|
||||
|
||||
@@ -17,11 +17,6 @@
|
||||
//#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__)
|
||||
#define AT_PRINTF(...)
|
||||
|
||||
|
||||
static bool ggml_is_view(const struct ggml_tensor * t) {
|
||||
return t->view_src != NULL;
|
||||
}
|
||||
|
||||
// ops that return true for this function must not use restrict pointers for their backend implementations
|
||||
bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
switch (op) {
|
||||
@@ -627,7 +622,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
||||
GGML_ASSERT(buffer_id >= 0);
|
||||
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
|
||||
|
||||
if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
|
||||
if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) {
|
||||
hn->allocated = true;
|
||||
assert(hn->addr.offset == 0);
|
||||
|
||||
@@ -658,7 +653,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
||||
|
||||
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
|
||||
if (p_hn->n_children == 1 && p_hn->n_views == 0) {
|
||||
if (ggml_is_view(parent)) {
|
||||
if (ggml_impl_is_view(parent)) {
|
||||
struct ggml_tensor * view_src = parent->view_src;
|
||||
struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
|
||||
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
|
||||
@@ -739,7 +734,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
||||
// GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
|
||||
// control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
|
||||
// itself is never used and should not be considered a dependency
|
||||
if (ggml_is_view(node) && node->op != GGML_OP_NONE) {
|
||||
if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) {
|
||||
struct ggml_tensor * view_src = node->view_src;
|
||||
ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
|
||||
}
|
||||
@@ -806,7 +801,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
||||
parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
|
||||
|
||||
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
|
||||
if (ggml_is_view(parent)) {
|
||||
if (ggml_impl_is_view(parent)) {
|
||||
struct ggml_tensor * view_src = parent->view_src;
|
||||
struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
|
||||
view_src_hn->n_views -= 1;
|
||||
|
||||
@@ -471,9 +471,10 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
|
||||
int best_score = 0;
|
||||
fs::path best_path;
|
||||
std::error_code ec;
|
||||
|
||||
for (const auto & search_path : search_paths) {
|
||||
if (std::error_code ec; !fs::exists(search_path, ec)) {
|
||||
if (!fs::exists(search_path, ec)) {
|
||||
if (ec) {
|
||||
GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str());
|
||||
} else {
|
||||
@@ -483,7 +484,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
}
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
for (const auto & entry : dir_it) {
|
||||
if (entry.is_regular_file()) {
|
||||
if (entry.is_regular_file(ec)) {
|
||||
auto filename = entry.path().filename();
|
||||
auto ext = entry.path().extension();
|
||||
if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
|
||||
|
||||
@@ -3286,130 +3286,223 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs expert-specific matrix multiplication (MoE) with
|
||||
* quantized precision using the CANN backend.
|
||||
* @brief Performs quantized matrix multiplication for Mixture of Experts (MoE)
|
||||
* models using the CANN backend.
|
||||
*
|
||||
* This function executes a matrix multiplication operation tailored for
|
||||
* Mixture of Experts (MoE) models, where the input tensor is multiplied
|
||||
* with expert-specific quantized weight matrices. It leverages the CANN
|
||||
* backend to perform efficient low-precision computations and stores the
|
||||
* quantized result in the destination tensor `dst`.
|
||||
* This function implements MUL_MAT_ID operation for quantized weight matrices
|
||||
* (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on
|
||||
* the provided expert indices, and computes matrix multiplication using CANN's
|
||||
* WeightQuantBatchMatmulV2 operator.
|
||||
*
|
||||
* Quantization techniques reduce memory footprint and improve performance
|
||||
* by using lower-bit representations (e.g., int8) instead of floating-point.
|
||||
* This function is designed to work with such formats and may incorporate
|
||||
* optimizations like identity-based fast paths or routing masks for sparse
|
||||
* expert selection.
|
||||
* The function performs the following steps:
|
||||
* 1. Converts input/output tensors to F16 format if necessary
|
||||
* 2. Uses IndexSelect to extract expert-specific weights and scales based on indices
|
||||
* 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2
|
||||
* 4. Converts output back to the target type if needed
|
||||
*
|
||||
* @param ctx The context for executing CANN backend operations.
|
||||
* @param dst The destination tensor where the quantized MoE multiplication result
|
||||
* will be stored.
|
||||
* Tensor shapes:
|
||||
* - dst: [M, K, N, 1] - output tensor
|
||||
* - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0)
|
||||
* - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast)
|
||||
* - ids: [K, N] - expert indices for routing
|
||||
*
|
||||
* @note This function assumes quantized data types and is designed for
|
||||
* MoE architectures with potential sparse expert routing.
|
||||
* @param ctx The CANN backend context for operation execution.
|
||||
* @param dst The destination tensor where the multiplication result will be stored.
|
||||
*
|
||||
* @note Only Q4_0 and Q8_0 quantization formats are supported.
|
||||
* @note The function handles automatic type conversion to/from F16 as needed by the hardware.
|
||||
*/
|
||||
static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
// TODO: Use aclnnGroupedMatMul
|
||||
//dst [M, K, N, 1]
|
||||
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
|
||||
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
|
||||
ggml_tensor * ids = dst->src[2]; //ids [K, N]
|
||||
// dst: [M, K, N, 1]
|
||||
// src0: [D, M, A, 1] - quantized weights
|
||||
// src1: [D, B, N, 1] - input activations, B = K or B = 1
|
||||
// ids: [K, N] - expert indices
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
ggml_tensor * ids = dst->src[2];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
GGML_ASSERT(src0->ne[3] == 1);
|
||||
GGML_ASSERT(src1->ne[3] == 1);
|
||||
GGML_ASSERT(dst->ne[3] == 1);
|
||||
GGML_ASSERT(src1->ne[2] == ids->ne[1]);
|
||||
|
||||
// copy index from npu to cpu
|
||||
int64_t n_as = ne02; // A
|
||||
int64_t n_ids = ids->ne[0]; // K
|
||||
const int64_t n_batches = ids->ne[1];
|
||||
const int64_t n_select_experts = ids->ne[0];
|
||||
const enum ggml_type type = src0->type;
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids),
|
||||
ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream()));
|
||||
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
|
||||
const int32_t group_size = QK8_0; // Both Q4_0 and Q8_0 use group size of 32
|
||||
GGML_ASSERT(group_size == QK4_0);
|
||||
|
||||
char * src0_original = (char *) src0->data;
|
||||
char * src1_original = (char *) src1->data;
|
||||
char * dst_original = (char *) dst->data;
|
||||
// Calculate element size for quantized weights
|
||||
const float weight_elem_size =
|
||||
(type == GGML_TYPE_Q4_0) ? 0.5f :
|
||||
(type == GGML_TYPE_Q8_0) ? 1.0f :
|
||||
(GGML_ABORT("MUL_MAT_ID only supports Q4_0 and Q8_0"), 0.0f);
|
||||
|
||||
ggml_tensor src0_row = *src0;
|
||||
ggml_tensor src1_row = *src1;
|
||||
ggml_tensor dst_row = *dst;
|
||||
// Calculate scale offset in memory
|
||||
const size_t weight_size = src0->ne[0] * src0->ne[1] * src0->ne[2] * weight_elem_size;
|
||||
const size_t scale_elem_size = sizeof(uint16_t);
|
||||
char * scale_data = (char *) src0->data + weight_size;
|
||||
|
||||
const enum ggml_type type = dst->src[0]->type;
|
||||
float weight_elem_size;
|
||||
if (type == GGML_TYPE_Q4_0) {
|
||||
weight_elem_size = float(sizeof(uint8_t)) / 2;
|
||||
} else if (type == GGML_TYPE_Q8_0) {
|
||||
weight_elem_size = float(sizeof(uint8_t));
|
||||
} else {
|
||||
GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
|
||||
}
|
||||
// Allocate buffers for selected expert weights and scales
|
||||
const size_t selected_weight_size = src0->ne[0] * src0->ne[1] * n_select_experts * weight_elem_size;
|
||||
ggml_cann_pool_alloc selected_weight_alloc(ctx.pool(), selected_weight_size);
|
||||
void * selected_weight_buffer = selected_weight_alloc.get();
|
||||
|
||||
// src0_row [D, M, 1, 1] weight without permute
|
||||
src0_row.ne[2] = 1;
|
||||
src0_row.ne[3] = 1;
|
||||
src0_row.nb[0] = weight_elem_size;
|
||||
src0_row.nb[1] = weight_elem_size * ne00;
|
||||
src0_row.nb[2] = weight_elem_size * ne00;
|
||||
src0_row.nb[3] = weight_elem_size * ne00;
|
||||
size_t weight_stride = ne00 * ne01 * weight_elem_size;
|
||||
size_t weight_size = weight_stride * ne02 * ne03;
|
||||
const size_t selected_scale_size = (src0->ne[0] / group_size) * src0->ne[1] * n_select_experts * scale_elem_size;
|
||||
ggml_cann_pool_alloc selected_scale_alloc(ctx.pool(), selected_scale_size);
|
||||
void * selected_scale_buffer = selected_scale_alloc.get();
|
||||
|
||||
// scale [D, M, 1, 1] -> scale && permute
|
||||
size_t scale_elem_size = sizeof(uint16_t);
|
||||
size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
|
||||
// Helper lambda to allocate and cast tensor to F16 if needed
|
||||
constexpr size_t f16_elem_size = sizeof(uint16_t);
|
||||
auto prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator,
|
||||
bool need_cast = false) -> void * {
|
||||
if (tensor->type == GGML_TYPE_F16) {
|
||||
return tensor->data;
|
||||
}
|
||||
|
||||
// src1_row [D, 1, 1, 1] -> input
|
||||
src1_row.ne[1] = 1;
|
||||
src1_row.ne[2] = 1;
|
||||
src1_row.ne[3] = 1;
|
||||
src1_row.nb[2] = nb11;
|
||||
src1_row.nb[3] = nb11;
|
||||
size_t total_size = f16_elem_size;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
total_size *= tensor->ne[i];
|
||||
}
|
||||
void * buffer = allocator.alloc(total_size);
|
||||
|
||||
// dst_row [M, 1, 1, 1] -> out
|
||||
dst_row.ne[1] = 1;
|
||||
dst_row.ne[2] = 1;
|
||||
dst_row.ne[3] = 1;
|
||||
dst_row.nb[2] = nb1;
|
||||
dst_row.nb[3] = nb1;
|
||||
if (need_cast == false) {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
//create weight for one row
|
||||
ggml_cann_pool_alloc weight_allocator(ctx.pool());
|
||||
void * weight_buffer = weight_allocator.alloc(nb02);
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
// expert index
|
||||
int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]);
|
||||
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
ne[i] = tensor->ne[i];
|
||||
if (i > 0) {
|
||||
nb[i] = nb[i - 1] * ne[i - 1];
|
||||
}
|
||||
}
|
||||
|
||||
// If B = 1 (broadcast), always use 0; otherwise, use id.
|
||||
int64_t i11 = (ne11 == 1 ? 0 : id);
|
||||
int64_t i12 = iid1;
|
||||
acl_tensor_ptr src_tensor = ggml_cann_create_tensor(tensor);
|
||||
acl_tensor_ptr f16_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, src_tensor.get(), f16_tensor.get(), ACL_FLOAT16);
|
||||
|
||||
int64_t i1 = id;
|
||||
int64_t i2 = i12;
|
||||
return buffer;
|
||||
};
|
||||
|
||||
void * src0_tmp_ptr = src0_original + i02 * weight_stride;
|
||||
void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride;
|
||||
void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12;
|
||||
void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2;
|
||||
// Prepare input and output buffers
|
||||
ggml_cann_pool_alloc input_alloc(ctx.pool());
|
||||
void * input_buffer = prepare_f16_buffer(src1, input_alloc, true);
|
||||
|
||||
// mem cpy
|
||||
ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
|
||||
void * scale_buffer = (char *) weight_buffer + weight_stride;
|
||||
ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
|
||||
ggml_cann_pool_alloc output_alloc(ctx.pool());
|
||||
void * output_buffer = prepare_f16_buffer(dst, output_alloc, false);
|
||||
|
||||
src0_row.data = weight_buffer;
|
||||
src1_row.data = src1_tmp_ptr;
|
||||
dst_row.data = dst_tmp_ptr;
|
||||
dst_row.src[0] = &src0_row;
|
||||
dst_row.src[1] = &src1_row;
|
||||
// Process each batch
|
||||
for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {
|
||||
// Create index tensor for current batch
|
||||
const size_t index_offset = batch_idx * ids->nb[1];
|
||||
acl_tensor_ptr batch_indices = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, index_offset);
|
||||
|
||||
ggml_cann_mul_mat(ctx, &dst_row);
|
||||
// Select quantized weights using expert indices
|
||||
// Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte
|
||||
const int64_t weight_d = (type == GGML_TYPE_Q4_0) ? src0->ne[0] / 2 : src0->ne[0];
|
||||
const int64_t weight_m = src0->ne[1];
|
||||
const int64_t weight_n_experts = src0->ne[2];
|
||||
|
||||
int64_t weight_ne[3] = { weight_d, weight_m, weight_n_experts };
|
||||
size_t weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), weight_d * weight_m * sizeof(int8_t) };
|
||||
|
||||
acl_tensor_ptr all_weights =
|
||||
ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, 3);
|
||||
|
||||
int64_t selected_weight_ne[3] = { weight_d, weight_m, n_select_experts };
|
||||
size_t selected_weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t),
|
||||
weight_d * weight_m * sizeof(int8_t) };
|
||||
|
||||
acl_tensor_ptr selected_weights = ggml_cann_create_tensor(selected_weight_buffer, ACL_INT8, sizeof(int8_t),
|
||||
selected_weight_ne, selected_weight_nb, 3);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_weights.get(), 0, batch_indices.get(), selected_weights.get());
|
||||
|
||||
// Select scales using the same expert indices
|
||||
const int64_t scale_d = src0->ne[0] / group_size;
|
||||
int64_t scale_ne[3] = { scale_d, weight_m, weight_n_experts };
|
||||
size_t scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size };
|
||||
|
||||
acl_tensor_ptr all_scales =
|
||||
ggml_cann_create_tensor(scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3);
|
||||
|
||||
int64_t selected_scale_ne[3] = { scale_d, weight_m, n_select_experts };
|
||||
size_t selected_scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size,
|
||||
scale_d * weight_m * scale_elem_size };
|
||||
|
||||
acl_tensor_ptr selected_scales = ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size,
|
||||
selected_scale_ne, selected_scale_nb, 3);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_scales.get(), 0, batch_indices.get(), selected_scales.get());
|
||||
|
||||
// Process each expert for current batch
|
||||
// IndexSelect output layout: [D, M, K] in contiguous format
|
||||
// WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride
|
||||
for (int64_t expert_idx = 0; expert_idx < n_select_experts; expert_idx++) {
|
||||
// Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input
|
||||
const size_t input_offset =
|
||||
(batch_idx * src1->ne[1] + (src1->ne[1] == 1 ? 0 : expert_idx)) * src1->ne[0] * f16_elem_size;
|
||||
const size_t output_offset = (batch_idx * dst->ne[1] + expert_idx) * dst->ne[0] * f16_elem_size;
|
||||
|
||||
// Create weight view for current expert: [D, M, K] -> [M, D]
|
||||
int64_t weight_view_ne[2] = { weight_m, src0->ne[0] };
|
||||
float weight_view_nb[2] = { src0->ne[0] * weight_elem_size, weight_elem_size };
|
||||
const size_t weight_view_offset = expert_idx * selected_weight_nb[2];
|
||||
|
||||
acl_tensor_ptr weight_view =
|
||||
ggml_cann_create_tensor(selected_weight_buffer, ggml_cann_type_mapping(type), weight_elem_size,
|
||||
weight_view_ne, weight_view_nb, 2, ACL_FORMAT_ND, weight_view_offset);
|
||||
|
||||
// Create scale view for current expert: [D, M, K] -> [M, D]
|
||||
int64_t scale_view_ne[2] = { weight_m, scale_d };
|
||||
size_t scale_view_nb[2] = { selected_scale_nb[1], selected_scale_nb[0] };
|
||||
const size_t scale_view_offset = expert_idx * selected_scale_nb[2];
|
||||
|
||||
acl_tensor_ptr scale_view =
|
||||
ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne,
|
||||
scale_view_nb, 2, ACL_FORMAT_ND, scale_view_offset);
|
||||
|
||||
// Create input activation tensor [D, 1]
|
||||
int64_t input_ne[2] = { src1->ne[0], 1 };
|
||||
size_t input_nb[2] = { f16_elem_size, src1->ne[0] * f16_elem_size };
|
||||
|
||||
acl_tensor_ptr input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, f16_elem_size, input_ne,
|
||||
input_nb, 2, ACL_FORMAT_ND, input_offset);
|
||||
|
||||
// Create output tensor [M, 1]
|
||||
int64_t output_ne[2] = { dst->ne[0], 1 };
|
||||
size_t output_nb[2] = { f16_elem_size, dst->ne[0] * f16_elem_size };
|
||||
|
||||
acl_tensor_ptr output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, output_ne,
|
||||
output_nb, 2, ACL_FORMAT_ND, output_offset);
|
||||
|
||||
// Perform quantized matrix multiplication
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, input_tensor.get(), weight_view.get(),
|
||||
scale_view.get(), nullptr, nullptr, nullptr, nullptr, group_size,
|
||||
output_tensor.get());
|
||||
}
|
||||
}
|
||||
return;
|
||||
|
||||
// Cast output back to original type if we used a temporary F16 buffer
|
||||
if (dst->type != GGML_TYPE_F16) {
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
ne[i] = dst->ne[i];
|
||||
if (i > 0) {
|
||||
nb[i] = nb[i - 1] * ne[i - 1];
|
||||
}
|
||||
}
|
||||
|
||||
acl_tensor_ptr f16_output =
|
||||
ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
|
||||
acl_tensor_ptr dst_tensor = ggml_cann_create_tensor(dst);
|
||||
|
||||
aclnn_cast(ctx, f16_output.get(), dst_tensor.get(), ggml_cann_type_mapping(dst->type));
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
@@ -794,19 +794,44 @@ struct ggml_backend_cann_buffer_context {
|
||||
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
|
||||
};
|
||||
|
||||
// cann buffer type
|
||||
/**
|
||||
* @brief Check if a buffer is a CANN buffer.
|
||||
*
|
||||
* This function checks if a given buffer is a CANN buffer by comparing its
|
||||
* `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
|
||||
*
|
||||
* @param buffer The buffer to check.
|
||||
* @return true if the buffer is a CANN buffer, false otherwise.
|
||||
* @brief Structure representing context information for a specific backend
|
||||
* buffer type.
|
||||
*/
|
||||
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
|
||||
struct ggml_backend_cann_buffer_type_context {
|
||||
int32_t device; /**< Device identifier associated with the buffer context. */
|
||||
std::string name; /**< Name associated with the buffer context. */
|
||||
};
|
||||
|
||||
static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) {
|
||||
return ggml_backend_buft_is_cann(buffer->buft);
|
||||
/**
|
||||
* @brief Retrieves the name associated with a CANN buffer type.
|
||||
*
|
||||
* This function returns the descriptive name associated with the specified
|
||||
* CANN buffer type context.
|
||||
*
|
||||
* @param buft Pointer to the buffer type context.
|
||||
* @return Const pointer to the C-style string containing the name.
|
||||
*/
|
||||
static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
|
||||
ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
|
||||
|
||||
return buft_ctx->name.c_str();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks if the backend buffer type is associated with the CANN backend.
|
||||
*
|
||||
* This function checks whether the provided backend buffer type is associated
|
||||
* with the CANN backend based on the comparison of its name retrieval function
|
||||
* pointer.
|
||||
*
|
||||
* @param buft Pointer to the backend buffer type to check.
|
||||
* @return bool Returns true if the buffer type is associated with the CANN
|
||||
* backend, otherwise false.
|
||||
*/
|
||||
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
|
||||
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1271,7 +1296,7 @@ static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
|
||||
const ggml_tensor * src,
|
||||
ggml_tensor * dst) {
|
||||
if (ggml_backend_buffer_is_cann(src->buffer)) {
|
||||
if (ggml_backend_buft_is_cann(src->buffer->buft)) {
|
||||
ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context;
|
||||
ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context;
|
||||
|
||||
@@ -1335,31 +1360,6 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
// cann buffer type
|
||||
/**
|
||||
* @brief Structure representing context information for a specific backend
|
||||
* buffer type.
|
||||
*/
|
||||
struct ggml_backend_cann_buffer_type_context {
|
||||
int32_t device; /**< Device identifier associated with the buffer context. */
|
||||
std::string name; /**< Name associated with the buffer context. */
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Retrieves the name associated with a CANN buffer type.
|
||||
*
|
||||
* This function returns the descriptive name associated with the specified
|
||||
* CANN buffer type context.
|
||||
*
|
||||
* @param buft Pointer to the buffer type context.
|
||||
* @return Const pointer to the C-style string containing the name.
|
||||
*/
|
||||
static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
|
||||
ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
|
||||
|
||||
return buft_ctx->name.c_str();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Allocates a new CANN buffer of the specified type and size.
|
||||
*
|
||||
@@ -1997,7 +1997,7 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
|
||||
|
||||
GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src));
|
||||
|
||||
if (!ggml_backend_buffer_is_cann(src->buffer) || !ggml_backend_buffer_is_cann(dst->buffer)) {
|
||||
if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2523,21 +2523,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks if the backend buffer type is associated with the CANN backend.
|
||||
*
|
||||
* This function checks whether the provided backend buffer type is associated
|
||||
* with the CANN backend based on the comparison of its name retrieval function
|
||||
* pointer.
|
||||
*
|
||||
* @param buft Pointer to the backend buffer type to check.
|
||||
* @return bool Returns true if the buffer type is associated with the CANN
|
||||
* backend, otherwise false.
|
||||
*/
|
||||
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
|
||||
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Records an event on the CANN backend stream.
|
||||
*
|
||||
|
||||
@@ -9,6 +9,11 @@ function(ggml_add_cpu_backend_features cpu_name arch)
|
||||
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN})
|
||||
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
|
||||
set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
# Disable LTO for the feature detection code to prevent cross-module optimization
|
||||
# from inlining architecture-specific instructions into the score function.
|
||||
# Without this, LTO can cause SIGILL when loading backends on older CPUs
|
||||
# (e.g., loading power10 backend on power9 crashes before feature check runs).
|
||||
target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto)
|
||||
target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME})
|
||||
endfunction()
|
||||
|
||||
@@ -569,27 +574,24 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
endif()
|
||||
|
||||
# TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+
|
||||
# Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28
|
||||
FetchContent_Declare(KleidiAI_Download
|
||||
URL ${KLEIDIAI_DOWNLOAD_URL}
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP NEW
|
||||
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
|
||||
|
||||
FetchContent_MakeAvailable(KleidiAI_Download)
|
||||
FetchContent_GetProperties(KleidiAI_Download
|
||||
SOURCE_DIR KLEIDIAI_SRC
|
||||
POPULATED KLEIDIAI_POPULATED)
|
||||
|
||||
if (NOT KLEIDIAI_POPULATED)
|
||||
message(FATAL_ERROR "KleidiAI source downloaded failed.")
|
||||
FetchContent_Populate(KleidiAI_Download)
|
||||
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
|
||||
endif()
|
||||
|
||||
add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
|
||||
|
||||
# Remove kleidiai target after fetching it
|
||||
if (TARGET kleidiai)
|
||||
set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
|
||||
endif()
|
||||
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/kleidiai/kleidiai.cpp
|
||||
ggml-cpu/kleidiai/kernels.cpp
|
||||
|
||||
@@ -43,6 +43,7 @@
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
@@ -55,7 +56,8 @@
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
# define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -76,6 +78,7 @@
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
@@ -84,6 +87,7 @@
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -107,6 +111,7 @@
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
@@ -119,6 +124,7 @@
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
@@ -143,6 +149,7 @@
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
@@ -155,6 +162,7 @@
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
@@ -186,6 +194,7 @@
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
@@ -197,6 +206,7 @@
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
@@ -227,6 +237,7 @@
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
@@ -239,6 +250,7 @@
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
@@ -271,6 +283,7 @@
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
@@ -283,6 +296,7 @@
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
|
||||
@@ -1072,6 +1072,195 @@ void ggml_gemv_q5_K_8x8_q8_K(int n,
|
||||
ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q6_K_8x4_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int col_groups = ncols_interleaved / 4;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
|
||||
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
|
||||
|
||||
// 1x8 tile = 2 x 4
|
||||
float32x4_t acc_f32[2];
|
||||
|
||||
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < col_groups; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
|
||||
float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||
float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
|
||||
float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
|
||||
|
||||
int32x4_t acc[col_groups];
|
||||
for (int i = 0; i < col_groups; i++) {
|
||||
acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
// Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
|
||||
// Reused for bias and dequantization later
|
||||
int16_t q6_scales[16 * 8];
|
||||
for (int i = 0; i < 16; i++) {
|
||||
int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
|
||||
vst1q_s16(q6_scales + i * 8, scales);
|
||||
}
|
||||
|
||||
// Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
|
||||
int32x4_t bias_lo = vdupq_n_s32(0);
|
||||
int32x4_t bias_hi = vdupq_n_s32(0);
|
||||
|
||||
// Load bsums in chunks of 4 to process with vectorized operations
|
||||
for (int i = 0; i < 16; i += 4) {
|
||||
int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
|
||||
int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
|
||||
int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
|
||||
int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
|
||||
int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
|
||||
int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
|
||||
int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
|
||||
int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
|
||||
int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
|
||||
|
||||
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
|
||||
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
|
||||
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
|
||||
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
|
||||
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
|
||||
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
|
||||
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
|
||||
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
|
||||
}
|
||||
bias_lo = vshlq_n_s32(bias_lo, 5);
|
||||
bias_hi = vshlq_n_s32(bias_hi, 5);
|
||||
|
||||
// Process two 128-value halves per superblock
|
||||
for (int half = 0; half < 2; half++) {
|
||||
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
|
||||
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
|
||||
|
||||
// A subblock (sb) is a set of weights that share the scale
|
||||
// Since q6_K scales are per 16 elements
|
||||
// num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
|
||||
const int8_t * q8_base_h = q8_base_l + 64;
|
||||
|
||||
// Load and duplicate q8 values (each register covers four interleaved columns of q6)
|
||||
int8x16_t q8_l[4];
|
||||
int8x16_t q8_h[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));
|
||||
q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));
|
||||
}
|
||||
|
||||
const int ql_off_base = sb * QK_K / 2;
|
||||
const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
|
||||
|
||||
// Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
|
||||
uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
|
||||
uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
|
||||
uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
|
||||
uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
|
||||
|
||||
// Adjust qh for subblocks 2 and 3 (shift right by 2)
|
||||
if (sb > 1) {
|
||||
q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
|
||||
q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
|
||||
q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
|
||||
q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
|
||||
q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
|
||||
q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
|
||||
q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
|
||||
q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
|
||||
}
|
||||
|
||||
const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],
|
||||
q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };
|
||||
const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],
|
||||
q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };
|
||||
|
||||
// Process column groups (0-3, 4-7)
|
||||
for (int g = 0; g < col_groups; g++) {
|
||||
int32x4_t sb_acc_l = vdupq_n_s32(0);
|
||||
int32x4_t sb_acc_h = vdupq_n_s32(0);
|
||||
|
||||
for (int chunk = 0; chunk < 4; chunk++) {
|
||||
const int idx = chunk * 2 + g;
|
||||
|
||||
const uint8x16_t q6_qs_l = q6_ql[idx];
|
||||
const uint8x16_t q6_qs_h = q6_qh[idx];
|
||||
|
||||
// Extract high 2 bits for upper nibble reconstruction
|
||||
const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);
|
||||
|
||||
// q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
|
||||
const int8x16_t q6_l =
|
||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));
|
||||
const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));
|
||||
|
||||
sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);
|
||||
sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);
|
||||
}
|
||||
|
||||
const int scale_idx_l = half * 8 + sb;
|
||||
const int scale_idx_h = half * 8 + sb + 4;
|
||||
|
||||
const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));
|
||||
const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));
|
||||
|
||||
acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);
|
||||
acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);
|
||||
}
|
||||
}
|
||||
} // for half
|
||||
|
||||
// Bias correction
|
||||
acc[0] = vsubq_s32(acc[0], bias_lo);
|
||||
acc[1] = vsubq_s32(acc[1], bias_hi);
|
||||
|
||||
// Apply superblock scale (no mins for q6_K)
|
||||
// acc[g] has [c0, c1, c2, c3]
|
||||
float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);
|
||||
float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);
|
||||
|
||||
acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);
|
||||
acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);
|
||||
} // for b
|
||||
|
||||
int base = x * ncols_interleaved;
|
||||
vst1q_f32(s + base, acc_f32[0]);
|
||||
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||
} // for x
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q6_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
@@ -1177,15 +1366,14 @@ void ggml_gemv_q6_K_8x8_q8_K(int n,
|
||||
q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
|
||||
}
|
||||
|
||||
// TODO: Test other qh repack patterns to reduce loads
|
||||
const int ql_off_base = sb * QK_K / 2;
|
||||
const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
|
||||
|
||||
// Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
|
||||
ggml_uint8x16x4_t q6_ql_0 = ggml_vld1q_u8_x4(ql_base + ql_off_base);
|
||||
ggml_uint8x16x4_t q6_ql_1 = ggml_vld1q_u8_x4(ql_base + ql_off_base + 64);
|
||||
ggml_uint8x16x4_t q6_qh_0 = ggml_vld1q_u8_x4(qh_base + qh_off_base);
|
||||
ggml_uint8x16x4_t q6_qh_1 = ggml_vld1q_u8_x4(qh_base + qh_off_base + 64);
|
||||
uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
|
||||
uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
|
||||
uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
|
||||
uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
|
||||
|
||||
// Adjust qh for subblocks 2 and 3 (shift right by 2)
|
||||
if (sb > 1) {
|
||||
@@ -3038,6 +3226,316 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (svcntb() * 8 == 256) {
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
const svuint8_t m4b_1 = svdup_n_u8(0x0f);
|
||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||
svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
|
||||
uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 };
|
||||
svbool_t pg = svptrue_pat_b32(SV_VL8);
|
||||
svuint32_t idx = svld1(pg, idx_arr);
|
||||
|
||||
static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
|
||||
svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
|
||||
acc_f32_01 = svdup_n_f32(0);
|
||||
acc_f32_23 = svdup_n_f32(0);
|
||||
acc_f32_45 = svdup_n_f32(0);
|
||||
acc_f32_67 = svdup_n_f32(0);
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// bsums pairs belongs to the same q8_k subblock
|
||||
// 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
|
||||
const int16x8_t bsums[4]{
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
|
||||
int32_t bsums_arr32[4][8];
|
||||
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
int16x8_t v16 = bsums[q8_row];
|
||||
|
||||
// low 4
|
||||
int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
|
||||
vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
|
||||
|
||||
// high 4
|
||||
int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
|
||||
vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
|
||||
}
|
||||
|
||||
svint32_t sb_acc_0 = svdup_n_s32(0);
|
||||
svint32_t sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
svint32_t acc_00 = svdup_n_s32(0);
|
||||
svint32_t acc_11 = svdup_n_s32(0);
|
||||
svint32_t acc_22 = svdup_n_s32(0);
|
||||
svint32_t acc_33 = svdup_n_s32(0);
|
||||
svint32_t acc_44 = svdup_n_s32(0);
|
||||
svint32_t acc_55 = svdup_n_s32(0);
|
||||
svint32_t acc_66 = svdup_n_s32(0);
|
||||
svint32_t acc_77 = svdup_n_s32(0);
|
||||
|
||||
svint32_t bias_acc_00 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_22 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_44 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_66 = svdup_n_s32(0);
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
|
||||
svint32_t q4sb_mins_0, q4sb_mins_1;
|
||||
{
|
||||
// 2-superblock I am working on
|
||||
const int offset = sb * 24 + 0 * 12;
|
||||
const uint8_t * scales_in = &q4_ptr[b].scales[offset];
|
||||
|
||||
const int offset1 = sb * 24 + 12;
|
||||
const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];
|
||||
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
constexpr uint8_t scales_size = 12;
|
||||
|
||||
uint32_t sm[3];
|
||||
memcpy(sm, scales_in, scales_size);
|
||||
|
||||
uint32_t sm1[3];
|
||||
memcpy(sm1, scales_in1, scales_size);
|
||||
|
||||
const uint32_t mins_0_3 = sm[1] & kmask1;
|
||||
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
|
||||
|
||||
const uint32_t mins_0_3_1 = sm1[1] & kmask1;
|
||||
const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
|
||||
|
||||
svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
|
||||
svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
|
||||
|
||||
/* reinterpret u32 → u8 */
|
||||
svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
|
||||
svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
|
||||
|
||||
/* widen u8 → u16->u32 (lower half only) */
|
||||
svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
|
||||
svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
|
||||
|
||||
q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
|
||||
q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
|
||||
|
||||
uint32_t scales_u32_0 = sm[0] & kmask1;
|
||||
uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
|
||||
uint32_t scales_u32_2 = sm1[0] & kmask1;
|
||||
uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
|
||||
|
||||
svuint32_t S01 = svdup_n_u32(scales_u32_0);
|
||||
svuint32_t S23 = svdup_n_u32(scales_u32_1);
|
||||
svuint32_t R01 = svdup_n_u32(scales_u32_2);
|
||||
svuint32_t R23 = svdup_n_u32(scales_u32_3);
|
||||
|
||||
svint8_t S01_b = svreinterpret_s8_u32(S01);
|
||||
svint8_t S23_b = svreinterpret_s8_u32(S23);
|
||||
svint8_t R01_b = svreinterpret_s8_u32(R01);
|
||||
svint8_t R23_b = svreinterpret_s8_u32(R23);
|
||||
|
||||
svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
|
||||
svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
|
||||
svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
|
||||
svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
|
||||
|
||||
block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
|
||||
block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
|
||||
block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
|
||||
block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
|
||||
}
|
||||
|
||||
const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;
|
||||
|
||||
// Load 32-byte per row pair, 1 subblock each time
|
||||
// predicate for activating higher lanes for 16 int8 elements
|
||||
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
||||
// predicate for activating lower lanes for 16 int8 elements
|
||||
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
|
||||
|
||||
svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
|
||||
svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
|
||||
svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
|
||||
svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
|
||||
|
||||
svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
|
||||
svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
|
||||
svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
|
||||
svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
|
||||
|
||||
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
||||
|
||||
sb_acc_0 = svdup_n_s32(0);
|
||||
sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);
|
||||
svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);
|
||||
svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);
|
||||
svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);
|
||||
|
||||
svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
|
||||
svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
|
||||
svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
|
||||
svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
|
||||
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
|
||||
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
|
||||
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
|
||||
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
|
||||
|
||||
if(cp == 0) {
|
||||
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
|
||||
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
|
||||
}
|
||||
if(cp == 1) {
|
||||
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
|
||||
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
|
||||
}
|
||||
if(cp == 2) {
|
||||
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
|
||||
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
|
||||
}
|
||||
if(cp == 3) {
|
||||
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
|
||||
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
|
||||
}
|
||||
}
|
||||
|
||||
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
|
||||
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
|
||||
|
||||
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
|
||||
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
|
||||
|
||||
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
|
||||
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
|
||||
|
||||
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
|
||||
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
|
||||
} // for sb
|
||||
|
||||
|
||||
acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
|
||||
acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
|
||||
acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
|
||||
acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
|
||||
acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
|
||||
acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
|
||||
acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
|
||||
acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
|
||||
|
||||
svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
|
||||
svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
|
||||
|
||||
svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
|
||||
svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
|
||||
|
||||
// Broadcast q8 scalar
|
||||
svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
|
||||
|
||||
svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
|
||||
|
||||
svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
|
||||
|
||||
svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
|
||||
acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[1]);
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
|
||||
acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[2]);
|
||||
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
|
||||
acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[3]);
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
|
||||
acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
|
||||
|
||||
} // for b
|
||||
|
||||
// With the previous reorder, the tile is already in the correct memory layout.
|
||||
// Predicate for exactly 4 lanes
|
||||
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
|
||||
if (i == 0 && j == 0) {
|
||||
// acc_f32_0 → lower half of acc_f32_01
|
||||
svst1_f32(pg4, s + offset, acc_f32_01);
|
||||
} else if (i == 0 && j == 1) {
|
||||
// acc_f32_1 → upper half of acc_f32_01
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
|
||||
} else if (i == 1 && j == 0) {
|
||||
// acc_f32_2
|
||||
svst1_f32(pg4, s + offset, acc_f32_23);
|
||||
} else if (i == 1 && j == 1) {
|
||||
// acc_f32_3
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
|
||||
} else if (i == 2 && j == 0) {
|
||||
// acc_f32_4
|
||||
svst1_f32(pg4, s + offset, acc_f32_45);
|
||||
} else if (i == 2 && j == 1) {
|
||||
// acc_f32_5
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
|
||||
} else if (i == 3 && j == 0) {
|
||||
// acc_f32_6
|
||||
svst1_f32(pg4, s + offset, acc_f32_67);
|
||||
} else if (i == 3 && j == 1) {
|
||||
// acc_f32_7
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
}
|
||||
#endif // SVE compile-time end
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
@@ -3474,6 +3972,208 @@ void ggml_gemm_q5_K_8x8_q8_K(int n,
|
||||
ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q6_K_8x4_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
constexpr int col_groups = ncols_interleaved / 4;
|
||||
constexpr int acc_size = q8_k_blocklen * col_groups; // 4 rows, 2 column groups
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
|
||||
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
|
||||
const int8x16_t m32s = vdupq_n_s8(32);
|
||||
|
||||
float32x4_t acc_f32[acc_size];
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));
|
||||
float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));
|
||||
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
||||
|
||||
float32x4_t sbd_scale_0123[q8_k_blocklen];
|
||||
float32x4_t sbd_scale_4567[q8_k_blocklen];
|
||||
|
||||
sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);
|
||||
sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);
|
||||
sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);
|
||||
sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);
|
||||
sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);
|
||||
sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);
|
||||
sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);
|
||||
sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);
|
||||
|
||||
int32x4_t acc_s32[acc_size];
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_s32[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
int16_t q6_scales[8 * 16];
|
||||
for (int i = 0; i < 16; i++) {
|
||||
int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
|
||||
vst1q_s16(q6_scales + i * 8, scales);
|
||||
}
|
||||
|
||||
for (int half = 0; half < 2; half++) {
|
||||
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
|
||||
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
int32x4_t acc_lo[acc_size];
|
||||
int32x4_t acc_hi[acc_size];
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
|
||||
const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
|
||||
|
||||
// 4 rows * 16 elements per scale
|
||||
// 4 reads of 16 bytes each
|
||||
constexpr int reads_per_sb = 4;
|
||||
int8x16_t q8_l[reads_per_sb];
|
||||
int8x16_t q8_h[reads_per_sb];
|
||||
for (int k = 0; k < reads_per_sb; k++) {
|
||||
q8_l[k] = vld1q_s8(q8_base_l + 16 * k);
|
||||
q8_h[k] = vld1q_s8(q8_base_h + 16 * k);
|
||||
}
|
||||
|
||||
const int ql_off_base = sb * QK_K / 2;
|
||||
const int qh_off_base = ql_off_base & 255;
|
||||
|
||||
uint8x16_t q6_ql_0123[reads_per_sb];
|
||||
uint8x16_t q6_ql_4567[reads_per_sb];
|
||||
uint8x16_t q6_qh_0123[reads_per_sb];
|
||||
uint8x16_t q6_qh_4567[reads_per_sb];
|
||||
|
||||
for (int k = 0; k < reads_per_sb; k++) {
|
||||
q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);
|
||||
q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);
|
||||
q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);
|
||||
q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);
|
||||
}
|
||||
|
||||
if (sb > 1) {
|
||||
for (int k = 0; k < reads_per_sb; k++) {
|
||||
q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);
|
||||
q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);
|
||||
}
|
||||
}
|
||||
|
||||
for (int k = 0; k < reads_per_sb; k++) {
|
||||
// q = (ql | qh) - 32
|
||||
const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);
|
||||
const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);
|
||||
const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);
|
||||
const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);
|
||||
|
||||
const int8x16_t q6_0123_lo = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);
|
||||
const int8x16_t q6_0123_hi = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);
|
||||
|
||||
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123
|
||||
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123
|
||||
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123
|
||||
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123
|
||||
|
||||
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123
|
||||
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123
|
||||
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123
|
||||
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123
|
||||
|
||||
const int8x16_t q6_4567_lo = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);
|
||||
const int8x16_t q6_4567_hi = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);
|
||||
|
||||
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567
|
||||
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567
|
||||
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567
|
||||
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567
|
||||
|
||||
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567
|
||||
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567
|
||||
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567
|
||||
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567
|
||||
}
|
||||
|
||||
// Scale and bias
|
||||
const int scale_idx_l = half * 8 + sb;
|
||||
const int scale_idx_h = half * 8 + sb + 4;
|
||||
|
||||
for (int g = 0; g < col_groups; g++) {
|
||||
const int16x4_t scales_l16 = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);
|
||||
const int16x4_t scales_h16 = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);
|
||||
const int32x4_t scale_vec_l = vmovl_s16(scales_l16);
|
||||
const int32x4_t scale_vec_h = vmovl_s16(scales_h16);
|
||||
const int acc_offset = g * q8_k_blocklen;
|
||||
|
||||
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||
const int idx = row * 2 + g;
|
||||
acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);
|
||||
acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finally we apply the superblock scales
|
||||
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||
const int idx0 = 2 * row;
|
||||
const int idx1 = 2 * row + 1;
|
||||
const int32x4_t acc_0123 = acc_s32[idx0];
|
||||
const int32x4_t acc_4567 = acc_s32[idx1];
|
||||
|
||||
acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);
|
||||
acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);
|
||||
}
|
||||
} // for b
|
||||
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q6_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
|
||||
@@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
|
||||
GGML_ASSERT(nb00 == sizeof(src0_t));
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
|
||||
|
||||
if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
}
|
||||
const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1);
|
||||
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
vDSP_fn_t vDSP_op = nullptr;
|
||||
@@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
|
||||
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||
|
||||
if (is_src1_contiguous) {
|
||||
if (is_src1_contiguous_rows) {
|
||||
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
||||
const int64_t nr0 = ne00 / ne10;
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "simd-mappings.h"
|
||||
|
||||
#define GGML_FA_TILE_Q 32
|
||||
#define GGML_FA_TILE_KV 16
|
||||
#define GGML_FA_TILE_Q 64
|
||||
#define GGML_FA_TILE_KV 64
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
|
||||
@@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan(
|
||||
const int64_t DV = node->src[2]->ne[0];
|
||||
|
||||
// Tiled flash attention scratch (tile sizes defined in common.h)
|
||||
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
|
||||
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
|
||||
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
|
||||
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;
|
||||
|
||||
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
|
||||
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
|
||||
@@ -2947,7 +2947,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
/*.use_ref =*/ cplan->use_ref,
|
||||
};
|
||||
|
||||
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
|
||||
#ifdef GGML_USE_OPENMP
|
||||
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan);
|
||||
#else
|
||||
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
|
||||
#endif
|
||||
|
||||
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
@@ -2974,7 +2978,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
}
|
||||
}
|
||||
|
||||
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
|
||||
#ifdef GGML_USE_OPENMP
|
||||
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan);
|
||||
#else
|
||||
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
|
||||
#endif
|
||||
|
||||
ggml_barrier(state->threadpool);
|
||||
|
||||
|
||||
@@ -1,333 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
typedef vector unsigned char vec_t;
|
||||
typedef __vector_quad acc_t;
|
||||
|
||||
template <typename TA>
|
||||
class tinyBLAS_Q0_PPC {
|
||||
public:
|
||||
tinyBLAS_Q0_PPC(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const block_q8_0 *B, int64_t ldb,
|
||||
float *C, int64_t ldc,
|
||||
int ith, int nth);
|
||||
|
||||
void matmul(int64_t m, int64_t n);
|
||||
void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
|
||||
vec_t A_pack[mc*kc*2];
|
||||
vec_t B_pack[nc*kc*2];
|
||||
int comparray[mc*kc];
|
||||
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
int64_t ytiles = m / mc;
|
||||
int64_t xtiles = n / nc;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (end > tiles) {
|
||||
end = tiles;
|
||||
}
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = (job / xtiles) * mc;
|
||||
int64_t jj = (job % xtiles) * nc;
|
||||
for (int64_t kk = 0; kk < k; kk += kc) {
|
||||
if constexpr(is_Ablock_q4) {
|
||||
packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
|
||||
} else {
|
||||
packNormal_large<int8_t, vector signed char>(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray);
|
||||
}
|
||||
packNormal_large<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
|
||||
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
|
||||
*c_ptr += *((float*)&fin_res[idx+I]+J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ArrayType>
|
||||
inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) {
|
||||
vector signed int vec_C[4];
|
||||
vector float CA[4] = {0};
|
||||
vector float res[4] = {0};
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
|
||||
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
||||
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
|
||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
||||
const vector signed char v8 = vec_splats((signed char)0x8);
|
||||
vector signed int vsum = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
c[0] = vec_and(c[1], lowMask);
|
||||
c[1] = vec_sr(c[1], v4);
|
||||
c[0] = vec_sub(c[0], v8);
|
||||
c[1] = vec_sub(c[1], v8);
|
||||
vsum = vec_sum4s(c[0], vsum);
|
||||
vsum2 = vec_sum4s(c[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
}
|
||||
|
||||
template <typename V1, typename V2>
|
||||
inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
|
||||
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
||||
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
||||
V2 t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
vector unsigned char xor_vector;
|
||||
uint8_t flip_vec = 0x80;
|
||||
xor_vector = vec_splats(flip_vec);
|
||||
t1 = vec_perm(s1, s2, swiz1);
|
||||
t2 = vec_perm(s1, s2, swiz2);
|
||||
t3 = vec_perm(s3, s4, swiz1);
|
||||
t4 = vec_perm(s3, s4, swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
if (flip == true) {
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
}
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset+16);
|
||||
vec_xst(t7, 0, vecOffset+32);
|
||||
vec_xst(t8, 0, vecOffset+48);
|
||||
}
|
||||
|
||||
template<int RM, int RN>
|
||||
inline void kernel(int64_t ii, int64_t jj) {
|
||||
if constexpr(RM == 4 && RN == 8) {
|
||||
KERNEL_4x8(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 4) {
|
||||
KERNEL_8x4(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 8) {
|
||||
KERNEL_8x8(ii,jj);
|
||||
} else {
|
||||
assert(false && "RN/RM values not supported");
|
||||
}
|
||||
}
|
||||
template<int size>
|
||||
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray);
|
||||
template<typename VA, typename VB>
|
||||
void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip);
|
||||
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n);
|
||||
void KERNEL_4x8(int64_t ii, int64_t jj);
|
||||
void KERNEL_8x4(int64_t ii, int64_t jj);
|
||||
void KERNEL_8x8(int64_t ii, int64_t jj);
|
||||
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN);
|
||||
template <int RM, int RN>
|
||||
void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n);
|
||||
|
||||
void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
|
||||
for (int I = 0; I<8; I++) {
|
||||
float a_scale = unhalf((A+((ii+I)*lda)+blk)->d);
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d));
|
||||
*((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void process_q8_elements(const int8_t *qs, int *ca) {
|
||||
vector signed char c1 = vec_xl(0, qs);
|
||||
vector signed char c2 = vec_xl(16, qs);
|
||||
vector signed int vsum1 = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
vsum1 = vec_sum4s(c1, vsum1);
|
||||
vsum2 = vec_sum4s(c2, vsum2);
|
||||
vector signed int vsum = vec_add(vsum1, vsum2);
|
||||
*ca = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
}
|
||||
|
||||
template<typename VA, typename VB>
|
||||
void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) {
|
||||
int64_t i, j;
|
||||
block_q8_0 *aoffset = NULL;
|
||||
VA *vecOffset = NULL;
|
||||
block_q8_0* aoffsets[8];
|
||||
__vector_pair arr[8];
|
||||
VB c[8][2] = {0};
|
||||
VB c1[8] = {0}; VB c2[8] = {0};
|
||||
aoffset = const_cast<block_q8_0*>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
int index = 0;
|
||||
if (j > 0) {
|
||||
do {
|
||||
for (int it = 0; it < 8; it++)
|
||||
aoffsets[it] = aoffset + it*lda;
|
||||
aoffset += 8 * lda;
|
||||
for (int blk = 0; blk < kc; blk++) {
|
||||
for (int it = 0; it < 8; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
c2[it] = c[it][1];
|
||||
if (comparray){
|
||||
process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]);
|
||||
}
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
|
||||
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
|
||||
vecOffset += 256;
|
||||
}
|
||||
j--;
|
||||
index += 8*kc;
|
||||
} while(j > 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) {
|
||||
int64_t i, j;
|
||||
TA *aoffset = NULL;
|
||||
int8_t *vecOffset = NULL;
|
||||
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
||||
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
||||
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
||||
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
||||
aoffset = const_cast<TA*>(a);
|
||||
vecOffset = vec;
|
||||
int index = 0;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
do {
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
aoffset4 = aoffset3 + lda;
|
||||
aoffset5 = aoffset4 + lda;
|
||||
aoffset6 = aoffset5 + lda;
|
||||
aoffset7 = aoffset6 + lda;
|
||||
aoffset8 = aoffset7 + lda;
|
||||
aoffset += 8 * lda;
|
||||
for (int blk = 0; blk < kc; blk++) {
|
||||
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs));
|
||||
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs));
|
||||
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs));
|
||||
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs));
|
||||
c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs));
|
||||
c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs));
|
||||
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs));
|
||||
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs));
|
||||
|
||||
process_q4_elements(c1, &comparray[index + 8*blk+0]);
|
||||
process_q4_elements(c2, &comparray[index + 8*blk+1]);
|
||||
process_q4_elements(c3, &comparray[index + 8*blk+2]);
|
||||
process_q4_elements(c4, &comparray[index + 8*blk+3]);
|
||||
process_q4_elements(c5, &comparray[index + 8*blk+4]);
|
||||
process_q4_elements(c6, &comparray[index + 8*blk+5]);
|
||||
process_q4_elements(c7, &comparray[index + 8*blk+6]);
|
||||
process_q4_elements(c8, &comparray[index + 8*blk+7]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
|
||||
vecOffset += 256;
|
||||
}
|
||||
j--;
|
||||
index += 8*kc;
|
||||
} while (j > 0);
|
||||
}
|
||||
}
|
||||
|
||||
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
|
||||
acc_t acc[8];
|
||||
for (int i = 0; i < mc ; i += 8) {
|
||||
for (int j = 0; j < nc; j += 8) {
|
||||
vector float fin_res[16] = {0};
|
||||
vector float vs[16] = {0};
|
||||
for (int64_t kk = 0; kk < kc; kk+=2) {
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xxsetaccz(&acc[x]);
|
||||
}
|
||||
int A_block_idx = (i/8)*(16*kc) + kk*16;
|
||||
int B_block_idx = (j/8)*(16*kc)+ kk*16;
|
||||
vec_t *A_block = &vec_A[A_block_idx];
|
||||
vec_t *B_block = &vec_B[B_block_idx];
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]);
|
||||
}
|
||||
compute_scale(ii+i, jj+j, l+kk, vs);
|
||||
int c_index = (i/8)*(8*kc)+ kk*8;
|
||||
int* c_block = &comparray[c_index];
|
||||
compute(&acc[0], 0, 0, c_block, vs, fin_res);
|
||||
compute(&acc[1], 4, 4, c_block, vs, fin_res);
|
||||
compute(&acc[2], 0, 8, c_block, vs, fin_res);
|
||||
compute(&acc[3], 4, 12, c_block, vs, fin_res);
|
||||
|
||||
A_block_idx = (i/8)*(16*kc) + (kk+1)*16;
|
||||
B_block_idx = (j/8)*(16*kc)+ (kk+1)*16;
|
||||
A_block = &vec_A[A_block_idx];
|
||||
B_block = &vec_B[B_block_idx];
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]);
|
||||
}
|
||||
compute_scale(ii+i, jj+j, l+kk+1, vs);
|
||||
c_index = (i/8)*(8*kc)+ (kk+1)*8;
|
||||
c_block = &comparray[c_index];
|
||||
compute(&acc[4], 0, 0, c_block, vs, fin_res);
|
||||
compute(&acc[5], 4, 4, c_block, vs, fin_res);
|
||||
compute(&acc[6], 0, 8, c_block, vs, fin_res);
|
||||
compute(&acc[7], 4, 12, c_block, vs, fin_res);
|
||||
|
||||
}
|
||||
if (l == 0) {
|
||||
save_res(ii+i, jj+j, 0, fin_res);
|
||||
save_res(ii+i+4, jj+j, 4, fin_res);
|
||||
save_res(ii+i, jj+j+4, 8, fin_res);
|
||||
save_res(ii+i+4, jj+j+4, 12, fin_res);
|
||||
} else {
|
||||
add_save_res(ii+i, jj+j, 0, fin_res);
|
||||
add_save_res(ii+i+4, jj+j, 4, fin_res);
|
||||
add_save_res(ii+i, jj+j+4, 8, fin_res);
|
||||
add_save_res(ii+i+4, jj+j+4, 12, fin_res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const TA *const A;
|
||||
const block_q8_0 *const B;
|
||||
float *C;
|
||||
const int64_t k;
|
||||
int64_t kc;
|
||||
const int64_t lda;
|
||||
const int64_t ldb;
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
};
|
||||
@@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
|
||||
#endif
|
||||
|
||||
#if defined(__MMA__)
|
||||
#include "sgemm-ppc.h"
|
||||
typedef vector unsigned char vec_t;
|
||||
typedef __vector_quad acc_t;
|
||||
#endif
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// VECTORIZED FUSED MULTIPLY ADD
|
||||
@@ -2153,7 +2154,7 @@ class tinyBLAS_HP16_PPC {
|
||||
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
|
||||
for (int x = 0; x < 4; x++) {
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
|
||||
}
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
@@ -2301,43 +2302,299 @@ class tinyBLAS_HP16_PPC {
|
||||
const int nth;
|
||||
};
|
||||
|
||||
template <typename TA>
|
||||
tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const block_q8_0 *B, int64_t ldb,
|
||||
float *C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
template <typename TA>
|
||||
class tinyBLAS_Q0_PPC {
|
||||
public:
|
||||
tinyBLAS_Q0_PPC(int64_t k,
|
||||
const TA * A, int64_t lda,
|
||||
const block_q8_0 * B, int64_t ldb,
|
||||
float * C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
kc = 64;
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
|
||||
int mc = 64; int nc = 64;
|
||||
if (n % 8 == 0 && n < nc) {
|
||||
nc = n;
|
||||
mc = 32 ;
|
||||
kc = 32;
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
const int64_t mc = 64;
|
||||
const int64_t kc = 64;
|
||||
int64_t nc = 64;
|
||||
int64_t n_aligned = 0;
|
||||
if (n % 64 == 0) {
|
||||
n_aligned = n;
|
||||
} else if (n == 4) {
|
||||
n_aligned = 4;
|
||||
} else if (n < 64) {
|
||||
n_aligned = (n / 8) * 8;
|
||||
} else {
|
||||
n_aligned = (n / 64) * 64;
|
||||
}
|
||||
const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
|
||||
if (is_aligned) {
|
||||
this->matmul_tiled_q0(m, n, mc, nc, kc);
|
||||
|
||||
if (n_aligned > 0) {
|
||||
if (n_aligned % 64 == 0) nc = 64;
|
||||
else if (n_aligned == n) nc = n;
|
||||
else if (n_aligned % 32 == 0) nc = 32;
|
||||
else if (n_aligned % 24 == 0) nc = 24;
|
||||
else if (n_aligned % 16 == 0) nc = 16;
|
||||
else nc = 8;
|
||||
}
|
||||
bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
|
||||
if (can_use_tiled) {
|
||||
matmul_tiled(m, n_aligned, mc, nc, kc);
|
||||
if (n > n_aligned) {
|
||||
mnpack(0, m, n_aligned, n);
|
||||
}
|
||||
} else {
|
||||
mnpack(0, m, 0, n);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
template<int size>
|
||||
void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
|
||||
private:
|
||||
inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int I = 0; I < 4; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int I = 0; I < 4; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
|
||||
*c_ptr += *((float *)&vec_C[I] + J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ArrayType>
|
||||
inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
|
||||
vector signed int vec_C[4];
|
||||
vector float CA[4] = {0};
|
||||
vector float res[4] = {0};
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
|
||||
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
||||
fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
|
||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
||||
const vector signed char v8 = vec_splats((signed char)0x8);
|
||||
vector signed int vsum = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
c[0] = vec_and(c[1], lowMask);
|
||||
c[1] = vec_sr(c[1], v4);
|
||||
c[0] = vec_sub(c[0], v8);
|
||||
c[1] = vec_sub(c[1], v8);
|
||||
vsum = vec_sum4s(c[0], vsum);
|
||||
vsum2 = vec_sum4s(c[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
}
|
||||
|
||||
template <typename V1, typename V2>
|
||||
inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
|
||||
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
||||
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
||||
V2 t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
vector unsigned char xor_vector;
|
||||
uint8_t flip_vec = 0x80;
|
||||
xor_vector = vec_splats(flip_vec);
|
||||
t1 = vec_perm(s1, s2, swiz1);
|
||||
t2 = vec_perm(s1, s2, swiz2);
|
||||
t3 = vec_perm(s3, s4, swiz1);
|
||||
t4 = vec_perm(s3, s4, swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
if (flip == true) {
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
}
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset + 16);
|
||||
vec_xst(t7, 0, vecOffset + 32);
|
||||
vec_xst(t8, 0, vecOffset + 48);
|
||||
}
|
||||
|
||||
inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
|
||||
const vector signed char lowMask = vec_splats((signed char)0x0F);
|
||||
const vector signed char v8 = vec_splats((signed char)0x08);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)4);
|
||||
lo = vec_and(packed, lowMask);
|
||||
hi = vec_sr(packed, v4);
|
||||
lo = vec_sub(lo, v8);
|
||||
hi = vec_sub(hi, v8);
|
||||
}
|
||||
|
||||
inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
|
||||
vec_t t[8], s[8];
|
||||
vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
||||
vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
||||
vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
for (int i = 0; i < 4; i += 2) {
|
||||
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
|
||||
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
|
||||
}
|
||||
for (int i = 4; i < 8; i += 2) {
|
||||
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
|
||||
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
|
||||
}
|
||||
s[0] = vec_perm(t[0], t[2], swiz3);
|
||||
s[1] = vec_perm(t[0], t[2], swiz4);
|
||||
s[2] = vec_perm(t[1], t[3], swiz3);
|
||||
s[3] = vec_perm(t[1], t[3], swiz4);
|
||||
s[4] = vec_perm(t[4], t[6], swiz3);
|
||||
s[5] = vec_perm(t[4], t[6], swiz4);
|
||||
s[6] = vec_perm(t[5], t[7], swiz3);
|
||||
s[7] = vec_perm(t[5], t[7], swiz4);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
|
||||
vector signed short i16_hi = vec_unpackh(raw);
|
||||
vector signed short i16_lo = vec_unpackl(raw);
|
||||
|
||||
vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
|
||||
vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
|
||||
vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
|
||||
vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
|
||||
out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
|
||||
out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
|
||||
}
|
||||
|
||||
void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
||||
unsigned char * vecOffset = vec;
|
||||
for (int i = 0; i < rows; i += 8) {
|
||||
const block_q4_0 * rows_base[8];
|
||||
for (int r = 0; r < 8; r++) {
|
||||
rows_base[r] = a + (i + r) * lda;
|
||||
}
|
||||
for (int blk = 0; blk < blocks; blk++) {
|
||||
vector unsigned short hp_res[8][4];
|
||||
for (int r = 0; r < 8; r++) {
|
||||
const block_q4_0 * current_blk = rows_base[r] + blk;
|
||||
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
|
||||
vector signed char v_qs = reinterpret_cast<vector signed char>(vec_xl(0, current_blk->qs));
|
||||
vector signed char c1, c2;
|
||||
unpack_q4_to_q8(v_qs, c1, c2);
|
||||
convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
|
||||
convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
|
||||
}
|
||||
for (int c = 0; c < 4; c++) {
|
||||
vector unsigned char c_arr[8];
|
||||
for (int r = 0; r < 8; r++) {
|
||||
c_arr[r] = (vector unsigned char)hp_res[r][c];
|
||||
}
|
||||
vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
|
||||
vecOffset += 128;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int chunk_size>
|
||||
static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
||||
unsigned char * vecOffset = vec;
|
||||
const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
||||
const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
||||
const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
|
||||
for (int i = 0; i < rows; i += chunk_size) {
|
||||
const block_q8_0 * rows_base[chunk_size];
|
||||
for (int r = 0; r < chunk_size; r++) {
|
||||
rows_base[r] = a + (i + r) * lda;
|
||||
}
|
||||
for (int blk = 0; blk < blocks; blk++) {
|
||||
vector unsigned short hp_res[chunk_size][4];
|
||||
for (int r = 0; r < chunk_size; r++) {
|
||||
const block_q8_0 * b = rows_base[r] + blk;
|
||||
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
|
||||
vector signed char c[2];
|
||||
__vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
|
||||
__builtin_vsx_disassemble_pair(c, & pair);
|
||||
convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
|
||||
convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
|
||||
}
|
||||
for (int col = 0; col < 4; col++) {
|
||||
if constexpr (chunk_size == 8) {
|
||||
vec_t t[8];
|
||||
t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
|
||||
t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
|
||||
t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
|
||||
t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
|
||||
t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
|
||||
t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
|
||||
t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
|
||||
t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
|
||||
|
||||
vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
|
||||
vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
|
||||
vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
|
||||
vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
|
||||
vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
|
||||
vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
|
||||
vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
|
||||
vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
|
||||
vecOffset += 128;
|
||||
} else {
|
||||
vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
|
||||
vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
|
||||
vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
|
||||
vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
|
||||
|
||||
vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
|
||||
vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
|
||||
vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
|
||||
vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
|
||||
vecOffset += 64;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
||||
if (rows == 4) {
|
||||
pack_q8_block<4>(a, lda, rows, blocks, vec);
|
||||
} else {
|
||||
pack_q8_block<8>(a, lda, rows, blocks, vec);
|
||||
}
|
||||
}
|
||||
|
||||
template<int size>
|
||||
void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) {
|
||||
int64_t i, j;
|
||||
TA *aoffset = NULL;
|
||||
int8_t *vecOffset = NULL;
|
||||
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
||||
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
||||
TA * aoffset = NULL;
|
||||
int8_t * vecOffset = NULL;
|
||||
TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
|
||||
TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
|
||||
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
||||
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
||||
aoffset = const_cast<TA*>(a);
|
||||
aoffset = const_cast<TA *>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
@@ -2363,18 +2620,18 @@ class tinyBLAS_HP16_PPC {
|
||||
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
|
||||
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
|
||||
|
||||
process_q4_elements(c1, &comparray[0]);
|
||||
process_q4_elements(c2, &comparray[1]);
|
||||
process_q4_elements(c3, &comparray[2]);
|
||||
process_q4_elements(c4, &comparray[3]);
|
||||
process_q4_elements(c5, &comparray[4]);
|
||||
process_q4_elements(c6, &comparray[5]);
|
||||
process_q4_elements(c7, &comparray[6]);
|
||||
process_q4_elements(c8, &comparray[7]);
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
process_q4_elements(c3, & comparray[2]);
|
||||
process_q4_elements(c4, & comparray[3]);
|
||||
process_q4_elements(c5, & comparray[4]);
|
||||
process_q4_elements(c6, & comparray[5]);
|
||||
process_q4_elements(c7, & comparray[6]);
|
||||
process_q4_elements(c8, & comparray[7]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
@@ -2405,12 +2662,12 @@ class tinyBLAS_HP16_PPC {
|
||||
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
|
||||
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
|
||||
|
||||
process_q4_elements(c1, &comparray[0]);
|
||||
process_q4_elements(c2, &comparray[1]);
|
||||
process_q4_elements(c3, &comparray[2]);
|
||||
process_q4_elements(c4, &comparray[3]);
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
process_q4_elements(c3, & comparray[2]);
|
||||
process_q4_elements(c4, & comparray[3]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
@@ -2434,12 +2691,12 @@ class tinyBLAS_HP16_PPC {
|
||||
case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
|
||||
break;
|
||||
}
|
||||
process_q4_elements(c1, &comparray[0]);
|
||||
process_q4_elements(c2, &comparray[1]);
|
||||
process_q4_elements(c3, &comparray[2]);
|
||||
process_q4_elements(c4, &comparray[3]);
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
process_q4_elements(c3, & comparray[2]);
|
||||
process_q4_elements(c4, & comparray[3]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
@@ -2450,39 +2707,38 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
template<typename VA, typename VB>
|
||||
void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
||||
void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
|
||||
int64_t i, j;
|
||||
block_q8_0 *aoffset = NULL;
|
||||
VA *vecOffset = NULL;
|
||||
block_q8_0* aoffsets[8];
|
||||
block_q8_0 * aoffset = NULL;
|
||||
VA * vecOffset = NULL;
|
||||
block_q8_0 * aoffsets[8];
|
||||
__vector_pair arr[8];
|
||||
VB c[8][2] = {0};
|
||||
VB c1[8] = {0}; VB c2[8] = {0};
|
||||
aoffset = const_cast<block_q8_0*>(a);
|
||||
aoffset = const_cast<block_q8_0 *>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
do {
|
||||
aoffsets[0] = aoffset;
|
||||
for (int it = 1; it < 8; it++)
|
||||
aoffsets[it] = aoffsets[it-1] + lda;
|
||||
aoffsets[it] = aoffsets[it - 1] + lda;
|
||||
aoffset += 8 * lda;
|
||||
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
do {
|
||||
for (int it = 0; it < 8; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
c2[it] = c[it][1];
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
|
||||
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
||||
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
|
||||
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
|
||||
for (int it = 0; it < 8; it++)
|
||||
aoffsets[it] += lda;
|
||||
vecOffset += 256;
|
||||
@@ -2501,13 +2757,13 @@ class tinyBLAS_HP16_PPC {
|
||||
if (i > 0) {
|
||||
do {
|
||||
for (int it = 0; it < 4; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
c2[it] = c[it][1];
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
||||
for (int it = 0; it < 4; it++) {
|
||||
aoffsets[it] += lda;
|
||||
}
|
||||
@@ -2520,24 +2776,24 @@ class tinyBLAS_HP16_PPC {
|
||||
if (rows & 3) {
|
||||
aoffsets[0] = aoffset;
|
||||
for (int it = 1; it < 3; it++ )
|
||||
aoffsets[it] = aoffsets[it-1] + lda;
|
||||
aoffsets[it] = aoffsets[it - 1] + lda;
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
do {
|
||||
switch(rows) {
|
||||
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[2], &arr[2]);
|
||||
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[2], & arr[2]);
|
||||
c1[2] = c[2][0]; c2[2] = c[2][1];
|
||||
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[1], &arr[1]);
|
||||
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[1], & arr[1]);
|
||||
c1[1] = c[1][0]; c2[1] = c[1][1];
|
||||
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[0], &arr[0]);
|
||||
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[0], & arr[0]);
|
||||
c1[0] = c[0][0]; c2[0] = c[0][1];
|
||||
break;
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
||||
for (int it = 0; it < 3; it++)
|
||||
aoffsets[it] += lda;
|
||||
vecOffset += 128;
|
||||
@@ -2547,8 +2803,7 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int m_rem = MIN(m - m0, 16);
|
||||
int n_rem = MIN(n - n0, 16);
|
||||
|
||||
@@ -2585,8 +2840,7 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[8], vec_B[16] = {0};
|
||||
acc_t acc_0, acc_1;
|
||||
std::array<int, 4> comparray {};
|
||||
@@ -2594,26 +2848,26 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float vs[8] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
__builtin_mma_xxsetaccz(& acc_1);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
|
||||
}
|
||||
for (int I = 0; I<4; I++) {
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
||||
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
*((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2624,15 +2878,14 @@ class tinyBLAS_HP16_PPC {
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(&acc_1, 0, 4, comparray, vs, fin_res);
|
||||
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(& acc_1, 0, 4, comparray, vs, fin_res);
|
||||
}
|
||||
save_res(ii, jj, 0, fin_res);
|
||||
save_res(ii, jj+4, 4, fin_res);
|
||||
save_res(ii, jj + 4, 4, fin_res);
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[16], vec_B[8] = {0};
|
||||
acc_t acc_0, acc_1;
|
||||
std::array<int, 8> comparray {};
|
||||
@@ -2640,25 +2893,25 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float vs[8] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
__builtin_mma_xxsetaccz(& acc_1);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
|
||||
}
|
||||
for (int I = 0; I<8; I++) {
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
for (int I = 0; I < 8; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2669,15 +2922,14 @@ class tinyBLAS_HP16_PPC {
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(&acc_1, 4, 4, comparray, vs, fin_res);
|
||||
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(& acc_1, 4, 4, comparray, vs, fin_res);
|
||||
}
|
||||
save_res(ii, jj, 0, fin_res);
|
||||
save_res(ii+4, jj, 4, fin_res);
|
||||
save_res(ii + 4, jj, 4, fin_res);
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[16], vec_B[16] = {0};
|
||||
acc_t acc_0, acc_1, acc_2, acc_3;
|
||||
acc_t acc_4, acc_5, acc_6, acc_7;
|
||||
@@ -2686,30 +2938,30 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float vs[16] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(&acc_2);
|
||||
__builtin_mma_xxsetaccz(&acc_3);
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
__builtin_mma_xxsetaccz(& acc_1);
|
||||
__builtin_mma_xxsetaccz(& acc_2);
|
||||
__builtin_mma_xxsetaccz(& acc_3);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
|
||||
}
|
||||
for (int I = 0; I<8; I++) {
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
||||
for (int I = 0; I < 8 ; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
*((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2720,19 +2972,99 @@ class tinyBLAS_HP16_PPC {
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(&acc_1, 4, 4, comparray, vs, fin_res);
|
||||
compute(&acc_2, 0, 8, comparray, vs, fin_res);
|
||||
compute(&acc_3, 4, 12, comparray, vs, fin_res);
|
||||
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(& acc_1, 4, 4, comparray, vs, fin_res);
|
||||
compute(& acc_2, 0, 8, comparray, vs, fin_res);
|
||||
compute(& acc_3, 4, 12, comparray, vs, fin_res);
|
||||
}
|
||||
save_res(ii, jj, 0, fin_res);
|
||||
save_res(ii+4, jj, 4, fin_res);
|
||||
save_res(ii, jj+4, 8, fin_res);
|
||||
save_res(ii+4, jj+4, 12, fin_res);
|
||||
save_res(ii + 4, jj, 4, fin_res);
|
||||
save_res(ii, jj + 4, 8, fin_res);
|
||||
save_res(ii + 4, jj + 4, 12, fin_res);
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
||||
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
|
||||
acc_t acc[8];
|
||||
for (int i = 0; i < mc ; i += 16) {
|
||||
for (int j = 0; j < nc; j += 8) {
|
||||
int A0_base = (i / 16) * (2 * 32 * kc);
|
||||
int B0_base = (j / 8) * (32 * kc);
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xxsetaccz(&acc[x]);
|
||||
}
|
||||
for (int64_t kk = 0; kk < kc; kk++) {
|
||||
int A0_block_idx = A0_base + kk * 32;
|
||||
int B0_block_idx = B0_base + kk * 32;
|
||||
int A1_block_idx = A0_block_idx + 32 * kc;
|
||||
int B1_block_idx = B0_block_idx + 32 * kc;
|
||||
vec_t * A0_block = & vec_A[A0_block_idx];
|
||||
vec_t * B0_block = & vec_B[B0_block_idx];
|
||||
vec_t * A1_block = & vec_A[A1_block_idx];
|
||||
for (int it = 0; it < 4; it++) {
|
||||
for (int x = 0; x < 4; x++) {
|
||||
__builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (l == 0) {
|
||||
save_acc(& acc[0], ii + i, jj + j);
|
||||
save_acc(& acc[1], ii + i, jj + j + 4);
|
||||
save_acc(& acc[2], ii + i + 4, jj + j);
|
||||
save_acc(& acc[3], ii + i + 4, jj + j + 4);
|
||||
save_acc(& acc[4], ii + i + 8, jj + j);
|
||||
save_acc(& acc[5], ii + i + 8, jj + j + 4);
|
||||
save_acc(& acc[6], ii + i + 12, jj + j);
|
||||
save_acc(& acc[7], ii + i + 12, jj + j + 4);
|
||||
} else {
|
||||
add_save_acc(& acc[0], ii + i, jj + j);
|
||||
add_save_acc(& acc[1], ii + i, jj + j + 4);
|
||||
add_save_acc(& acc[2], ii + i + 4, jj + j);
|
||||
add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
|
||||
add_save_acc(& acc[4], ii + i + 8, jj + j);
|
||||
add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
|
||||
add_save_acc(& acc[6], ii + i + 12, jj + j);
|
||||
add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
|
||||
vec_t A_pack[mc * kc * 4];
|
||||
vec_t B_pack[nc * kc * 4];
|
||||
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
int64_t ytiles = m / mc;
|
||||
int64_t xtiles = n / nc;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (end > tiles) {
|
||||
end = tiles;
|
||||
}
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = (job / xtiles) * mc;
|
||||
int64_t jj = (job % xtiles) * nc;
|
||||
for (int64_t kk = 0; kk < k; kk += kc) {
|
||||
if constexpr(is_Ablock_q4) {
|
||||
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
||||
} else {
|
||||
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
||||
}
|
||||
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
|
||||
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
int64_t xtiles = (n - n0) / RN;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
@@ -2754,32 +3086,32 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float fin_res[4] = {0};
|
||||
vector float vs[4] = {0};
|
||||
vector float CA[4] = {0};
|
||||
__builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
|
||||
__builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
|
||||
__builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
|
||||
__builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
if (isAblock_q4) {
|
||||
packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
|
||||
for(int x = 0; x < 8; x+=4) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
|
||||
for (int x = 0; x < 8; x += 4) {
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
|
||||
}
|
||||
for (int I = 0; I<RM; I++) {
|
||||
for (int J = 0; J<RN; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||
__builtin_mma_disassemble_acc(vec_C, & acc_0);
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < RM; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2800,9 +3132,21 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
template<int RM, int RN>
|
||||
inline void kernel(int64_t ii, int64_t jj) {
|
||||
if constexpr(RM == 4 && RN == 8) {
|
||||
KERNEL_4x8(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 4) {
|
||||
KERNEL_8x4(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 8) {
|
||||
KERNEL_8x8(ii,jj);
|
||||
} else {
|
||||
assert(false && "RN/RM values not supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <int RM, int RN>
|
||||
NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
int64_t xtiles = (n - n0) / RN;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
@@ -2814,12 +3158,20 @@ class tinyBLAS_HP16_PPC {
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = m0 + job / xtiles * RM;
|
||||
int64_t jj = n0 + job % xtiles * RN;
|
||||
this->kernel<RM, RN>(ii, jj);
|
||||
kernel<RM, RN>(ii, jj);
|
||||
}
|
||||
}
|
||||
|
||||
template class tinyBLAS_Q0_PPC<block_q4_0>;
|
||||
template class tinyBLAS_Q0_PPC<block_q8_0>;
|
||||
const TA * const A;
|
||||
const block_q8_0 * const B;
|
||||
float * C;
|
||||
const int64_t k;
|
||||
int64_t kc;
|
||||
const int64_t lda;
|
||||
const int64_t ldb;
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
};
|
||||
|
||||
class tinyBLAS_PPC {
|
||||
public:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "binary-ops.h"
|
||||
#include "simd-gemm.h"
|
||||
#include "ggml.h"
|
||||
#include "unary-ops.h"
|
||||
#include "vec.h"
|
||||
@@ -2096,10 +2097,14 @@ static void ggml_compute_forward_gelu_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2113,10 +2118,14 @@ static void ggml_compute_forward_gelu_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2135,10 +2144,14 @@ static void ggml_compute_forward_gelu_f16(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2152,10 +2165,14 @@ static void ggml_compute_forward_gelu_f16(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_f16(nc,
|
||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2276,10 +2293,14 @@ static void ggml_compute_forward_gelu_erf_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2293,10 +2314,14 @@ static void ggml_compute_forward_gelu_erf_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_erf_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2315,10 +2340,14 @@ static void ggml_compute_forward_gelu_erf_f16(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2332,10 +2361,14 @@ static void ggml_compute_forward_gelu_erf_f16(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_erf_f16(nc,
|
||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2379,10 +2412,14 @@ static void ggml_compute_forward_gelu_quick_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2396,10 +2433,14 @@ static void ggml_compute_forward_gelu_quick_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_quick_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2418,10 +2459,14 @@ static void ggml_compute_forward_gelu_quick_f16(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2435,10 +2480,14 @@ static void ggml_compute_forward_gelu_quick_f16(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_quick_f16(nc,
|
||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2482,10 +2531,14 @@ static void ggml_compute_forward_silu_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2499,10 +2552,14 @@ static void ggml_compute_forward_silu_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_silu_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2521,10 +2578,14 @@ static void ggml_compute_forward_silu_f16(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2538,10 +2599,14 @@ static void ggml_compute_forward_silu_f16(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_silu_f16(nc,
|
||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -7629,8 +7694,7 @@ static void ggml_compute_forward_pad_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
||||
assert(dst->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
@@ -8326,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
GGML_ASSERT(k->type == v->type);
|
||||
const ggml_type kv_type = k->type;
|
||||
|
||||
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
|
||||
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
|
||||
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
|
||||
const size_t kv_type_size = ggml_type_size(kv_type);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t rk2 = neq2/nek2;
|
||||
@@ -8361,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
|
||||
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
|
||||
|
||||
int ir = ir0;
|
||||
while (ir < ir1) {
|
||||
// q indices for the start of this tile
|
||||
@@ -8389,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
}
|
||||
|
||||
// Per-thread scratch layout:
|
||||
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
|
||||
// Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
|
||||
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
|
||||
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
|
||||
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
|
||||
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
|
||||
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
|
||||
// V32: KV_TILE_SZ * DV (F32 buffer for V tile)
|
||||
// K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
|
||||
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
|
||||
|
||||
void * Q_q = base;
|
||||
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
|
||||
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
|
||||
float * V32 = VKQ32 + Q_TILE_SZ * DV;
|
||||
float * K_f32 = V32 + KV_TILE_SZ * DV;
|
||||
|
||||
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
|
||||
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
||||
@@ -8413,28 +8473,38 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
|
||||
}
|
||||
// Zero-pad remaining rows
|
||||
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
||||
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
|
||||
{
|
||||
float * Q_f32 = (float *)Q_q;
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
|
||||
}
|
||||
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
||||
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
|
||||
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
|
||||
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
|
||||
|
||||
// skip the tile entirely if all the masks are -inf
|
||||
if (mask) {
|
||||
bool can_skip = true;
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
for (int tk = 0; tk < kv_tile; tk++) {
|
||||
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
|
||||
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
|
||||
can_skip = false;
|
||||
}
|
||||
}
|
||||
// Pad remaining mask entries with -inf
|
||||
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
||||
mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
if (can_skip) {
|
||||
@@ -8442,13 +8512,32 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
}
|
||||
}
|
||||
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
float s;
|
||||
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
|
||||
KQ[tq * KV_TILE_SZ + tk] = s * scale;
|
||||
// Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
|
||||
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
|
||||
for (int tk = 0; tk < kv_tile; tk++) {
|
||||
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
|
||||
for (int64_t dk = 0; dk < DK; dk++) {
|
||||
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
|
||||
}
|
||||
} else {
|
||||
const float * k_f32_src = (const float *)k_data;
|
||||
for (int64_t dk = 0; dk < DK; dk++) {
|
||||
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
|
||||
}
|
||||
}
|
||||
}
|
||||
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
||||
simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
|
||||
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
|
||||
|
||||
// Set padded KQ entries to -inf so softmax gives them zero weight
|
||||
if (kv_tile < KV_TILE_SZ) {
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
||||
KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8488,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
|
||||
}
|
||||
|
||||
// Convert V tile to F32 first (if F16), then do MAD
|
||||
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
|
||||
// TODO: on ARM, native f16 should be faster
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
|
||||
}
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
|
||||
}
|
||||
// V accumulation: VKQ32 += softmax(KQ) * V
|
||||
// Pack V tile to contiguous F32, zero-padded
|
||||
for (int tk = 0; tk < kv_tile; tk++) {
|
||||
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
|
||||
} else {
|
||||
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
|
||||
}
|
||||
}
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) {
|
||||
memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
|
||||
}
|
||||
}
|
||||
simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
|
||||
}
|
||||
|
||||
// sinks (apply only to valid rows in the tile)
|
||||
@@ -8731,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
|
||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||
|
||||
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
const bool use_tiled = !use_ref &&
|
||||
bool use_tiled = !use_ref &&
|
||||
(q->type == GGML_TYPE_F32 &&
|
||||
kv_is_f32_or_f16 &&
|
||||
k->type == v->type &&
|
||||
nek1 % KV_TILE_SZ == 0 &&
|
||||
neq1 >= Q_TILE_SZ);
|
||||
|
||||
#ifdef GGML_SIMD
|
||||
use_tiled &= (DV % GGML_F32_EPR == 0);
|
||||
#endif
|
||||
int current_chunk = ith;
|
||||
|
||||
while (current_chunk < nchunk) {
|
||||
|
||||
@@ -256,6 +256,200 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR
|
||||
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
|
||||
}
|
||||
|
||||
template <int M, int N>
|
||||
static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int blocklen = M;
|
||||
constexpr int ncols_interleaved = N;
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int blocks_per_half = 64 / blocklen;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0f;
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
|
||||
const int base_h = base_l + 64;
|
||||
|
||||
const int scale_idx_l = base_l / 16;
|
||||
const int scale_idx_h = base_h / 16;
|
||||
|
||||
const int qh_shift_l = ((base_l % 128) / 32) * 2;
|
||||
const int qh_shift_h = ((base_h % 128) / 32) * 2;
|
||||
|
||||
const int qh_half_l = (base_l / 128) * 32;
|
||||
const int qh_half_h = (base_h / 128) * 32;
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
|
||||
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
|
||||
|
||||
int sumi_l = 0;
|
||||
int sumi_h = 0;
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
|
||||
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
|
||||
|
||||
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
|
||||
const int qh_chunk_l = qh_idx_l / blocklen;
|
||||
const int qh_pos_l = qh_idx_l % blocklen;
|
||||
const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
|
||||
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
|
||||
|
||||
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
|
||||
const int qh_chunk_h = qh_idx_h / blocklen;
|
||||
const int qh_pos_h = qh_idx_h % blocklen;
|
||||
const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
|
||||
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
|
||||
|
||||
const int q_l = ((hi_2_l << 4) | l_4) - 32;
|
||||
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
|
||||
|
||||
const int8_t a_l = a_ptr[l].qs[base_l + i];
|
||||
const int8_t a_h = a_ptr[l].qs[base_h + i];
|
||||
|
||||
sumi_l += q_l * a_l;
|
||||
sumi_h += q_h * a_h;
|
||||
}
|
||||
|
||||
sumf[j] +=
|
||||
(sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int M, int N>
|
||||
static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int blocklen = M;
|
||||
constexpr int ncols_interleaved = N;
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int blocks_per_half = 64 / blocklen;
|
||||
const int q8_half_stride = 512;
|
||||
const int q8_low_high_step = 256;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
|
||||
float sumf[4][8];
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
|
||||
const int base_h = base_l + 64;
|
||||
|
||||
const int scale_idx_l = base_l / 16;
|
||||
const int scale_idx_h = base_h / 16;
|
||||
|
||||
const int qh_shift_l = ((base_l % 128) / 32) * 2;
|
||||
const int qh_shift_h = ((base_h % 128) / 32) * 2;
|
||||
|
||||
const int qh_half_l = (base_l / 128) * 32;
|
||||
const int qh_half_h = (base_h / 128) * 32;
|
||||
|
||||
const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4);
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
|
||||
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
|
||||
|
||||
int sumi_l = 0;
|
||||
int sumi_h = 0;
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
|
||||
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
|
||||
|
||||
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
|
||||
const int qh_chunk_l = qh_idx_l / blocklen;
|
||||
const int qh_pos_l = qh_idx_l % blocklen;
|
||||
const int qh_offset_l =
|
||||
qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
|
||||
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
|
||||
|
||||
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
|
||||
const int qh_chunk_h = qh_idx_h / blocklen;
|
||||
const int qh_pos_h = qh_idx_h % blocklen;
|
||||
const int qh_offset_h =
|
||||
qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
|
||||
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
|
||||
|
||||
const int q_l = ((hi_2_l << 4) | l_4) - 32;
|
||||
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
|
||||
|
||||
const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i];
|
||||
const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step];
|
||||
|
||||
sumi_l += q_l * q8_l;
|
||||
sumi_h += q_h * q8_h;
|
||||
}
|
||||
|
||||
sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
|
||||
a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
@@ -704,94 +898,12 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
|
||||
}
|
||||
|
||||
|
||||
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0f;
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
|
||||
|
||||
for (int k = 0; k < 16; k++) {
|
||||
// k = 0.. 7 weights 0-63 low, 64-127 high
|
||||
// k = 8..15 weights 128-191 low, 192-255 high
|
||||
const int base_l = (k / 8) * 128 + (k % 8) * 8;
|
||||
const int base_h = base_l + 64;
|
||||
|
||||
const int scale_idx_l = base_l / 16;
|
||||
const int scale_idx_h = base_h / 16;
|
||||
|
||||
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
|
||||
const int qh_shift_l = ((base_l % 128) / 32) * 2;
|
||||
const int qh_shift_h = ((base_h % 128) / 32) * 2;
|
||||
|
||||
// qh_half: offset to the correct 32-byte half (0 or 32)
|
||||
const int qh_half_l = (base_l / 128) * 32;
|
||||
const int qh_half_h = (base_h / 128) * 32;
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
// Interleaved scales
|
||||
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
|
||||
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
|
||||
|
||||
int sumi_l = 0;
|
||||
int sumi_h = 0;
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int ql_pos = k * 64 + j * 8 + i;
|
||||
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
|
||||
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
|
||||
|
||||
// qh indexing with 8-byte interleaving (like q5_K)
|
||||
const int qh_byte_l = qh_half_l + ((base_l + i) % 32);
|
||||
const int qh_chunk_l = qh_byte_l / 8;
|
||||
const int qh_pos_l = qh_byte_l % 8;
|
||||
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
|
||||
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
|
||||
|
||||
const int qh_byte_h = qh_half_h + ((base_h + i) % 32);
|
||||
const int qh_chunk_h = qh_byte_h / 8;
|
||||
const int qh_pos_h = qh_byte_h % 8;
|
||||
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
|
||||
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
|
||||
|
||||
const int q_l = ((hi_2_l << 4) | l_4) - 32;
|
||||
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
|
||||
|
||||
const int8_t a_l = a_ptr[l].qs[base_l + i];
|
||||
const int8_t a_h = a_ptr[l].qs[base_h + i];
|
||||
|
||||
sumi_l += q_l * a_l;
|
||||
sumi_h += q_h * a_h;
|
||||
}
|
||||
|
||||
sumf[j] +=
|
||||
(sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j];
|
||||
}
|
||||
}
|
||||
ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
@@ -1485,109 +1597,12 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q6_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
|
||||
float sumf[4][8];
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < 16; k++) {
|
||||
// k = 0.. 7 weights 0-63 low, 64-127 high
|
||||
// k = 8..15 weights 128-191 low, 192-255 high
|
||||
const int base_l = (k / 8) * 128 + (k % 8) * 8;
|
||||
const int base_h = base_l + 64;
|
||||
|
||||
const int scale_idx_l = base_l / 16;
|
||||
const int scale_idx_h = base_h / 16;
|
||||
|
||||
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
|
||||
const int qh_shift_l = ((base_l % 128) / 32) * 2;
|
||||
const int qh_shift_h = ((base_h % 128) / 32) * 2;
|
||||
|
||||
// qh_half: offset to the correct 32-byte half (0 or 32)
|
||||
const int qh_half_l = (base_l / 128) * 32;
|
||||
const int qh_half_h = (base_h / 128) * 32;
|
||||
|
||||
// Activation base indices for q8_Kx4 interleaved format
|
||||
// Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32
|
||||
const int q8_base = (k / 8) * 512 + (k % 8) * 32;
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
// Interleaved scales
|
||||
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
|
||||
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
|
||||
|
||||
int sumi_l = 0;
|
||||
int sumi_h = 0;
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int ql_pos = k * 64 + j * 8 + i;
|
||||
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
|
||||
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
|
||||
|
||||
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
|
||||
const int qh_chunk_l = qh_idx_l / 8;
|
||||
const int qh_pos_l = qh_idx_l % 8;
|
||||
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
|
||||
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
|
||||
|
||||
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
|
||||
const int qh_chunk_h = qh_idx_h / 8;
|
||||
const int qh_pos_h = qh_idx_h % 8;
|
||||
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
|
||||
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
|
||||
|
||||
const int q_l = ((hi_2_l << 4) | l_4) - 32;
|
||||
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
|
||||
|
||||
const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i];
|
||||
const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256];
|
||||
|
||||
sumi_l += q_l * q8_l;
|
||||
sumi_h += q_h * q8_h;
|
||||
}
|
||||
|
||||
sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
|
||||
a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
@@ -1901,9 +1916,10 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
// buffer large enough for the max interleave block size (8 bytes)
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
||||
memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave);
|
||||
memcpy(&out.qs[dst_offset], &elems, blck_size_interleave);
|
||||
}
|
||||
|
||||
// The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
|
||||
@@ -2097,18 +2113,18 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in
|
||||
}
|
||||
|
||||
const int end_ls = QK_K * 4 / blck_size_interleave;
|
||||
// Interleave Q6_K quants by taking 8 bytes at a time
|
||||
// Interleave Q6_K quants by taking blck_size_interleave bytes at a time
|
||||
for (int i = 0; i < end_ls; ++i) {
|
||||
int src_id = i % n_blocks;
|
||||
int src_offset = (i / n_blocks) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elem_ls;
|
||||
memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t));
|
||||
memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave);
|
||||
memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave);
|
||||
}
|
||||
|
||||
// Interleave high bits using same 8-byte pattern as low bits
|
||||
// Interleave high bits using same chunk size as low bits
|
||||
const int end_hs = end_ls / 2;
|
||||
for (int i = 0; i < end_hs; ++i) {
|
||||
int src_id = i % n_blocks;
|
||||
@@ -2116,8 +2132,8 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elem_hs;
|
||||
memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t));
|
||||
memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave);
|
||||
memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave);
|
||||
}
|
||||
|
||||
// The below logic is designed so as to unpack and rearrange scales in Q6_K
|
||||
@@ -2262,7 +2278,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
|
||||
|
||||
static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q6_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
||||
constexpr int nrows_interleaved = 8;
|
||||
|
||||
block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data;
|
||||
@@ -2511,6 +2527,10 @@ template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * da
|
||||
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q6_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q6_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
@@ -2575,6 +2595,10 @@ template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
||||
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -2634,6 +2658,10 @@ template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
||||
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -3043,6 +3071,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
|
||||
|
||||
// instance for Q6_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q6_K, 4, 8, GGML_TYPE_Q8_K> q6_K_8x4_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q6_K, 8, 8, GGML_TYPE_Q8_K> q6_K_8x8_q8_K;
|
||||
|
||||
// instance for Q2
|
||||
@@ -3107,6 +3136,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
return &q6_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q6_K_8x4_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
||||
if (ggml_cpu_has_avx2()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
|
||||
@@ -112,6 +112,7 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -122,6 +123,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -142,6 +144,7 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -152,6 +155,7 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
136
ggml/src/ggml-cpu/simd-gemm.h
Normal file
136
ggml/src/ggml-cpu/simd-gemm.h
Normal file
@@ -0,0 +1,136 @@
|
||||
#pragma once
|
||||
|
||||
// Computes C[M x N] += A[M x K] * B[K x N]
|
||||
|
||||
#include "simd-mappings.h"
|
||||
|
||||
// TODO: add support for sizeless vector types
|
||||
#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
|
||||
|
||||
// TODO: untested on avx512
|
||||
// These are in units of GGML_F32_EPR
|
||||
#if defined(__AVX512F__) || defined (__ARM_NEON__)
|
||||
static constexpr int GEMM_RM = 4;
|
||||
static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
|
||||
#elif defined(__AVX2__) || defined(__AVX__)
|
||||
static constexpr int GEMM_RM = 6;
|
||||
static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
|
||||
#else
|
||||
static constexpr int GEMM_RM = 2;
|
||||
static constexpr int GEMM_RN = 2;
|
||||
#endif
|
||||
|
||||
template <int RM, int RN>
|
||||
static inline void simd_gemm_ukernel(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int K, int N)
|
||||
{
|
||||
static constexpr int KN = GGML_F32_EPR;
|
||||
|
||||
GGML_F32_VEC acc[RM][RN];
|
||||
for (int64_t i = 0; i < RM; i++) {
|
||||
for (int r = 0; r < RN; r++) {
|
||||
acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
GGML_F32_VEC Bv[RN];
|
||||
for (int r = 0; r < RN; r++) {
|
||||
Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);
|
||||
}
|
||||
for (int64_t i = 0; i < RM; i++) {
|
||||
GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);
|
||||
for (int r = 0; r < RN; r++) {
|
||||
acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < RM; i++) {
|
||||
for (int r = 0; r < RN; r++) {
|
||||
GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// C[M x N] += A[M x K] * B[K x N]
|
||||
static void simd_gemm(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int M, int K, int N)
|
||||
{
|
||||
static constexpr int KN = GGML_F32_EPR;
|
||||
|
||||
int64_t ii = 0;
|
||||
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
|
||||
int64_t jj = 0;
|
||||
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
||||
simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj + KN <= N; jj += KN) {
|
||||
simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj < N; jj++) {
|
||||
for (int64_t i = 0; i < GEMM_RM; i++) {
|
||||
float a = C[i * N + jj];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
a += A[i + kk] * B[kk * N + jj];
|
||||
}
|
||||
C[i * N + jj] = a;
|
||||
}
|
||||
}
|
||||
|
||||
A += GEMM_RM * K;
|
||||
C += GEMM_RM * N;
|
||||
}
|
||||
|
||||
// Tail rows: one at a time
|
||||
for (; ii < M; ii++) {
|
||||
int64_t jj = 0;
|
||||
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
||||
simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj + KN <= N; jj += KN) {
|
||||
simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj < N; jj++) {
|
||||
float a = C[jj];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
a += A[kk] * B[kk * N + jj];
|
||||
}
|
||||
C[jj] = a;
|
||||
}
|
||||
|
||||
A += K;
|
||||
C += N;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__GNUC__) && !defined(__clang__)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#else // scalar path
|
||||
|
||||
static void simd_gemm(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int M, int K, int N)
|
||||
{
|
||||
for (int64_t i = 0; i < M; i++) {
|
||||
for (int64_t j = 0; j < N; j++) {
|
||||
float sum = C[i * N + j];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
sum += A[i * K + kk] * B[kk * N + j];
|
||||
}
|
||||
C[i * N + j] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GGML_SIMD
|
||||
@@ -1160,6 +1160,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
||||
float32x4_t tmp = x[0] + vec_reve(x[0]); \
|
||||
res = tmp[0] + tmp[1]; \
|
||||
}
|
||||
#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \
|
||||
{ \
|
||||
float32x4_t v = vec_add(vec_add(s0, s1), \
|
||||
vec_add(s2, s3)); \
|
||||
v = vec_add(v, vec_sld(v, v, 8)); \
|
||||
v = vec_add(v, vec_sld(v, v, 4)); \
|
||||
res += (ggml_float)vec_extract(v, 0); \
|
||||
}
|
||||
|
||||
#define GGML_F32_VEC GGML_F32x4
|
||||
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
|
||||
@@ -1209,6 +1217,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
|
||||
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
|
||||
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
|
||||
|
||||
// BF16 s390x
|
||||
#define GGML_BF16_STEP 16
|
||||
#define GGML_BF16_EPR 8
|
||||
|
||||
#define GGML_BF16x8 __vector unsigned short
|
||||
#define GGML_BF16x8_ZERO vec_splats((unsigned short)0)
|
||||
#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))
|
||||
|
||||
#define GGML_BF16_VEC GGML_BF16x8
|
||||
#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO
|
||||
#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD
|
||||
#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO))
|
||||
#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO))
|
||||
#define GGML_BF16_FMA_LO(acc, x, y) \
|
||||
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))
|
||||
#define GGML_BF16_FMA_HI(acc, x, y) \
|
||||
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))
|
||||
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
|
||||
// compatible with vlen >= 128
|
||||
|
||||
@@ -111,7 +111,7 @@ template <float (*op)(float), typename src0_t, typename dst_t>
|
||||
static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
|
||||
@@ -236,8 +236,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
|
||||
vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||
sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
|
||||
|
||||
#endif
|
||||
#if defined(__POWER9_VECTOR__)
|
||||
#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__)
|
||||
const int np = (n & ~(GGML_BF16_STEP - 1));
|
||||
if (np > 0) {
|
||||
GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};
|
||||
|
||||
@@ -64,7 +64,7 @@ if (CUDAToolkit_FOUND)
|
||||
FetchContent_Declare(
|
||||
CCCL
|
||||
GIT_REPOSITORY https://github.com/nvidia/cccl.git
|
||||
GIT_TAG v3.2.0-rc2
|
||||
GIT_TAG v3.2.0
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
|
||||
|
||||
@@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
const uint3 ne11,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*int s0, */ const int s1,
|
||||
/*const int s0,*/
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
/*int s00,*/ const int s01,
|
||||
const int s00,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
/*int s10,*/ const int s11,
|
||||
const int s10,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
@@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
|
||||
const uint32_t i10 = fastmodulo(i0, ne10);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0] : 0.0f;
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10]);
|
||||
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
@@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
||||
const uint3 ne11,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*int s0, */ const int s1,
|
||||
/*const int s0,*/
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
/*int s00,*/ const int s01,
|
||||
const int s00,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
/*int s10,*/ const int s11,
|
||||
const int s10,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
@@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
||||
|
||||
const int i10 = fastmodulo(i0, ne10);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0] : 0.0f;
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10]);
|
||||
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
@@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
cnb[3] *= cne[3];
|
||||
};
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (nr[i] != 1) {
|
||||
break;
|
||||
@@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
//size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
@@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s00 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
|
||||
@@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
|
||||
ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
|
||||
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13);
|
||||
}
|
||||
} else {
|
||||
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
|
||||
if constexpr (sizeof...(I) > 0) {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00 ,s01, s02, s03,
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t ne00, const int64_t ne01,
|
||||
const int64_t ne0203, const uint3 ne02,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03) {
|
||||
const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
|
||||
|
||||
@@ -16,23 +17,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
||||
}
|
||||
|
||||
const int64_t i01 = blockIdx.y;
|
||||
const int64_t i02 = blockIdx.z % ne02;
|
||||
const int64_t i03 = blockIdx.z / ne02;
|
||||
|
||||
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
|
||||
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
|
||||
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
|
||||
const int64_t i02 = dm.y;
|
||||
const int64_t i03 = dm.x;
|
||||
|
||||
const int64_t ib = ibx0 + i00/qk; // block index
|
||||
const int64_t iqs = (i00%qk)/qr; // quant index
|
||||
const int64_t iybs = i00 - i00%qk; // y block start index
|
||||
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
|
||||
|
||||
// dequantize
|
||||
float2 v;
|
||||
dequantize_kernel(vx, ib, iqs, v);
|
||||
const int64_t ib = ibx0 + i00/qk; // block index
|
||||
const int64_t iqs = (i00%qk)/qr; // quant index
|
||||
const int64_t iybs = i00 - i00%qk; // y block start index
|
||||
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
|
||||
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
// dequantize
|
||||
float2 v;
|
||||
dequantize_kernel(vx, ib, iqs, v);
|
||||
|
||||
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
|
||||
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool need_check>
|
||||
@@ -485,9 +490,11 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_cuda(const void * vx, dst_t * y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
|
||||
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
|
||||
const int64_t ne0203 = ne02*ne03;
|
||||
const uint3 ne02_fdv = init_fastdiv_values(ne02);
|
||||
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535));
|
||||
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
|
||||
(vx, y, ne00, ne01, ne02, s01, s02, s03);
|
||||
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
@@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static __global__ void convert_unary(
|
||||
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
|
||||
const int64_t ne0203, const uint3 ne02,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03) {
|
||||
const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
@@ -621,23 +629,29 @@ static __global__ void convert_unary(
|
||||
}
|
||||
|
||||
const int64_t i01 = blockIdx.y;
|
||||
const int64_t i02 = blockIdx.z % ne02;
|
||||
const int64_t i03 = blockIdx.z / ne02;
|
||||
|
||||
const src_t * x = (const src_t *) vx;
|
||||
|
||||
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
|
||||
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
|
||||
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
|
||||
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
|
||||
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
|
||||
const int64_t i02 = dm.y;
|
||||
const int64_t i03 = dm.x;
|
||||
|
||||
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
|
||||
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
|
||||
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary_cuda(const void * vx, dst_t * y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
|
||||
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
|
||||
const int64_t ne0203 = ne02*ne03;
|
||||
const uint3 ne02_fdv = init_fastdiv_values(ne02);
|
||||
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535));
|
||||
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
|
||||
(vx, y, ne00, ne01, ne02, s01, s02, s03);
|
||||
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
|
||||
@@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
||||
#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
|
||||
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16> frag_c_VKQ;
|
||||
#else
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
|
||||
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
||||
#endif
|
||||
|
||||
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
||||
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
||||
@@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
||||
half2 * VKQ2 = (half2 *) VKQ;
|
||||
|
||||
#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
|
||||
const _Float16 * K_h_f16 = reinterpret_cast<const _Float16 *>(K_h);
|
||||
const _Float16 * V_h_f16 = reinterpret_cast<const _Float16 *>(V_h);
|
||||
_Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ);
|
||||
_Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ);
|
||||
#else
|
||||
const half * K_h_f16 = K_h;
|
||||
const half * V_h_f16 = V_h;
|
||||
half * KQ_f16 = KQ;
|
||||
half * VKQ_f16 = VKQ;
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
@@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
for (int i0 = 0; i0 < D; i0 += 16) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
||||
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||
frag_a_K K_a;
|
||||
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||
@@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
wmma::load_matrix_sync(
|
||||
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
||||
KQ + j0*(kqar*kqs_padded) + k,
|
||||
KQ_f16 + j0*(kqar*kqs_padded) + k,
|
||||
kqar*kqs_padded);
|
||||
}
|
||||
}
|
||||
@@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
|
||||
frag_a_V v_a;
|
||||
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||
@@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
wmma::store_matrix_sync(
|
||||
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
||||
D_padded, wmma::mem_col_major);
|
||||
}
|
||||
|
||||
@@ -2278,11 +2278,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
|
||||
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
|
||||
if (ggml_is_quantized(src0->type)) {
|
||||
if (ne2 <= 4) {
|
||||
if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
||||
return;
|
||||
}
|
||||
@@ -2305,6 +2306,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
}
|
||||
}
|
||||
|
||||
// note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
|
||||
// TODO: add asserts to verify this. should work with CUDA, HIP, etc.
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(nb12 % nb11 == 0);
|
||||
@@ -2865,14 +2868,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
bool use_cuda_graph = true;
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
||||
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
|
||||
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
|
||||
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
||||
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
|
||||
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
@@ -2887,30 +2882,14 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
|
||||
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD &&
|
||||
node->src[1] && node->src[1]->ne[1] > 1 &&
|
||||
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
|
||||
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
|
||||
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
|
||||
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
|
||||
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
||||
// by means of matching node names. See
|
||||
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
||||
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
|
||||
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
|
||||
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
|
||||
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
|
||||
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -3640,11 +3619,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
n_fuse++;
|
||||
|
||||
if (n_fuse > 1) {
|
||||
ggml_tensor fused_add_node;
|
||||
memcpy(&fused_add_node, node, sizeof(ggml_tensor));
|
||||
for (int j = 0; j < n_fuse - 1; ++j) {
|
||||
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
|
||||
fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
|
||||
}
|
||||
cgraph->nodes[i + n_fuse - 1]->data = node->data;
|
||||
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
|
||||
fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
|
||||
ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
|
||||
i += n_fuse - 1;
|
||||
|
||||
continue;
|
||||
@@ -4542,6 +4523,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
// TODO: should become:
|
||||
//return ggml_is_contiguous_rows(op->src[0]);
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
@@ -4820,8 +4803,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_ACC:
|
||||
// TODO: extend support like so:
|
||||
//return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
|
||||
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_TOP_K:
|
||||
@@ -4834,8 +4820,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_PAD:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_PAD:
|
||||
return true;
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_ARANGE:
|
||||
|
||||
@@ -2715,14 +2715,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR2_XXS; ++l) {
|
||||
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
|
||||
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
|
||||
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
|
||||
|
||||
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
|
||||
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
||||
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
||||
@@ -2733,12 +2733,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
const int ls = aux32 >> 28;
|
||||
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
|
||||
const float d = bxi->d;
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
|
||||
#else
|
||||
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
|
||||
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
@@ -2776,11 +2776,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR2_XS; ++l) {
|
||||
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
|
||||
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
|
||||
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
|
||||
const uint32_t signs = unpack_ksigns(q2[l] >> 9);
|
||||
|
||||
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
||||
@@ -2904,11 +2907,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR3_XXS; ++l) {
|
||||
const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
|
||||
|
||||
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
||||
#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
|
||||
|
||||
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
||||
|
||||
@@ -7,7 +7,7 @@ __device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
|
||||
return (coord + size) % size;
|
||||
}
|
||||
|
||||
static __global__ void pad_f32(const float * src, float * dst,
|
||||
static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||
@@ -34,11 +34,8 @@ static __global__ void pad_f32(const float * src, float * dst,
|
||||
const int64_t i01 = i1 - lp1;
|
||||
const int64_t i02 = i2 - lp2;
|
||||
const int64_t i03 = i3 - lp3;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
@@ -57,21 +54,21 @@ static __global__ void pad_f32(const float * src, float * dst,
|
||||
const int64_t i02 = wrap_around(i2 - lp2, ne02);
|
||||
const int64_t i03 = wrap_around(i3 - lp3, ne03);
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void pad_f32_cuda(const float * src, float * dst,
|
||||
static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||
const bool circular, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2 * ne3);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst,
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, s00, s01, s02, s03, dst,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
ne0, ne1, ne2, ne3, circular);
|
||||
}
|
||||
@@ -82,9 +79,10 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int32_t lp0 = ((const int32_t *) (dst->op_params))[0];
|
||||
const int32_t rp0 = ((const int32_t *) (dst->op_params))[1];
|
||||
@@ -96,7 +94,12 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int32_t rp3 = ((const int32_t *) (dst->op_params))[7];
|
||||
const int32_t circular = ((const int32_t *) (dst->op_params))[8];
|
||||
|
||||
pad_f32_cuda(src0_d, dst_d,
|
||||
const size_t s00 = nb00 / ggml_type_size(src0->type);
|
||||
const size_t s01 = nb01 / ggml_type_size(src0->type);
|
||||
const size_t s02 = nb02 / ggml_type_size(src0->type);
|
||||
const size_t s03 = nb03 / ggml_type_size(src0->type);
|
||||
|
||||
pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
(bool) circular, stream);
|
||||
|
||||
@@ -43,10 +43,15 @@ static __device__ void rope_yarn(
|
||||
template <bool forward, bool has_ff, typename T, typename D>
|
||||
static __global__ void rope_norm(const T * x,
|
||||
D * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int n_dims,
|
||||
const int32_t * pos,
|
||||
const float freq_scale,
|
||||
@@ -59,23 +64,23 @@ static __global__ void rope_norm(const T * x,
|
||||
const int set_rows_stride) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
if (i0 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
|
||||
int idst = row_dst * ne0 + i0;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0;
|
||||
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||
|
||||
int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
||||
if (set_rows_stride != 0) {
|
||||
idst = row_x * ne0 + i0;
|
||||
idst += row_indices[channel_x] * set_rows_stride;
|
||||
idst = i1 * s1 + i0;
|
||||
idst += row_indices[i2] * set_rows_stride;
|
||||
}
|
||||
|
||||
const auto & store_coaelsced = [&](float x0, float x1) {
|
||||
@@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x,
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
@@ -110,10 +115,15 @@ static __global__ void rope_norm(const T * x,
|
||||
template <bool forward, bool has_ff, typename T, typename D>
|
||||
static __global__ void rope_neox(const T * x,
|
||||
D * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int n_dims,
|
||||
const int32_t * pos,
|
||||
const float freq_scale,
|
||||
@@ -126,23 +136,24 @@ static __global__ void rope_neox(const T * x,
|
||||
const int set_rows_stride) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
if (i0 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||
|
||||
int idst = row_dst * ne0 + i0 / 2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
||||
if (set_rows_stride != 0) {
|
||||
idst = row_x * ne0 + i0 / 2;
|
||||
idst += row_indices[channel_x] * set_rows_stride;
|
||||
idst = i1 * s1 + i0 / 2;
|
||||
idst += row_indices[i2] * set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
@@ -152,7 +163,7 @@ static __global__ void rope_neox(const T * x,
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
@@ -168,24 +179,42 @@ static __global__ void rope_neox(const T * x,
|
||||
dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
|
||||
}
|
||||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_multi(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
||||
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
template <bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_multi(const T * x,
|
||||
T * dst,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int n_dims,
|
||||
const int32_t * pos,
|
||||
const float freq_scale,
|
||||
const float ext_factor,
|
||||
const float attn_factor,
|
||||
const rope_corr_dims corr_dims,
|
||||
const float theta_scale,
|
||||
const float * freq_factors,
|
||||
const mrope_sections sections,
|
||||
const bool is_imrope) {
|
||||
const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
if (i0 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||
|
||||
const int idst = row_dst*ne0 + i0/2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||
@@ -200,27 +229,24 @@ static __global__ void rope_multi(
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
|
||||
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
|
||||
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
|
||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
|
||||
theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
|
||||
theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
|
||||
theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
|
||||
} else {
|
||||
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
||||
theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < sections.v[0]) {
|
||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
||||
theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
|
||||
} else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
|
||||
} else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
|
||||
} else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,37 +264,53 @@ static __global__ void rope_multi(
|
||||
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_vision(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float * freq_factors, const mrope_sections sections) {
|
||||
template <bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_vision(const T * x,
|
||||
T * dst,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int n_dims,
|
||||
const int32_t * pos,
|
||||
const float freq_scale,
|
||||
const float ext_factor,
|
||||
const float attn_factor,
|
||||
const rope_corr_dims corr_dims,
|
||||
const float theta_scale,
|
||||
const float * freq_factors,
|
||||
const mrope_sections sections) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
if (i0 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||
|
||||
const int idst = row_dst*ne0 + i0/2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
const int sect_dims = sections.v[0] + sections.v[1];
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < sections.v[0]) {
|
||||
const int p = sector;
|
||||
theta_base = pos[channel_x]*powf(theta_scale, p);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[i2] * powf(theta_scale, p);
|
||||
} else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
const int p = sector - sections.v[0];
|
||||
theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
|
||||
theta_base = pos[i2 + ne02] * powf(theta_scale, p);
|
||||
}
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
@@ -288,10 +330,15 @@ static __global__ void rope_vision(
|
||||
template <bool forward, typename T, typename D>
|
||||
static void rope_norm_cuda(const T * x,
|
||||
D * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int n_dims,
|
||||
const int nr,
|
||||
const int32_t * pos,
|
||||
@@ -304,31 +351,36 @@ static void rope_norm_cuda(const T * x,
|
||||
const int64_t * row_indices,
|
||||
const int set_rows_stride,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
||||
freq_factors, row_indices, set_rows_stride);
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
} else {
|
||||
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
||||
freq_factors, row_indices, set_rows_stride);
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool forward, typename T, typename D>
|
||||
static void rope_neox_cuda(const T * x,
|
||||
D * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int n_dims,
|
||||
const int nr,
|
||||
const int32_t * pos,
|
||||
@@ -341,55 +393,92 @@ static void rope_neox_cuda(const T * x,
|
||||
const int64_t * row_indices,
|
||||
const int set_rows_stride,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
||||
freq_factors, row_indices, set_rows_stride);
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
} else {
|
||||
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
||||
freq_factors, row_indices, set_rows_stride);
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool forward, typename T>
|
||||
static void rope_multi_cuda(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
template <bool forward, typename T>
|
||||
static void rope_multi_cuda(const T * x,
|
||||
T * dst,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int n_dims,
|
||||
const int nr,
|
||||
const int32_t * pos,
|
||||
const float freq_scale,
|
||||
const float freq_base,
|
||||
const float ext_factor,
|
||||
const float attn_factor,
|
||||
const rope_corr_dims corr_dims,
|
||||
const float * freq_factors,
|
||||
const mrope_sections sections,
|
||||
const bool is_imrope,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
||||
} else {
|
||||
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool forward, typename T>
|
||||
static void rope_vision_cuda(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
template <bool forward, typename T>
|
||||
static void rope_vision_cuda(const T * x,
|
||||
T * dst,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int n_dims,
|
||||
const int nr,
|
||||
const int32_t * pos,
|
||||
const float freq_scale,
|
||||
const float freq_base,
|
||||
const float ext_factor,
|
||||
const float attn_factor,
|
||||
const rope_corr_dims corr_dims,
|
||||
const float * freq_factors,
|
||||
const mrope_sections sections,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||
// break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
|
||||
// where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
|
||||
@@ -398,11 +487,11 @@ static void rope_vision_cuda(
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
||||
} else {
|
||||
rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
||||
}
|
||||
}
|
||||
@@ -445,6 +534,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
|
||||
|
||||
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
|
||||
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
|
||||
const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
|
||||
|
||||
const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
|
||||
const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
|
||||
const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
@@ -495,57 +589,63 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
|
||||
// compute
|
||||
if (is_neox) {
|
||||
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
||||
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, row_indices, set_rows_stride, stream);
|
||||
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
|
||||
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||
set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
||||
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, row_indices, set_rows_stride, stream);
|
||||
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
|
||||
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||
set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
||||
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
|
||||
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, row_indices, set_rows_stride, stream);
|
||||
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
|
||||
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||
set_rows_stride, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} else if (is_mrope && !is_vision) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_multi_cuda<forward>(
|
||||
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
|
||||
rope_multi_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
|
||||
s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
|
||||
corr_dims, freq_factors, sections, is_imrope, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_multi_cuda<forward>(
|
||||
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
|
||||
rope_multi_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
|
||||
s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
|
||||
corr_dims, freq_factors, sections, is_imrope, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} else if (is_vision) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_vision_cuda<forward>(
|
||||
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
rope_vision_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
|
||||
s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
|
||||
corr_dims, freq_factors, sections, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_vision_cuda<forward>(
|
||||
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
rope_vision_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
|
||||
s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
|
||||
corr_dims, freq_factors, sections, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} else {
|
||||
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
||||
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, row_indices, set_rows_stride, stream);
|
||||
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
|
||||
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||
set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
||||
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, row_indices, set_rows_stride, stream);
|
||||
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
|
||||
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||
set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
||||
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
|
||||
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, row_indices, set_rows_stride, stream);
|
||||
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
|
||||
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||
set_rows_stride, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
@@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
|
||||
// v is a 7 bit int, with the 8th sign being encodable as popcnt
|
||||
// with xor we can "correct" the bit instead of having to mask
|
||||
const uint32_t p = __popc(v) & 1;
|
||||
const uint32_t s = v ^ p << 7;
|
||||
// broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors
|
||||
return s * 0x01010101;
|
||||
}
|
||||
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||
|
||||
@@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
||||
int sumi = 0;
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < 8; k0 += 2) {
|
||||
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
|
||||
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
|
||||
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]];
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2));
|
||||
|
||||
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
|
||||
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
|
||||
sumi = ggml_cuda_dp4a(grid0, u0, sumi);
|
||||
|
||||
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
||||
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
|
||||
sumi = ggml_cuda_dp4a(grid1, u1, sumi);
|
||||
}
|
||||
|
||||
const int ls = aux32 >> 28;
|
||||
sumi = (ls*sumi + sumi/2)/4;
|
||||
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
|
||||
sumi = sumi * ls / 8; // (sumi * scale + sumi / 2) / 4
|
||||
const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
|
||||
return d * sumi;
|
||||
}
|
||||
@@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
||||
int sumi1 = 0;
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < 8; l0 += 2) {
|
||||
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
|
||||
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9));
|
||||
|
||||
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
||||
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF];
|
||||
const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9);
|
||||
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
|
||||
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
|
||||
|
||||
if (l0 < 4) {
|
||||
@@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < 8; l0 += 2) {
|
||||
const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2));
|
||||
|
||||
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
|
||||
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
|
||||
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
|
||||
|
||||
sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
|
||||
|
||||
@@ -1935,11 +1935,6 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contigiuos tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1991,6 +1986,25 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contigiuos tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
|
||||
const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
@@ -2111,6 +2125,26 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session *
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0]; // values
|
||||
const struct ggml_tensor * dst = op; // indices
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != GGML_TYPE_I32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0->ne[0] > (16*1024)) {
|
||||
// reject tensors with huge rows for now
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const int32_t * op_params = &op->op_params[0];
|
||||
|
||||
@@ -2278,6 +2312,9 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
|
||||
case GGML_OP_SUB:
|
||||
req->op = HTP_OP_SUB;
|
||||
break;
|
||||
case GGML_OP_DIV:
|
||||
req->op = HTP_OP_DIV;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
|
||||
break;
|
||||
@@ -2316,6 +2353,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer *
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_ARGSORT;
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
template <bool _is_src0_constant>
|
||||
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
switch (t->op) {
|
||||
@@ -2370,6 +2418,16 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
||||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_SQR:
|
||||
req->op = HTP_OP_SQR;
|
||||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_SQRT:
|
||||
req->op = HTP_OP_SQRT;
|
||||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_UNARY:
|
||||
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
|
||||
req->op = HTP_OP_UNARY_SILU;
|
||||
@@ -2387,6 +2445,9 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
||||
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
|
||||
req->op = HTP_OP_GLU_SWIGLU_OAI;
|
||||
supported = true;
|
||||
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
|
||||
req->op = HTP_OP_GLU_GEGLU;
|
||||
supported = true;
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -2411,6 +2472,17 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
req->op = HTP_OP_SUM_ROWS;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
req->op = HTP_OP_ROPE;
|
||||
@@ -2519,6 +2591,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_ADD_ID:
|
||||
@@ -2528,6 +2601,13 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
case GGML_OP_SCALE:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
|
||||
(ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
|
||||
@@ -2536,7 +2616,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
break;
|
||||
case GGML_OP_GLU:
|
||||
if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) {
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
}
|
||||
break;
|
||||
@@ -2564,6 +2645,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
|
||||
}
|
||||
@@ -2916,6 +3001,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
supp = ggml_hexagon_supported_binary(sess, op);
|
||||
break;
|
||||
|
||||
@@ -2928,6 +3014,15 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SUM_ROWS:
|
||||
supp = ggml_hexagon_supported_sum_rows(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SOFT_MAX:
|
||||
supp = ggml_hexagon_supported_softmax(sess, op);
|
||||
break;
|
||||
@@ -2943,7 +3038,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
case GGML_OP_GLU:
|
||||
{
|
||||
const auto glu_op = ggml_get_glu_op(op);
|
||||
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) {
|
||||
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
|
||||
supp = ggml_hexagon_supported_activations(sess, op);
|
||||
}
|
||||
break;
|
||||
@@ -2968,6 +3063,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_cpy(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_ARGSORT:
|
||||
supp = ggml_hexagon_supported_argsort(sess, op);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include_directories(
|
||||
${HEXAGON_SDK_ROOT}/incs
|
||||
${HEXAGON_SDK_ROOT}/incs/stddef
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../..
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
@@ -21,6 +22,7 @@ add_library(${HTP_LIB} SHARED
|
||||
matmul-ops.c
|
||||
binary-ops.c
|
||||
unary-ops.c
|
||||
sum-rows-ops.c
|
||||
softmax-ops.c
|
||||
act-ops.c
|
||||
rope-ops.c
|
||||
@@ -28,6 +30,7 @@ add_library(${HTP_LIB} SHARED
|
||||
set-rows-ops.c
|
||||
get-rows-ops.c
|
||||
cpy-ops.c
|
||||
argsort-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
|
||||
@@ -410,7 +410,7 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
// gelu = x * sigmoid(1.702 * x) // current implementation
|
||||
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
@@ -516,7 +516,7 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
|
||||
// silu = x * sigmoid(x)
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
|
||||
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
@@ -541,6 +541,143 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static const float GELU_COEF_A = 0.044715f;
|
||||
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
htp_act_preamble3;
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
|
||||
src1_row_size_aligned, src1_row_size, block_size);
|
||||
}
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
for (uint32_t ib = 0; ib < block_size; ib++) {
|
||||
const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));
|
||||
const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));
|
||||
uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));
|
||||
|
||||
// geglu tanh implementation
|
||||
// geglu(x, g) = gelu(x) * g
|
||||
// gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))
|
||||
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x
|
||||
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A
|
||||
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
|
||||
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
|
||||
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI
|
||||
hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res)
|
||||
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
|
||||
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
|
||||
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f
|
||||
hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
|
||||
dst_row_size_aligned, block_size);
|
||||
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
|
||||
src1_row_size_aligned, src1_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
@@ -559,6 +696,12 @@ static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
@@ -593,6 +736,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
act_op_func = unary_gelu_f32;
|
||||
op_type = "gelu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_GEGLU:
|
||||
act_op_func = glu_geglu_f32;
|
||||
op_type = "geglu-f32";
|
||||
break;
|
||||
default:
|
||||
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
|
||||
281
ggml/src/ggml-hexagon/htp/argsort-ops.c
Normal file
281
ggml/src/ggml-hexagon/htp/argsort-ops.c
Normal file
@@ -0,0 +1,281 @@
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
#include <math.h>
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-dma.h"
|
||||
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
struct htp_argsort_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t nrows_per_thread;
|
||||
};
|
||||
|
||||
static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
|
||||
{
|
||||
const HVX_Vector one = Q6_V_vsplat_R(1);
|
||||
const HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
|
||||
HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
|
||||
HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
|
||||
return hvx_vec_get_i32(sum) == 32;
|
||||
}
|
||||
|
||||
// Sorts values and mirrors swaps to indices.
|
||||
static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
|
||||
if (left >= right) return;
|
||||
|
||||
int pivot_idx = (left + right) / 2;
|
||||
float pivot = values[pivot_idx];
|
||||
int i = left;
|
||||
int j = right;
|
||||
|
||||
HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
|
||||
while (i <= j) {
|
||||
// Vectorized scan for i
|
||||
while (i <= j) {
|
||||
// Check if we have at least one full vector
|
||||
if (i + 32 <= j) {
|
||||
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
|
||||
if (all_greater_f32(pivot_vec, vals_vec)) {
|
||||
// If all elements are < pivot, we can skip this whole block
|
||||
i += 32;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Scalar fallback / cleanup
|
||||
if (values[i] < pivot) {
|
||||
i++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Vectorized scan for j
|
||||
while (i <= j) {
|
||||
if (j - 32 >= i) {
|
||||
// Load 32 elements ending at j.
|
||||
// Since we want `values[j] > pivot`, let's load from j-31 to j.
|
||||
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
|
||||
if (all_greater_f32(vals_vec, pivot_vec)) {
|
||||
j -= 32;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (values[j] > pivot) {
|
||||
j--;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (i <= j) {
|
||||
float tmp_val = values[i];
|
||||
values[i] = values[j];
|
||||
values[j] = tmp_val;
|
||||
|
||||
int32_t tmp_idx = indices[i];
|
||||
indices[i] = indices[j];
|
||||
indices[j] = tmp_idx;
|
||||
i++;
|
||||
j--;
|
||||
}
|
||||
}
|
||||
|
||||
if (left < j) quicksort_values_indices_asc(values, indices, left, j);
|
||||
if (i < right) quicksort_values_indices_asc(values, indices, i, right);
|
||||
}
|
||||
|
||||
static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
|
||||
if (left >= right) return;
|
||||
|
||||
int pivot_idx = (left + right) / 2;
|
||||
float pivot = values[pivot_idx];
|
||||
int i = left;
|
||||
int j = right;
|
||||
|
||||
HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
|
||||
|
||||
while (i <= j) {
|
||||
// Vectorized scan for i (values[i] > pivot)
|
||||
while (i <= j) {
|
||||
if (i + 32 <= j) {
|
||||
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
|
||||
if (all_greater_f32(vals_vec, pivot_vec)) {
|
||||
i += 32;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (values[i] > pivot) {
|
||||
i++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Vectorized scan for j (values[j] < pivot)
|
||||
while (i <= j) {
|
||||
if (j - 32 >= i) {
|
||||
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
|
||||
if (all_greater_f32(pivot_vec, vals_vec)) {
|
||||
j -= 32;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (values[j] < pivot) {
|
||||
j--;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (i <= j) {
|
||||
float tmp_val = values[i];
|
||||
values[i] = values[j];
|
||||
values[j] = tmp_val;
|
||||
|
||||
int32_t tmp_idx = indices[i];
|
||||
indices[i] = indices[j];
|
||||
indices[j] = tmp_idx;
|
||||
i++;
|
||||
j--;
|
||||
}
|
||||
}
|
||||
|
||||
if (left < j) quicksort_values_indices_desc(values, indices, left, j);
|
||||
if (i < right) quicksort_values_indices_desc(values, indices, i, right);
|
||||
}
|
||||
|
||||
static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
|
||||
struct htp_ops_context * octx = actx->octx;
|
||||
|
||||
// Unpack context
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
// Scratchpad memory
|
||||
uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
|
||||
|
||||
// Dimensions
|
||||
uint32_t ne00 = src0->ne[0];
|
||||
uint32_t ne01 = src0->ne[1];
|
||||
uint32_t ne02 = src0->ne[2];
|
||||
uint32_t ne03 = src0->ne[3];
|
||||
|
||||
uint32_t nb01 = src0->nb[1];
|
||||
//uint32_t nb02 = src0->nb[2];
|
||||
//uint32_t nb03 = src0->nb[3];
|
||||
|
||||
uint32_t nb1 = dst->nb[1];
|
||||
//uint32_t nb2 = dst->nb[2];
|
||||
//uint32_t nb3 = dst->nb[3];
|
||||
|
||||
// Sort order
|
||||
enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
|
||||
|
||||
// Rows to process
|
||||
uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
uint32_t rows_per_thread = actx->nrows_per_thread;
|
||||
uint32_t start_row = rows_per_thread * i;
|
||||
uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
|
||||
|
||||
// Scratchpad layout:
|
||||
// We need space for one row of float data (values) and one row of int32 indices.
|
||||
// values: ne00 * sizeof(float)
|
||||
// indices: ne00 * sizeof(int32_t)
|
||||
// Padded to 128 bytes.
|
||||
|
||||
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
|
||||
float * values_buf = (float *) spad;
|
||||
int32_t * indices_buf = (int32_t *) (spad + values_size);
|
||||
|
||||
for (uint32_t r = start_row; r < end_row; r++) {
|
||||
uint32_t src_offset = r * nb01;
|
||||
uint32_t dst_offset = r * nb1;
|
||||
|
||||
uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset;
|
||||
|
||||
hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
|
||||
hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
|
||||
|
||||
// Initialize indices
|
||||
for (uint32_t j = 0; j < ne00; j++) {
|
||||
indices_buf[j] = j;
|
||||
}
|
||||
|
||||
// Sort values and mirror swaps to indices
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
|
||||
} else {
|
||||
quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
|
||||
}
|
||||
|
||||
// Copy indices back to DDR
|
||||
hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
int op_argsort(struct htp_ops_context * octx) {
|
||||
// Check supported types
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
// Allocate scratchpad
|
||||
// We need 1 row of float + 1 row of int32 per thread.
|
||||
uint32_t ne00 = octx->src0.ne[0];
|
||||
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
|
||||
size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
|
||||
size_t spad_per_thread = values_size + indices_size;
|
||||
|
||||
// Make sure we round up to 256 for alignment requirements
|
||||
spad_per_thread = hex_round_up(spad_per_thread, 256);
|
||||
|
||||
size_t total_spad_size = spad_per_thread * octx->n_threads;
|
||||
|
||||
if (octx->ctx->vtcm_size < total_spad_size) {
|
||||
FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src0_spad.size = total_spad_size;
|
||||
octx->src0_spad.size_per_thread = spad_per_thread;
|
||||
|
||||
FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
|
||||
octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
|
||||
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
|
||||
octx->src0.data, octx->dst.data);
|
||||
|
||||
uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
|
||||
uint32_t n_jobs = MIN(total_rows, octx->n_threads);
|
||||
|
||||
struct htp_argsort_context actx;
|
||||
actx.octx = octx;
|
||||
actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
|
||||
|
||||
// Run jobs
|
||||
worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -17,121 +17,6 @@
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
|
||||
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
}
|
||||
|
||||
// Dot product of FP32 and FP16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
// Zero-out unused elements
|
||||
// Note that we need to clear both x and y because they may contain NANs
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
x_hf = Q6_V_vand_QV(bmask, x_hf);
|
||||
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
|
||||
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
// Dot product of FP32 and FP16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
|
||||
const void * restrict y,
|
||||
const void * restrict x0,
|
||||
const void * restrict x1,
|
||||
unsigned int n,
|
||||
float s) {
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
|
||||
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
|
||||
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
// Load x (fp16)
|
||||
HVX_Vector x0_hf = vx0[i];
|
||||
HVX_Vector x1_hf = vx1[i];
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x0_hf = vx0[i];
|
||||
HVX_Vector x1_hf = vx1[i];
|
||||
|
||||
// Zero-out unused elements
|
||||
// Note that we need to clear both x and y because they may contain NANs
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
x0_hf = Q6_V_vand_QV(bmask, x0_hf);
|
||||
x1_hf = Q6_V_vand_QV(bmask, x1_hf);
|
||||
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
|
||||
}
|
||||
|
||||
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
|
||||
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
// Dot product of two F16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
@@ -140,8 +25,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
@@ -156,11 +40,10 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
|
||||
// Load x (fp16) and zero-out unused elements
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
@@ -181,12 +64,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
|
||||
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
@@ -204,12 +86,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
|
||||
// Load x (fp16) and zero-out unused elements
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
|
||||
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
|
||||
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
|
||||
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
@@ -222,7 +103,7 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
|
||||
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * s (float)
|
||||
// MAD: y (F32) += x (F16) * s (F32)
|
||||
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
|
||||
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
|
||||
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
||||
@@ -259,15 +140,125 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict
|
||||
}
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32)
|
||||
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y,
|
||||
const void * restrict x0,
|
||||
const void * restrict x1,
|
||||
float s0,
|
||||
float s1,
|
||||
int n) {
|
||||
const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0;
|
||||
const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1;
|
||||
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector S0 = hvx_vec_splat_f16(s0);
|
||||
HVX_Vector S1 = hvx_vec_splat_f16(s1);
|
||||
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
// Multiply x * s -> pair of F32 vectors
|
||||
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
|
||||
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
|
||||
|
||||
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
|
||||
HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
|
||||
|
||||
ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2]));
|
||||
ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1]));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
|
||||
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
|
||||
|
||||
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
|
||||
HVX_Vector xs = xs_p_lo;
|
||||
i = 2 * i; // index for ptr_y
|
||||
|
||||
if (nloe >= 32) {
|
||||
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
nloe -= 32; ++i;
|
||||
xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define FLASH_ATTN_BLOCK_SIZE 128
|
||||
|
||||
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
|
||||
struct htp_fa_context {
|
||||
const struct htp_ops_context * octx;
|
||||
|
||||
struct fastdiv_values src0_div21;
|
||||
struct fastdiv_values src0_div1;
|
||||
|
||||
struct fastdiv_values broadcast_rk2;
|
||||
struct fastdiv_values broadcast_rk3;
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values src3_div2;
|
||||
struct fastdiv_values src3_div3;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t n_blocks;
|
||||
|
||||
size_t size_q_row_padded;
|
||||
size_t size_k_row_padded;
|
||||
size_t size_v_row_padded;
|
||||
|
||||
size_t size_k_block;
|
||||
size_t size_v_block;
|
||||
size_t size_m_block;
|
||||
|
||||
bool is_q_fp32;
|
||||
};
|
||||
|
||||
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
|
||||
assert((size_t) dst % 128 == 0);
|
||||
assert((size_t) src % 128 == 0);
|
||||
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
|
||||
|
||||
const uint32_t nvec = n / VLEN_FP32;
|
||||
const uint32_t nloe = n % VLEN_FP32;
|
||||
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; ++i) {
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
|
||||
}
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_fa_context * factx = (struct htp_fa_context *) data;
|
||||
const struct htp_ops_context * octx = factx->octx;
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const uint32_t neq0 = q->ne[0];
|
||||
const uint32_t neq1 = q->ne[1];
|
||||
@@ -304,18 +295,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
const uint32_t nb2 = dst->nb[2];
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
// total rows in q
|
||||
const uint32_t nr = neq1*neq2*neq3;
|
||||
|
||||
@@ -331,18 +310,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
const uint32_t DV = nev0;
|
||||
|
||||
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
|
||||
const size_t size_q_row_padded = hex_round_up(size_q_row, 128);
|
||||
|
||||
const size_t size_k_row = DK * sizeof(__fp16);
|
||||
const size_t size_v_row = DV * sizeof(__fp16);
|
||||
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
|
||||
|
||||
const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
|
||||
const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
|
||||
|
||||
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
|
||||
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
|
||||
@@ -351,31 +320,28 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
|
||||
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
|
||||
|
||||
for (uint32_t ir = ir0; ir < ir1; ++ir) {
|
||||
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
|
||||
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
|
||||
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
|
||||
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
|
||||
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
|
||||
|
||||
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
|
||||
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
|
||||
const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);
|
||||
const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);
|
||||
|
||||
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
|
||||
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
|
||||
const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);
|
||||
const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);
|
||||
|
||||
// Fetch Q row
|
||||
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
|
||||
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
|
||||
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
|
||||
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
|
||||
const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
|
||||
|
||||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
|
||||
HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
|
||||
|
||||
// Clear accumulator
|
||||
hvx_splat_f32_a(spad_a, 0, DV);
|
||||
@@ -383,40 +349,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
|
||||
const __fp16 * mp_base = NULL;
|
||||
if (mask) {
|
||||
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
|
||||
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
|
||||
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);
|
||||
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);
|
||||
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
|
||||
}
|
||||
|
||||
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
|
||||
|
||||
// Prefetch first two blocks
|
||||
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
|
||||
for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
||||
// K
|
||||
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
|
||||
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
|
||||
uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;
|
||||
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);
|
||||
|
||||
// V
|
||||
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
||||
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
|
||||
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
|
||||
uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;
|
||||
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);
|
||||
|
||||
// Mask
|
||||
if (mask) {
|
||||
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
|
||||
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
|
||||
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
|
||||
// Mask is 1D contiguous for this row
|
||||
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
|
||||
}
|
||||
}
|
||||
|
||||
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
||||
uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
||||
if (factx->is_q_fp32) {
|
||||
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
|
||||
}
|
||||
|
||||
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
|
||||
const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
|
||||
for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
||||
@@ -428,8 +396,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
// Inner loop processing the block from VTCM
|
||||
uint32_t ic = 0;
|
||||
|
||||
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
|
||||
|
||||
// Process in blocks of 32 (VLEN_FP32)
|
||||
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
|
||||
HVX_Vector_x4 scores_x4;
|
||||
@@ -437,22 +403,18 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
||||
// 1. Compute scores
|
||||
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
||||
for (int j = 0; j < VLEN_FP32; j += 2) {
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
||||
if (is_q_fp32) {
|
||||
hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
|
||||
}
|
||||
const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded;
|
||||
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale);
|
||||
}
|
||||
|
||||
HVX_Vector scores = *(HVX_Vector *) scores_arr;
|
||||
|
||||
// 2. Softcap
|
||||
if (logit_softcap != 0.0f) {
|
||||
if (factx->logit_softcap != 0.0f) {
|
||||
scores = hvx_vec_tanh_f32(scores);
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
@@ -460,70 +422,59 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
if (mask) {
|
||||
const __fp16 * mp = m_base + ic;
|
||||
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
|
||||
|
||||
HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
|
||||
|
||||
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
|
||||
|
||||
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
|
||||
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
|
||||
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
|
||||
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
|
||||
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
scores_x4.v[iv] = scores;
|
||||
v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
|
||||
v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
|
||||
}
|
||||
|
||||
{
|
||||
// 4. Online Softmax Update
|
||||
v_max = hvx_vec_reduce_max_f32(v_max);
|
||||
float m_block = hvx_vec_get_f32(v_max);
|
||||
float M_old = M;
|
||||
float M_new = (m_block > M) ? m_block : M;
|
||||
M = M_new;
|
||||
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
|
||||
HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec));
|
||||
M_vec = M_new_vec;
|
||||
|
||||
const float ms = expf(M_old - M_new);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
|
||||
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
|
||||
HVX_Vector scores = scores_x4.v[iv];
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
|
||||
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
|
||||
|
||||
// 5. Accumulate V
|
||||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector*)p_arr = P;
|
||||
*(HVX_Vector *) p_arr = P;
|
||||
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV);
|
||||
}
|
||||
}
|
||||
|
||||
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
|
||||
S = S * ms + hvx_vec_get_f32(p_sum_vec);
|
||||
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
|
||||
}
|
||||
|
||||
// Sync scalars for leftover/next block if needed
|
||||
float M = hvx_vec_get_f32(M_vec);
|
||||
float S = hvx_vec_get_f32(S_vec);
|
||||
|
||||
// Leftover
|
||||
for (; ic < current_block_size; ++ic) {
|
||||
float s_val;
|
||||
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
|
||||
|
||||
if (is_q_fp32) {
|
||||
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
}
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
s_val = logit_softcap * tanhf(s_val);
|
||||
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
|
||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
|
||||
if (factx->logit_softcap != 0.0f) {
|
||||
s_val = factx->logit_softcap * tanhf(s_val);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
@@ -532,37 +483,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
}
|
||||
|
||||
const float Mold = M;
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s_val > M) {
|
||||
M = s_val;
|
||||
ms = expf(Mold - M);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
float ms = hvx_vec_get_f32(ms_vec);
|
||||
S = S * ms + vs;
|
||||
} else {
|
||||
vs = expf(s_val - M);
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
|
||||
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
|
||||
S += vs;
|
||||
}
|
||||
|
||||
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
|
||||
const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
|
||||
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
|
||||
|
||||
S = S * ms + vs;
|
||||
}
|
||||
M_vec = hvx_vec_splat_f32(M);
|
||||
S_vec = hvx_vec_splat_f32(S);
|
||||
|
||||
// Issue DMA for next+1 block (if exists)
|
||||
if (ib + 2 < n_blocks) {
|
||||
if (ib + 2 < factx->n_blocks) {
|
||||
const uint32_t next_ib = ib + 2;
|
||||
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
|
||||
|
||||
// K
|
||||
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
|
||||
dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);
|
||||
|
||||
// V
|
||||
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
||||
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
|
||||
dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);
|
||||
|
||||
// Mask
|
||||
if (mask) {
|
||||
@@ -573,20 +529,26 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
}
|
||||
|
||||
// sinks
|
||||
float M = hvx_vec_get_f32(M_vec);
|
||||
float S = hvx_vec_get_f32(S_vec);
|
||||
|
||||
if (sinks) {
|
||||
const float s = ((float *)((char *) sinks->data))[h];
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s > M) {
|
||||
ms = expf(M - s);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
} else {
|
||||
vs = expf(s - M);
|
||||
}
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
S = S * ms + vs;
|
||||
float ms = hvx_vec_get_f32(ms_vec);
|
||||
S = S * ms + vs;
|
||||
} else {
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);
|
||||
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
|
||||
S += vs;
|
||||
}
|
||||
}
|
||||
|
||||
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
||||
@@ -609,53 +571,73 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
}
|
||||
}
|
||||
|
||||
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = data;
|
||||
flash_attn_ext_f16_thread(octx, i, n);
|
||||
}
|
||||
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
// Check support
|
||||
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
|
||||
k->type != HTP_TYPE_F16 ||
|
||||
v->type != HTP_TYPE_F16) {
|
||||
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
||||
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
|
||||
struct htp_fa_context factx;
|
||||
factx.octx = octx;
|
||||
|
||||
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
||||
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
||||
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
||||
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
||||
factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
||||
factx.src0_div1 = init_fastdiv_values(q->ne[1]);
|
||||
|
||||
factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
||||
factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
||||
factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
||||
factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
||||
|
||||
if (mask) {
|
||||
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
factx.src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
factx.src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
}
|
||||
|
||||
size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
|
||||
size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
|
||||
size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
|
||||
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
|
||||
factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);
|
||||
factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
|
||||
factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
|
||||
|
||||
size_t size_q_block = size_q_row_padded * 1; // single row for now
|
||||
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
size_t size_q_block = factx.size_q_row_padded * 1; // single row for now
|
||||
factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
factx.scale = scale;
|
||||
factx.max_bias = max_bias;
|
||||
factx.logit_softcap = logit_softcap;
|
||||
|
||||
uint32_t n_head = q->ne[2];
|
||||
factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
|
||||
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
|
||||
|
||||
size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
|
||||
|
||||
octx->src0_spad.size_per_thread = size_q_block * 1;
|
||||
octx->src1_spad.size_per_thread = size_k_block * 2;
|
||||
octx->src2_spad.size_per_thread = size_v_block * 2;
|
||||
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
|
||||
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
|
||||
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
|
||||
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
|
||||
octx->dst_spad.size_per_thread = size_vkq_acc;
|
||||
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
||||
@@ -677,7 +659,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
|
||||
@@ -42,32 +42,36 @@ enum htp_data_type {
|
||||
HTP_TYPE_COUNT
|
||||
};
|
||||
|
||||
// These values are manually translated over to HTP
|
||||
// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
|
||||
// Do not reorder first 4 (used as an index)
|
||||
enum htp_op {
|
||||
HTP_OP_MUL = 0,
|
||||
HTP_OP_ADD = 1,
|
||||
HTP_OP_SUB = 2,
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT = 4,
|
||||
HTP_OP_MUL_MAT_ID = 5,
|
||||
HTP_OP_RMS_NORM = 6,
|
||||
HTP_OP_UNARY_SILU = 7,
|
||||
HTP_OP_UNARY_GELU = 8,
|
||||
HTP_OP_GLU_SWIGLU = 9,
|
||||
HTP_OP_GLU_SWIGLU_OAI = 10,
|
||||
HTP_OP_SOFTMAX = 11,
|
||||
HTP_OP_ADD_ID = 12,
|
||||
HTP_OP_ROPE = 13,
|
||||
HTP_OP_FLASH_ATTN_EXT = 14,
|
||||
HTP_OP_SET_ROWS = 15,
|
||||
HTP_OP_SCALE = 16,
|
||||
HTP_OP_GET_ROWS = 17,
|
||||
HTP_OP_CPY = 18,
|
||||
HTP_OP_MUL = 0,
|
||||
HTP_OP_ADD = 1,
|
||||
HTP_OP_SUB = 2,
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT,
|
||||
HTP_OP_MUL_MAT_ID,
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_UNARY_SILU,
|
||||
HTP_OP_UNARY_GELU,
|
||||
HTP_OP_GLU_SWIGLU,
|
||||
HTP_OP_GLU_SWIGLU_OAI,
|
||||
HTP_OP_GLU_GEGLU,
|
||||
HTP_OP_SOFTMAX,
|
||||
HTP_OP_ADD_ID,
|
||||
HTP_OP_ROPE,
|
||||
HTP_OP_FLASH_ATTN_EXT,
|
||||
HTP_OP_SET_ROWS,
|
||||
HTP_OP_GET_ROWS,
|
||||
HTP_OP_SCALE,
|
||||
HTP_OP_CPY,
|
||||
HTP_OP_ARGSORT,
|
||||
HTP_OP_SQR,
|
||||
HTP_OP_SQRT,
|
||||
HTP_OP_SUM_ROWS,
|
||||
INVALID
|
||||
};
|
||||
|
||||
static inline size_t htp_type_block_size(uint32_t t) {
|
||||
static inline size_t htp_t_block_size(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return 1;
|
||||
@@ -103,22 +107,6 @@ static inline size_t htp_type_nbytes(uint32_t t) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static const char * htp_type_name(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return "fp32";
|
||||
case HTP_TYPE_F16:
|
||||
return "fp16";
|
||||
case HTP_TYPE_Q4_0:
|
||||
return "q4_0";
|
||||
case HTP_TYPE_Q8_0:
|
||||
return "q8_0";
|
||||
case HTP_TYPE_MXFP4:
|
||||
return "mxfp4";
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Internal types
|
||||
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
|
||||
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
|
||||
|
||||
@@ -64,25 +64,12 @@ struct htp_ops_context {
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
|
||||
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
|
||||
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
|
||||
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
|
||||
|
||||
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
|
||||
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
|
||||
|
||||
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
|
||||
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
|
||||
|
||||
struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01
|
||||
struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02
|
||||
struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03
|
||||
|
||||
struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00
|
||||
struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01
|
||||
struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
@@ -90,6 +77,7 @@ int op_matmul(struct htp_ops_context * octx);
|
||||
int op_matmul_id(struct htp_ops_context * octx);
|
||||
int op_binary(struct htp_ops_context * octx);
|
||||
int op_unary(struct htp_ops_context * octx);
|
||||
int op_sum_rows(struct htp_ops_context * octx);
|
||||
int op_activations(struct htp_ops_context * octx);
|
||||
int op_softmax(struct htp_ops_context * octx);
|
||||
int op_add_id(struct htp_ops_context * octx);
|
||||
@@ -98,5 +86,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx);
|
||||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
int op_cpy(struct htp_ops_context * octx);
|
||||
int op_argsort(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
||||
@@ -46,127 +46,76 @@
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
// ADD variants
|
||||
// Generic macro to define alignment permutations for an op
|
||||
#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
|
||||
static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
|
||||
static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD);
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
|
||||
|
||||
// Dispatcher logic
|
||||
#define HVX_BINARY_DISPATCHER(OP_NAME) \
|
||||
static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
|
||||
if (hex_is_aligned((void *) dst, 128)) { \
|
||||
if (hex_is_aligned((void *) src0, 128)) { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_aau(dst, src0, src1, num_elems); \
|
||||
} else { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_auu(dst, src0, src1, num_elems); \
|
||||
} \
|
||||
} else { \
|
||||
if (hex_is_aligned((void *) src0, 128)) { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_uau(dst, src0, src1, num_elems); \
|
||||
} else { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_uuu(dst, src0, src1, num_elems); \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
// SUB variants
|
||||
|
||||
static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
// MUL variants
|
||||
|
||||
static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
// Dispatchers
|
||||
|
||||
static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_add_f32_aa(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_add_f32_au(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_add_f32_ua(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_add_f32_uu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_sub_f32_aa(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_sub_f32_au(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_sub_f32_ua(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_sub_f32_uu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_mul_f32_aa(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_mul_f32_au(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_mul_f32_ua(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_mul_f32_uu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
HVX_BINARY_DISPATCHER(hvx_add_f32)
|
||||
HVX_BINARY_DISPATCHER(hvx_sub_f32)
|
||||
HVX_BINARY_DISPATCHER(hvx_mul_f32)
|
||||
|
||||
// Mul-Mul Optimized
|
||||
|
||||
static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
@@ -443,6 +392,68 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Square
|
||||
//
|
||||
|
||||
#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(float); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, v); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128)) {
|
||||
if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_sqr_f32_aa(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_sqr_f32_au(dst, src, num_elems);
|
||||
}
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_sqr_f32_ua(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_sqr_f32_uu(dst, src, num_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef HVX_OP_ADD
|
||||
#undef HVX_OP_SUB
|
||||
#undef HVX_OP_MUL
|
||||
@@ -453,5 +464,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
|
||||
#undef hvx_scalar_loop_body
|
||||
#undef HVX_OP_MIN_SCALAR
|
||||
#undef HVX_OP_CLAMP_SCALAR
|
||||
#undef DEFINE_HVX_BINARY_OP_VARIANTS
|
||||
#undef HVX_BINARY_DISPATCHER
|
||||
|
||||
#endif // HVX_ARITH_H
|
||||
|
||||
@@ -66,6 +66,12 @@ static inline float hvx_vec_get_f32(HVX_Vector v) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
|
||||
int32_t __attribute__((aligned(128))) x;
|
||||
hvx_vec_store_a(&x, 4, v);
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
|
||||
// abs by clearing the fp16 sign bit
|
||||
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
|
||||
|
||||
@@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0); \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(__fp16); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
|
||||
116
ggml/src/ggml-hexagon/htp/hvx-div.h
Normal file
116
ggml/src/ggml-hexagon/htp/hvx-div.h
Normal file
@@ -0,0 +1,116 @@
|
||||
#ifndef HVX_DIV_H
|
||||
#define HVX_DIV_H
|
||||
|
||||
#include <HAP_farf.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include <assert.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hvx-base.h"
|
||||
#include "hex-utils.h"
|
||||
#include "hvx-inverse.h"
|
||||
#include "hvx-arith.h"
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src0_type * restrict vsrc0 = (src0_type *) src0; \
|
||||
src1_type * restrict vsrc1 = (src1_type *) src1; \
|
||||
\
|
||||
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP32; \
|
||||
const uint32_t nloe = n % VLEN_FP32; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
|
||||
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
|
||||
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
// 3-letter suffix variants
|
||||
static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128)) {
|
||||
if (hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_aau(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_auu(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_uau(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_uuu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef HVX_OP_MUL
|
||||
|
||||
#endif // HVX_DIV_H
|
||||
@@ -91,6 +91,27 @@ static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const uint32_t epv = 128 / sizeof(float); \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \
|
||||
vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
@@ -111,4 +132,10 @@ static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * re
|
||||
hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
#endif /* HVX_SIGMOID_H */
|
||||
|
||||
@@ -12,11 +12,17 @@
|
||||
#define RSQRT_ONE_HALF 0x3f000000 // 0.5
|
||||
#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
|
||||
//Algorithm :
|
||||
// x2 = input*0.5
|
||||
// y = * (long *) &input
|
||||
// y = 0x5f3759df - (y>>2)
|
||||
// y = 0x5f3759df - (y>>1)
|
||||
// y = y*(threehalfs - x2*y*y)
|
||||
|
||||
HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
|
||||
@@ -57,4 +63,64 @@ static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
|
||||
return Q6_Vsf_equals_Vqf32(temp);
|
||||
}
|
||||
|
||||
// Compute sqrt(x) as x*inv_sqrt(x)
|
||||
#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP32; \
|
||||
const uint32_t nloe = n % VLEN_FP32; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
|
||||
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
|
||||
vdst[i] = sqrt_res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
|
||||
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {
|
||||
if ((unsigned long) dst % 128 == 0) {
|
||||
if ((unsigned long) src % 128 == 0) {
|
||||
hvx_sqrt_f32_aa(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_sqrt_f32_au(dst, src, num_elems);
|
||||
}
|
||||
} else {
|
||||
if ((unsigned long) src % 128 == 0) {
|
||||
hvx_sqrt_f32_ua(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_sqrt_f32_uu(dst, src, num_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* HVX_SQRT_H */
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "hvx-sigmoid.h"
|
||||
#include "hvx-sqrt.h"
|
||||
#include "hvx-arith.h"
|
||||
#include "hvx-div.h"
|
||||
#include "hvx-base.h"
|
||||
|
||||
#endif /* HVX_UTILS_H */
|
||||
|
||||
@@ -189,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) {
|
||||
// otherwise we'll release it once we're done with the current Op.
|
||||
|
||||
if (ctx->vtcm_inuse) {
|
||||
ctx->vtcm_needs_release = false;
|
||||
ctx->vtcm_needs_release = true;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context * ctx,
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_argsort(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
@@ -679,6 +718,45 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_sum_rows(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_activations_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
@@ -951,6 +1029,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
case HTP_OP_MUL:
|
||||
case HTP_OP_ADD:
|
||||
case HTP_OP_SUB:
|
||||
case HTP_OP_DIV:
|
||||
if (n_bufs != 3) {
|
||||
FARF(ERROR, "Bad binary-req buffer list");
|
||||
continue;
|
||||
@@ -968,6 +1047,25 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
proc_unary_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_SQR:
|
||||
case HTP_OP_SQRT:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad unary-req buffer list");
|
||||
continue;
|
||||
}
|
||||
|
||||
proc_unary_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_SUM_ROWS:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad unary-req buffer list");
|
||||
continue;
|
||||
}
|
||||
|
||||
proc_sum_rows_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_UNARY_SILU:
|
||||
case HTP_OP_UNARY_GELU:
|
||||
if (n_bufs != 2) {
|
||||
@@ -980,6 +1078,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
case HTP_OP_GLU_SWIGLU:
|
||||
case HTP_OP_GLU_SWIGLU_OAI:
|
||||
case HTP_OP_SOFTMAX:
|
||||
case HTP_OP_GLU_GEGLU:
|
||||
if ((n_bufs != 2) && (n_bufs != 3)) {
|
||||
FARF(ERROR, "Bad act-req buffer list");
|
||||
continue;
|
||||
@@ -1035,6 +1134,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
proc_cpy_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_ARGSORT:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad argsort-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_argsort_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Op %u", req.op);
|
||||
break;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
115
ggml/src/ggml-hexagon/htp/sum-rows-ops.c
Normal file
115
ggml/src/ggml-hexagon/htp/sum-rows-ops.c
Normal file
@@ -0,0 +1,115 @@
|
||||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
|
||||
#define sum_rows_preamble \
|
||||
struct htp_tensor *src0 = &octx->src0;\
|
||||
struct htp_tensor *dst = &octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3]; \
|
||||
|
||||
static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
sum_rows_preamble;
|
||||
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
int opt_path = 0;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src = (const uint8_t *) src0->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
|
||||
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
|
||||
|
||||
for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
|
||||
const float * restrict src_local = src_th + (ir * ne00);
|
||||
|
||||
if (ir + 1 < src0_nrows_per_thread) {
|
||||
hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
|
||||
} else {
|
||||
dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
|
||||
sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_sum_rows(struct htp_ops_context * octx) {
|
||||
sum_rows_preamble;
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
const int n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
@@ -132,6 +132,56 @@ static void rms_norm_htp_f32(const float * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static void sqr_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
} else {
|
||||
hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void sqrt_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
} else {
|
||||
hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||
struct htp_tensor * dst,
|
||||
uint8_t * spad,
|
||||
@@ -181,6 +231,12 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||
case HTP_OP_SCALE:
|
||||
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
@@ -218,6 +274,14 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "scale-f32";
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "sqr-f32";
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "sqrt-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
|
||||
@@ -98,6 +98,10 @@ static bool ggml_op_is_empty(enum ggml_op op) {
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool ggml_impl_is_view(const struct ggml_tensor * t) {
|
||||
return t->view_src != NULL;
|
||||
}
|
||||
|
||||
static inline float ggml_compute_softplus_f32(float input) {
|
||||
return (input > 20.0f) ? input : logf(1 + expf(input));
|
||||
}
|
||||
|
||||
@@ -264,15 +264,26 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_GLU:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_REPEAT:
|
||||
return true;
|
||||
default:
|
||||
return ggml_op_is_empty(op);
|
||||
@@ -312,7 +323,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
||||
h_add(mrs1, node0);
|
||||
|
||||
// that many nodes forward to search for a concurrent node
|
||||
constexpr int N_FORWARD = 8;
|
||||
constexpr int N_FORWARD = 64;
|
||||
|
||||
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
|
||||
if (used[i1]) {
|
||||
|
||||
@@ -394,7 +394,7 @@ bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, con
|
||||
[encoder endEncoding];
|
||||
|
||||
ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src);
|
||||
ggml_metal_event_record(ctx_src, ev_cpy);
|
||||
ggml_metal_event_encode_signal(ev_cpy, cmd_buf);
|
||||
|
||||
[cmd_buf commit];
|
||||
|
||||
|
||||
@@ -212,61 +212,69 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const int64_t n = ggml_nelements(op);
|
||||
int op_num = -1;
|
||||
|
||||
const char * op_str = "undefined";
|
||||
switch (op->op) {
|
||||
case GGML_OP_SCALE: op_str = "scale"; break;
|
||||
case GGML_OP_FILL: op_str = "fill"; break;
|
||||
case GGML_OP_CLAMP: op_str = "clamp"; break;
|
||||
case GGML_OP_SQR: op_str = "sqr"; break;
|
||||
case GGML_OP_SQRT: op_str = "sqrt"; break;
|
||||
case GGML_OP_SIN: op_str = "sin"; break;
|
||||
case GGML_OP_COS: op_str = "cos"; break;
|
||||
case GGML_OP_LOG: op_str = "log"; break;
|
||||
case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
|
||||
case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break;
|
||||
case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break;
|
||||
case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break;
|
||||
case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break;
|
||||
case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break;
|
||||
case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break;
|
||||
case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break;
|
||||
case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break;
|
||||
case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_TANH: op_str = "tanh"; break;
|
||||
case GGML_UNARY_OP_RELU: op_str = "relu"; break;
|
||||
case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break;
|
||||
case GGML_UNARY_OP_GELU: op_str = "gelu"; break;
|
||||
case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break;
|
||||
case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break;
|
||||
case GGML_UNARY_OP_SILU: op_str = "silu"; break;
|
||||
case GGML_UNARY_OP_ELU: op_str = "elu"; break;
|
||||
case GGML_UNARY_OP_NEG: op_str = "neg"; break;
|
||||
case GGML_UNARY_OP_ABS: op_str = "abs"; break;
|
||||
case GGML_UNARY_OP_SGN: op_str = "sgn"; break;
|
||||
case GGML_UNARY_OP_STEP: op_str = "step"; break;
|
||||
case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
|
||||
case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
|
||||
case GGML_UNARY_OP_EXP: op_str = "exp"; break;
|
||||
case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
|
||||
case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
|
||||
case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break;
|
||||
case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break;
|
||||
case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break;
|
||||
case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break;
|
||||
case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break;
|
||||
case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break;
|
||||
case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break;
|
||||
case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break;
|
||||
case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break;
|
||||
case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break;
|
||||
case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break;
|
||||
case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break;
|
||||
case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break;
|
||||
case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
|
||||
case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
|
||||
case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
|
||||
case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
} break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
const char * suffix = "";
|
||||
if (n % 4 == 0) {
|
||||
suffix = "_4";
|
||||
}
|
||||
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||
const char * t_str = ggml_type_name(op->type);
|
||||
|
||||
snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
|
||||
snprintf(name, 256, "%s", base);
|
||||
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
||||
const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
|
||||
|
||||
snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
|
||||
ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
res.c4 = is_c4;
|
||||
res.cnt = is_cnt;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -320,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const char * op_str = "undefined";
|
||||
int op_num = -1;
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_SUM_ROWS:
|
||||
op_str = "sum_rows"; break;
|
||||
case GGML_OP_MEAN:
|
||||
op_str = "mean"; break;
|
||||
case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
|
||||
case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
|
||||
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||
const char * t_str = ggml_type_name(op->type);
|
||||
|
||||
snprintf(name, 256, "%s", base);
|
||||
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
||||
|
||||
snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s_op=%d", base, op_num);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
res.smem = 32*sizeof(float);
|
||||
|
||||
if (is_c4) {
|
||||
res.smem *= 4;
|
||||
}
|
||||
|
||||
res.c4 = is_c4;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -1392,34 +1415,78 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v
|
||||
GGML_UNUSED(op);
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
|
||||
ggml_metal_library_t lib,
|
||||
ggml_op op,
|
||||
int32_t n_fuse,
|
||||
bool row) {
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const char * op_str = "undefined";
|
||||
switch (op) {
|
||||
case GGML_OP_ADD: op_str = "add"; break;
|
||||
case GGML_OP_SUB: op_str = "sub"; break;
|
||||
case GGML_OP_MUL: op_str = "mul"; break;
|
||||
case GGML_OP_DIV: op_str = "div"; break;
|
||||
int op_num = -1;
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_ADD: op_num = 0; break;
|
||||
case GGML_OP_SUB: op_num = 1; break;
|
||||
case GGML_OP_MUL: op_num = 2; break;
|
||||
case GGML_OP_DIV: op_num = 3; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
if (row) {
|
||||
snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
|
||||
} else {
|
||||
snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
|
||||
}
|
||||
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||
const char * t1_str = ggml_type_name(op->src[1]->type);
|
||||
const char * t_str = ggml_type_name(op->type);
|
||||
|
||||
snprintf(name, 256, "%s", base);
|
||||
const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
|
||||
|
||||
const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
|
||||
|
||||
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
|
||||
ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
|
||||
ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
res.c4 = is_c4;
|
||||
res.cnt = is_rb;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
int op_num = -1;
|
||||
|
||||
switch (op) {
|
||||
case GGML_OP_ADD: op_num = 0; break;
|
||||
case GGML_OP_SUB: op_num = 1; break;
|
||||
case GGML_OP_MUL: op_num = 2; break;
|
||||
case GGML_OP_DIV: op_num = 3; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
|
||||
snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
|
||||
ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1);
|
||||
ggml_metal_cv_set_bool (cv, false, FC_BIN + 2);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -1428,13 +1495,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_L2_NORM);
|
||||
|
||||
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_l2_norm_f32");
|
||||
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
||||
|
||||
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||
const char * t_str = ggml_type_name(op->type);
|
||||
|
||||
snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
@@ -1442,6 +1511,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
res.c4 = is_c4;
|
||||
res.smem = 32*sizeof(float);
|
||||
|
||||
return res;
|
||||
|
||||
@@ -53,6 +53,9 @@ struct ggml_metal_pipeline_with_params {
|
||||
int nr1;
|
||||
|
||||
size_t smem;
|
||||
|
||||
bool c4;
|
||||
bool cnt;
|
||||
};
|
||||
|
||||
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
|
||||
@@ -134,7 +137,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse );
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
|
||||
@@ -346,10 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
|
||||
|
||||
struct ggml_metal_pipeline_with_params res = {
|
||||
/*.pipeline =*/ nil,
|
||||
/*.nsg =*/ 0,
|
||||
/*.nr0 =*/ 0,
|
||||
/*.nr1 =*/ 0,
|
||||
/*.nsg =*/ 0,
|
||||
/*.smem =*/ 0,
|
||||
/*.c4 =*/ false,
|
||||
/*.cnt =*/ false,
|
||||
};
|
||||
|
||||
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
@@ -362,10 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
|
||||
struct ggml_metal_pipeline_with_params res = {
|
||||
/*.pipeline =*/ nil,
|
||||
/*.nsg =*/ 0,
|
||||
/*.nr0 =*/ 0,
|
||||
/*.nr1 =*/ 0,
|
||||
/*.nsg =*/ 0,
|
||||
/*.smem =*/ 0,
|
||||
/*.c4 =*/ false,
|
||||
/*.cnt =*/ false,
|
||||
};
|
||||
|
||||
[lib->lock lock];
|
||||
@@ -1007,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
}
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_TANH:
|
||||
@@ -1026,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -1054,11 +1067,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_ADD_ID:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ACC:
|
||||
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
@@ -1066,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SUM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_TRI:
|
||||
@@ -1083,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
return has_simdgroup_reduction &&
|
||||
op->src[0]->type == GGML_TYPE_I32 &&
|
||||
@@ -1157,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CONT:
|
||||
|
||||
@@ -80,6 +80,9 @@
|
||||
#define FC_SSM_CONV 900
|
||||
#define FC_SOLVE_TRI 1000
|
||||
#define FC_COUNT_EQUAL 1100
|
||||
#define FC_UNARY 1200
|
||||
#define FC_BIN 1300
|
||||
#define FC_SUM_ROWS 1400
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPSG 8
|
||||
@@ -88,6 +91,37 @@
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
|
||||
|
||||
#define OP_UNARY_NUM_SCALE 10
|
||||
#define OP_UNARY_NUM_FILL 11
|
||||
#define OP_UNARY_NUM_CLAMP 12
|
||||
#define OP_UNARY_NUM_SQR 13
|
||||
#define OP_UNARY_NUM_SQRT 14
|
||||
#define OP_UNARY_NUM_SIN 15
|
||||
#define OP_UNARY_NUM_COS 16
|
||||
#define OP_UNARY_NUM_LOG 17
|
||||
#define OP_UNARY_NUM_LEAKY_RELU 18
|
||||
|
||||
#define OP_UNARY_NUM_TANH 100
|
||||
#define OP_UNARY_NUM_RELU 101
|
||||
#define OP_UNARY_NUM_SIGMOID 102
|
||||
#define OP_UNARY_NUM_GELU 103
|
||||
#define OP_UNARY_NUM_GELU_ERF 104
|
||||
#define OP_UNARY_NUM_GELU_QUICK 105
|
||||
#define OP_UNARY_NUM_SILU 106
|
||||
#define OP_UNARY_NUM_ELU 107
|
||||
#define OP_UNARY_NUM_NEG 108
|
||||
#define OP_UNARY_NUM_ABS 109
|
||||
#define OP_UNARY_NUM_SGN 110
|
||||
#define OP_UNARY_NUM_STEP 111
|
||||
#define OP_UNARY_NUM_HARDSWISH 112
|
||||
#define OP_UNARY_NUM_HARDSIGMOID 113
|
||||
#define OP_UNARY_NUM_EXP 114
|
||||
#define OP_UNARY_NUM_SOFTPLUS 115
|
||||
#define OP_UNARY_NUM_EXPM1 116
|
||||
|
||||
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
|
||||
#define OP_SUM_ROWS_NUM_MEAN 11
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
||||
@@ -123,6 +157,31 @@ typedef struct {
|
||||
int32_t dim;
|
||||
} ggml_metal_kargs_concat;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
float slope;
|
||||
float scale;
|
||||
float bias;
|
||||
float val;
|
||||
float min;
|
||||
float max;
|
||||
} ggml_metal_kargs_unary;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
@@ -180,20 +239,6 @@ typedef struct {
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_repeat;
|
||||
|
||||
typedef struct {
|
||||
float scale;
|
||||
float bias;
|
||||
} ggml_metal_kargs_scale;
|
||||
|
||||
typedef struct {
|
||||
float val;
|
||||
} ggml_metal_kargs_fill;
|
||||
|
||||
typedef struct {
|
||||
float min;
|
||||
float max;
|
||||
} ggml_metal_kargs_clamp;
|
||||
|
||||
typedef struct {
|
||||
int64_t nk0;
|
||||
int64_t ne00;
|
||||
@@ -497,8 +542,21 @@ typedef struct {
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne00_4;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
@@ -880,10 +938,6 @@ typedef struct {
|
||||
int max_period;
|
||||
} ggml_metal_kargs_timestep_embedding;
|
||||
|
||||
typedef struct {
|
||||
float slope;
|
||||
} ggml_metal_kargs_leaky_relu;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
|
||||
@@ -287,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
n_fuse = ggml_metal_op_acc(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
n_fuse = ggml_metal_op_scale(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
n_fuse = ggml_metal_op_fill(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_CLAMP:
|
||||
{
|
||||
n_fuse = ggml_metal_op_clamp(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
@@ -426,10 +418,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_top_k(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_TRI:
|
||||
{
|
||||
n_fuse = ggml_metal_op_tri(ctx, idx);
|
||||
@@ -438,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
{
|
||||
n_fuse = ggml_metal_op_set(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
@@ -628,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
||||
|
||||
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
||||
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
||||
@@ -679,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
ggml_metal_kargs_bin args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.ne00 =*/ ne10,
|
||||
/*.ne01 =*/ ne11,
|
||||
/*.ne02 =*/ ne12,
|
||||
/*.ne03 =*/ ne13,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ pnb1,
|
||||
/*.nb02 =*/ pnb2,
|
||||
@@ -695,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.ne0 =*/ ne10,
|
||||
/*.ne1 =*/ ne11,
|
||||
/*.ne2 =*/ ne12,
|
||||
/*.ne3 =*/ ne13,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ pnb1,
|
||||
/*.nb2 =*/ pnb2,
|
||||
@@ -707,7 +699,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
/*.o1 =*/ { 0 },
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -715,126 +707,19 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
|
||||
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
int nth = 1;
|
||||
|
||||
while (2*nth < args.ne0 && nth < nth_max) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float scale;
|
||||
float bias;
|
||||
memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
|
||||
memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
|
||||
|
||||
ggml_metal_kargs_scale args = {
|
||||
/*.scale =*/ scale,
|
||||
/*.bias =*/ bias,
|
||||
};
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const float val = ggml_get_op_params_f32(op, 0);
|
||||
|
||||
ggml_metal_kargs_fill args = {
|
||||
/*.val =*/ val
|
||||
};
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float min;
|
||||
float max;
|
||||
memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
|
||||
memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
|
||||
|
||||
ggml_metal_kargs_clamp args = {
|
||||
/*.min =*/ min,
|
||||
/*.max =*/ max,
|
||||
};
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -846,19 +731,79 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_kargs_unary args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.slope =*/ 0.0,
|
||||
/*.scale =*/ 0.0,
|
||||
/*.bias =*/ 0.0,
|
||||
/*.val =*/ 0.0,
|
||||
/*.min =*/ 0.0,
|
||||
/*.max =*/ 0.0,
|
||||
};
|
||||
|
||||
if (op->op == GGML_OP_LEAKY_RELU) {
|
||||
args.slope = ggml_get_op_params_f32(op, 0);
|
||||
}
|
||||
|
||||
if (op->op == GGML_OP_SCALE) {
|
||||
args.scale = ggml_get_op_params_f32(op, 0);
|
||||
args.bias = ggml_get_op_params_f32(op, 1);
|
||||
}
|
||||
|
||||
if (op->op == GGML_OP_FILL) {
|
||||
args.val = ggml_get_op_params_f32(op, 0);
|
||||
}
|
||||
|
||||
if (op->op == GGML_OP_CLAMP) {
|
||||
args.min = ggml_get_op_params_f32(op, 0);
|
||||
args.max = ggml_get_op_params_f32(op, 1);
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
if (pipeline.cnt) {
|
||||
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
} else {
|
||||
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
const int nth = MIN(args.ne00, nth_max);
|
||||
|
||||
const int nk0 = (args.ne00 + nth - 1)/nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -969,6 +914,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
@@ -990,21 +940,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
||||
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00);
|
||||
nth = std::min(nth, (int) args.ne00);
|
||||
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
@@ -1664,6 +1619,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
||||
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
||||
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
|
||||
const size_t offs = ((const int32_t *) op->op_params)[3];
|
||||
|
||||
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
|
||||
|
||||
if (!inplace) {
|
||||
// run a separete kernel to cpy src->dst
|
||||
// not sure how to avoid this
|
||||
// TODO: make a simpler cpy_bytes kernel
|
||||
|
||||
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ ne00,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
|
||||
|
||||
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
|
||||
|
||||
int64_t nk0 = ne10;
|
||||
if (ggml_is_quantized(op->src[1]->type)) {
|
||||
nk0 = ne10/16;
|
||||
} else if (ggml_is_quantized(op->type)) {
|
||||
nk0 = ne10/ggml_blck_size(op->type);
|
||||
}
|
||||
|
||||
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
|
||||
// TODO: relax this constraint in the future
|
||||
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
|
||||
if (nth > nk0) {
|
||||
nrptg = (nth + nk0 - 1)/nk0;
|
||||
nth = nk0;
|
||||
|
||||
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nrptg--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nth = std::min<int>(nth, nk0);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ nk0,
|
||||
/*.ne00 =*/ ne10,
|
||||
/*.ne01 =*/ ne11,
|
||||
/*.ne02 =*/ ne12,
|
||||
/*.ne03 =*/ ne13,
|
||||
/*.nb00 =*/ nb10,
|
||||
/*.nb01 =*/ nb11,
|
||||
/*.nb02 =*/ nb12,
|
||||
/*.nb03 =*/ nb13,
|
||||
/*.ne0 =*/ ne10,
|
||||
/*.ne1 =*/ ne11,
|
||||
/*.ne2 =*/ ne12,
|
||||
/*.ne3 =*/ ne13,
|
||||
/*.nb0 =*/ ggml_element_size(op),
|
||||
/*.nb1 =*/ pnb1,
|
||||
/*.nb2 =*/ pnb2,
|
||||
/*.nb3 =*/ pnb3,
|
||||
};
|
||||
|
||||
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
||||
|
||||
bid_dst.offs += offs;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -2895,8 +2978,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
||||
|
||||
bool bcast_row = false;
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
@@ -2990,18 +3071,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
struct ggml_metal_pipeline_with_params pipeline;
|
||||
|
||||
if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
|
||||
// src1 is a row
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
|
||||
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
|
||||
|
||||
bcast_row = true;
|
||||
} else {
|
||||
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
|
||||
}
|
||||
pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
|
||||
|
||||
if (n_fuse > 1) {
|
||||
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
|
||||
@@ -3015,20 +3085,28 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
}
|
||||
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne10 = ne10/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
||||
|
||||
if (bcast_row) {
|
||||
const int64_t n = ggml_nelements(op)/4;
|
||||
if (pipeline.cnt) {
|
||||
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
} else {
|
||||
int nth = 32;
|
||||
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
int nth = 1;
|
||||
|
||||
while (2*nth < args.ne0 && nth < nth_max) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
@@ -3049,39 +3127,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, op->op_params, sizeof(float));
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
ggml_metal_kargs_l2_norm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne00_4 =*/ ne00/4,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.eps =*/ eps,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
||||
|
||||
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00/4);
|
||||
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
const int64_t nrows = ggml_nrows(op->src[0]);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -4089,42 +4187,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float slope;
|
||||
memcpy(&slope, op->op_params, sizeof(float));
|
||||
|
||||
ggml_metal_kargs_leaky_relu args = {
|
||||
/*.slope =*/ slope
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user