mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-09 16:17:31 +03:00
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c58ba3365 | ||
|
|
57ace0d612 | ||
|
|
39b27f0da0 | ||
|
|
f49e917876 | ||
|
|
7c7d6ce5c7 | ||
|
|
5208e2d5ba | ||
|
|
7992aa7c8e | ||
|
|
a1cfb64530 | ||
|
|
5803c8d115 | ||
|
|
63f8fe0ef4 | ||
|
|
223373742b | ||
|
|
e15efe007d | ||
|
|
6137c325a1 | ||
|
|
17193cce34 | ||
|
|
d6dac92bfd | ||
|
|
dae2bf41c9 | ||
|
|
bc07d55922 | ||
|
|
4888137b17 | ||
|
|
fbd441c379 | ||
|
|
c30e012253 | ||
|
|
95a6ebabb2 |
@@ -1,8 +1,8 @@
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG ROCM_VERSION=7.2
|
||||
ARG AMDGPU_VERSION=7.2
|
||||
ARG ROCM_VERSION=7.2.1
|
||||
ARG AMDGPU_VERSION=7.2.1
|
||||
|
||||
# Target the ROCm build image
|
||||
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||
@@ -12,11 +12,11 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||
|
||||
# Unless otherwise specified, we make a fat build.
|
||||
# This is mostly tied to rocBLAS supported archs.
|
||||
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.0/reference/system-requirements.html
|
||||
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.1/reference/system-requirements.html
|
||||
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityrad/native_linux/native_linux_compatibility.html
|
||||
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityryz/native_linux/native_linux_compatibility.html
|
||||
|
||||
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201'
|
||||
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1150;gfx1200;gfx1201'
|
||||
|
||||
# Set ROCm architectures
|
||||
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
|
||||
|
||||
5
.github/labeler.yml
vendored
5
.github/labeler.yml
vendored
@@ -27,6 +27,11 @@ IBM zDNN:
|
||||
- any-glob-to-any-file:
|
||||
- ggml/include/ggml-zdnn.h
|
||||
- ggml/src/ggml-zdnn/**
|
||||
AMD ZenDNN:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- ggml/include/ggml-zendnn.h
|
||||
- ggml/src/ggml-zendnn/**
|
||||
documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
|
||||
4
.github/workflows/build.yml
vendored
4
.github/workflows/build.yml
vendored
@@ -941,7 +941,7 @@ jobs:
|
||||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
run: |
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70201-81~24.04_amd64.deb"
|
||||
7z x rocwmma.deb
|
||||
7z x data.tar
|
||||
|
||||
@@ -984,7 +984,7 @@ jobs:
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.1/include/" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DLLAMA_BUILD_BORINGSSL=ON `
|
||||
-DROCM_DIR="${env:HIP_PATH}" `
|
||||
|
||||
2
.github/workflows/hip-quality-check.yml
vendored
2
.github/workflows/hip-quality-check.yml
vendored
@@ -35,7 +35,7 @@ env:
|
||||
jobs:
|
||||
ubuntu-22-hip-quality-check:
|
||||
runs-on: ubuntu-22.04
|
||||
container: rocm/dev-ubuntu-22.04:7.2
|
||||
container: rocm/dev-ubuntu-22.04:7.2.1
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
|
||||
22
.github/workflows/release.yml
vendored
22
.github/workflows/release.yml
vendored
@@ -639,8 +639,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- ROCM_VERSION: "7.2"
|
||||
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201"
|
||||
- ROCM_VERSION: "7.2.1"
|
||||
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1150;gfx1200;gfx1201"
|
||||
build: 'x64'
|
||||
|
||||
steps:
|
||||
@@ -662,7 +662,7 @@ jobs:
|
||||
sudo apt install -y build-essential git cmake wget
|
||||
|
||||
- name: Setup Legacy ROCm
|
||||
if: matrix.ROCM_VERSION == '7.2'
|
||||
if: matrix.ROCM_VERSION == '7.2.1'
|
||||
id: legacy_env
|
||||
run: |
|
||||
sudo mkdir --parents --mode=0755 /etc/apt/keyrings
|
||||
@@ -683,7 +683,7 @@ jobs:
|
||||
sudo apt-get install -y libssl-dev rocm-hip-sdk
|
||||
|
||||
- name: Setup TheRock
|
||||
if: matrix.ROCM_VERSION != '7.2'
|
||||
if: matrix.ROCM_VERSION != '7.2.1'
|
||||
id: therock_env
|
||||
run: |
|
||||
wget https://repo.amd.com/rocm/tarball/therock-dist-linux-gfx1151-${{ matrix.ROCM_VERSION }}.tar.gz
|
||||
@@ -699,7 +699,6 @@ jobs:
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DCMAKE_HIP_FLAGS="-mllvm --amdgpu-unroll-threshold-local=600" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
@@ -717,17 +716,20 @@ jobs:
|
||||
id: tag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
|
||||
- name: Get ROCm short version
|
||||
run: echo "ROCM_VERSION_SHORT=$(echo '${{ matrix.ROCM_VERSION }}' | cut -d '.' -f 1,2)" >> $GITHUB_ENV
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ env.ROCM_VERSION_SHORT }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ env.ROCM_VERSION_SHORT }}-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-ubuntu-rocm-${{ env.ROCM_VERSION_SHORT }}-${{ matrix.build }}.tar.gz
|
||||
|
||||
windows-hip:
|
||||
runs-on: windows-2022
|
||||
@@ -749,7 +751,7 @@ jobs:
|
||||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
run: |
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70201-81~24.04_amd64.deb"
|
||||
7z x rocwmma.deb
|
||||
7z x data.tar
|
||||
|
||||
@@ -806,7 +808,7 @@ jobs:
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.1/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DGGML_BACKEND_DL=ON `
|
||||
-DGGML_NATIVE=OFF `
|
||||
|
||||
120
AGENTS.md
120
AGENTS.md
@@ -5,78 +5,106 @@
|
||||
>
|
||||
> Read more: [CONTRIBUTING.md](CONTRIBUTING.md)
|
||||
|
||||
AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (see examples below)
|
||||
AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (see examples below).
|
||||
|
||||
---
|
||||
|
||||
## Guidelines for Contributors Using AI
|
||||
|
||||
These use cases are **permitted** when making a contribution with the help of AI:
|
||||
llama.cpp is built by humans, for humans. Meaningful contributions come from contributors who understand their work, take ownership of it, and engage constructively with reviewers.
|
||||
|
||||
- Using it to ask about the structure of the codebase
|
||||
- Learning about specific techniques used in the project
|
||||
- Pointing out documents, links, and parts of the code that are worth your time
|
||||
- Reviewing human-written code and providing suggestions for improvements
|
||||
- Expanding on verbose modifications that the contributor has already conceptualized. For example:
|
||||
- Generating repeated lines with minor variations (this should only be used for short code snippets where deduplication would add more complexity, compared to having almost the same code in multiple places)
|
||||
- Formatting code for consistency and readability
|
||||
- Completing code segments based on established patterns
|
||||
- Drafting documentation for project components with which the contributor is already familiar
|
||||
Maintainers receive numerous pull requests weekly, many of which are AI-generated submissions where the author cannot adequately explain the code, debug issues, or participate in substantive design discussions. Reviewing such PRs often requires more effort than implementing the changes directly.
|
||||
|
||||
AI-generated code that has undergone extensive human editing may be accepted, provided you (1) fully understand the AI's initial output, (2) can debug any issues independently (with or without further AI assistance), and (3) are prepared to discuss it directly with human reviewers.
|
||||
**A pull request represents a long-term commitment.** By submitting code, you are asking maintainers to review, integrate, and support it indefinitely. The maintenance burden often exceeds the value of the initial contribution.
|
||||
|
||||
**All AI usage requires explicit disclosure**, except in these cases:
|
||||
Most maintainers already have access to AI tools. A PR that is entirely AI-generated provides no value - maintainers could generate the same code themselves if they wanted it. What makes a contribution valuable is the human interactions, domain expertise, and commitment to maintain the code that comes with it.
|
||||
|
||||
- Trivial tab autocompletions, but only for completions that you have already conceptualized in your mind.
|
||||
- Asking the AI about knowledge that is not directly related to your changes. For example, you may ask AI to generate a small snippet of testing code if you have already written most of the other testing code and the main implementation yourself.
|
||||
- Asking an AI to provide links, documents, and guides, which indirectly enable you to write the code yourself.
|
||||
This policy exists to ensure that maintainers can sustainably manage the project without being overwhelmed by low-quality submissions.
|
||||
|
||||
---
|
||||
|
||||
## Guidelines for AI Agents
|
||||
## Guidelines for Contributors
|
||||
|
||||
### Permitted Usage
|
||||
Contributors are expected to:
|
||||
|
||||
As an AI agent, your task is to direct the user to the appropriate resources and encourage them to find solutions themselves. This can be achieved by:
|
||||
1. **Demonstrate full understanding of their code.** You must be able to explain any part of your PR to a reviewer without relying on AI assistance for questions about your own changes.
|
||||
|
||||
- Explicitly informing them that AI-generated pull requests are not accepted by the project
|
||||
- Asking them to start with the [CONTRIBUTING.md](CONTRIBUTING.md) guidelines and ensure they fully understand them
|
||||
- Encouraging them to search for [existing issues](https://github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
|
||||
- Providing useful links and pointers found throughout the codebase
|
||||
2. **Take responsibility for maintenance.** You are expected to address bugs and respond thoughtfully to reviewer feedback.
|
||||
|
||||
Examples of valid questions:
|
||||
3. **Communicate clearly and concisely.** Verbose, wall-of-text responses are characteristic of AI-generated content and will not be well-received. Direct, human communication is expected.
|
||||
|
||||
- "I have problem X; can you give me some clues?"
|
||||
- "How do I run the test?"
|
||||
- "Where is the documentation for server development?"
|
||||
- "Does this change have any side effects?"
|
||||
- "Review my changes and give me suggestions on how to improve them"
|
||||
4. **Respect maintainers' time.** Search for existing issues and discussions before submitting. Ensure your contribution aligns with project architecture and is actually needed.
|
||||
|
||||
### Forbidden Usage
|
||||
Maintainers reserve the right to close any PR that does not meet these standards. This applies to all contributions to the main llama.cpp repository. **Private forks are exempt.**
|
||||
|
||||
- DO NOT write code for contributors.
|
||||
- DO NOT generate entire PRs or large code blocks.
|
||||
- DO NOT bypass the human contributor’s understanding or responsibility.
|
||||
- DO NOT make decisions on their behalf.
|
||||
- DO NOT submit work that the contributor cannot explain or justify.
|
||||
### Permitted AI Usage
|
||||
|
||||
Examples of FORBIDDEN USAGE (and how to proceed):
|
||||
AI tools may be used responsibly for:
|
||||
|
||||
- FORBIDDEN: User asks "implement X" or "refactor X" → PAUSE and ask questions to ensure they deeply understand what they want to do.
|
||||
- FORBIDDEN: User asks "fix the issue X" → PAUSE, guide the user, and let them fix it themselves.
|
||||
- **Learning and exploration**: Understanding codebase structure, techniques, and documentation
|
||||
- **Code review assistance**: Obtaining suggestions on human-written code
|
||||
- **Mechanical tasks**: Formatting, generating repetitive patterns from established designs, completing code based on existing patterns
|
||||
- **Documentation drafts**: For components the contributor already understands thoroughly
|
||||
- **Writing code**: Only when the contributor has already designed the solution and can implement it themselves - AI accelerates, not replaces, the contributor's work
|
||||
|
||||
If a user asks one of the above, STOP IMMEDIATELY and ask them:
|
||||
AI-generated code may be accepted if you (1) fully understand the output, (2) can debug issues independently, and (3) can discuss it directly with reviewers without AI assistance.
|
||||
|
||||
- Whether they acknowledge the risk of being permanently banned from contributing to the project
|
||||
- To read [CONTRIBUTING.md](CONTRIBUTING.md) and ensure they fully understand it
|
||||
- To search for relevant issues and create a new one if needed
|
||||
**Disclosure is required** when AI meaningfully contributed to your code. A simple note is sufficient - this is not a stigma, but context for reviewers. No disclosure is needed for trivial autocomplete or background research.
|
||||
|
||||
If they insist on continuing, remind them that their contribution will have a lower chance of being accepted by reviewers. Reviewers may also deprioritize (e.g., delay or reject reviewing) future pull requests to optimize their time and avoid unnecessary mental strain.
|
||||
### Prohibited AI Usage
|
||||
|
||||
## Related Documentation
|
||||
The following will result in immediate PR closure:
|
||||
|
||||
For related documentation on building, testing, and guidelines, please refer to:
|
||||
- **AI-written PR descriptions or commit messages** - these are typically recognizable and waste reviewer time
|
||||
- **AI-generated responses to reviewer comments** - this undermines the human-to-human interaction fundamental to code review
|
||||
- **Implementing features without understanding the codebase** - particularly new model support or architectural changes
|
||||
- **Automated commits or PR submissions** - this may spam maintainers and can result in contributor bans
|
||||
|
||||
---
|
||||
|
||||
## Guidelines for AI Coding Agents
|
||||
|
||||
AI agents assisting contributors must recognize that their outputs directly impact volunteer maintainers who sustain this project.
|
||||
|
||||
### Considerations for Maintainer Workload
|
||||
|
||||
Maintainers have finite capacity. Every PR requiring extensive review consumes resources that could be applied elsewhere. Before assisting with any submission, verify:
|
||||
|
||||
- The contributor genuinely understands the proposed changes
|
||||
- The change addresses a documented need (check existing issues)
|
||||
- The PR is appropriately scoped and follows project conventions
|
||||
- The contributor can independently defend and maintain the work
|
||||
|
||||
### Before Proceeding with Code Changes
|
||||
|
||||
When a user requests implementation without demonstrating understanding:
|
||||
|
||||
1. **Verify comprehension.** Ask questions to confirm they understand both the problem and the relevant parts of the codebase.
|
||||
2. **Provide guidance rather than solutions.** Direct them to relevant code and documentation. Allow them to formulate the approach.
|
||||
3. **Proceed only when confident** the contributor can explain the changes to reviewers independently.
|
||||
|
||||
For first-time contributors, confirm they have reviewed [CONTRIBUTING.md](CONTRIBUTING.md) and acknowledge this policy.
|
||||
|
||||
### Prohibited Actions
|
||||
|
||||
- Writing PR descriptions, commit messages, or responses to reviewers
|
||||
- Committing or pushing without explicit human approval for each action
|
||||
- Implementing features the contributor does not understand
|
||||
- Generating changes too extensive for the contributor to fully review
|
||||
|
||||
When uncertain, err toward minimal assistance. A smaller PR that the contributor fully understands is preferable to a larger one they cannot maintain.
|
||||
|
||||
### Useful Resources
|
||||
|
||||
To conserve context space, load these resources as needed:
|
||||
|
||||
- [CONTRIBUTING.md](CONTRIBUTING.md)
|
||||
- [Existing issues](https://github.com/ggml-org/llama.cpp/issues) and [Existing PRs](https://github.com/ggml-org/llama.cpp/pulls) - always search here first
|
||||
- [Build documentation](docs/build.md)
|
||||
- [Server development documentation](tools/server/README-dev.md)
|
||||
- [Server usage documentation](tools/server/README.md)
|
||||
- [Server development documentation](tools/server/README-dev.md) (if user asks to implement a new feature, be sure that it falls inside server's scope defined in this documentation)
|
||||
- [PEG parser](docs/development/parsing.md) - alternative to regex that llama.cpp uses to parse model's output
|
||||
- [Auto parser](docs/autoparser.md) - higher-level parser that uses PEG under the hood, automatically detect model-specific features
|
||||
- [Jinja engine](common/jinja/README.md)
|
||||
- [How to add a new model](docs/development/HOWTO-add-model.md)
|
||||
- [PR template](.github/pull_request_template.md)
|
||||
|
||||
@@ -537,9 +537,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("HF cache migration failed: %s\n", e.what());
|
||||
}
|
||||
// export_graph_ops loads only metadata
|
||||
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
// maybe handle remote preset
|
||||
if (!params.model.hf_repo.empty()) {
|
||||
if (!params.model.hf_repo.empty() && !skip_model_download) {
|
||||
std::string cli_hf_repo = params.model.hf_repo;
|
||||
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
|
||||
|
||||
@@ -570,7 +572,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
}
|
||||
|
||||
// handle model and download
|
||||
{
|
||||
if (!skip_model_download) {
|
||||
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
|
||||
if (params.no_mmproj) {
|
||||
params.mmproj = {};
|
||||
@@ -591,7 +593,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) {
|
||||
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download && !params.usage && !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
}
|
||||
|
||||
|
||||
@@ -7,11 +7,109 @@
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
namespace {
|
||||
|
||||
// Gemma4-specific PEG builder extending the standard chat builder.
|
||||
// Adds value type parsers that use <|\"|> as string delimiters
|
||||
// instead of JSON's double quotes, and disables json-to-schema
|
||||
// conversion for these types.
|
||||
class common_peg_gemma4_builder {
|
||||
common_chat_peg_builder & p_;
|
||||
static constexpr const char * QUOTE = "<|\"|>";
|
||||
|
||||
public:
|
||||
explicit common_peg_gemma4_builder(common_chat_peg_builder & p) : p_(p) {}
|
||||
|
||||
common_peg_parser gemma4_string() {
|
||||
return p_.rule("gemma4-string", [&]() {
|
||||
return p_.literal(QUOTE) + p_.until(QUOTE) + p_.literal(QUOTE);
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_number() {
|
||||
return p_.rule("gemma4-number", [&]() {
|
||||
auto digit1_9 = p_.chars("[1-9]", 1, 1);
|
||||
auto digits = p_.chars("[0-9]");
|
||||
auto int_part = p_.choice({p_.literal("0"), p_.sequence({digit1_9, p_.chars("[0-9]", 0, -1)})});
|
||||
auto frac = p_.sequence({p_.literal("."), digits});
|
||||
auto exp = p_.sequence({p_.choice({p_.literal("e"), p_.literal("E")}),
|
||||
p_.optional(p_.chars("[+-]", 1, 1)), digits});
|
||||
auto not_number_continuation = p_.negate(p_.chars("[0-9.eE+-]", 1, 1));
|
||||
return p_.sequence({p_.optional(p_.literal("-")), int_part, p_.optional(frac),
|
||||
p_.optional(exp), not_number_continuation});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_bool() {
|
||||
return p_.rule("gemma4-bool", [&]() {
|
||||
return p_.choice({p_.literal("true"), p_.literal("false")});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_null() {
|
||||
return p_.rule("gemma4-null", [&]() {
|
||||
return p_.literal("null");
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_dict() {
|
||||
return p_.rule("gemma4-dict", [&]() {
|
||||
auto ws = p_.space();
|
||||
auto key = p_.until(":");
|
||||
auto member = p_.sequence({key, p_.literal(":"), ws, gemma4_value()});
|
||||
auto members = p_.sequence({member, p_.zero_or_more(p_.sequence({p_.literal(","), ws, member}))});
|
||||
return p_.sequence({
|
||||
p_.literal("{"), ws,
|
||||
p_.choice({p_.literal("}"), p_.sequence({members, ws, p_.literal("}")})})
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_array() {
|
||||
return p_.rule("gemma4-array", [&]() {
|
||||
auto ws = p_.space();
|
||||
auto elements = p_.sequence({gemma4_value(), p_.zero_or_more(p_.sequence({p_.literal(","), ws, gemma4_value()}))});
|
||||
return p_.sequence({
|
||||
p_.literal("["), ws,
|
||||
p_.choice({p_.literal("]"), p_.sequence({elements, ws, p_.literal("]")})})
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_value() {
|
||||
return p_.rule("gemma4-value", [&]() {
|
||||
return p_.choice({gemma4_string(), gemma4_dict(), gemma4_array(),
|
||||
gemma4_number(), gemma4_bool(), gemma4_null()});
|
||||
});
|
||||
}
|
||||
|
||||
// Select the appropriate value parser based on JSON schema type.
|
||||
// Does NOT use schema() - the gemma4 types are pure PEG without
|
||||
// JSON schema metadata, so GBNF is generated directly from the
|
||||
// PEG structure.
|
||||
common_peg_parser gemma4_value_for_type(const json & schema) {
|
||||
if (!schema.contains("type") || !schema.at("type").is_string()) {
|
||||
return gemma4_value();
|
||||
}
|
||||
std::string type = schema.at("type").get<std::string>();
|
||||
if (type == "string") { return gemma4_string(); }
|
||||
if (type == "number") { return gemma4_number(); }
|
||||
if (type == "integer") { return gemma4_number(); }
|
||||
if (type == "boolean") { return gemma4_bool(); }
|
||||
if (type == "object") { return gemma4_dict(); }
|
||||
if (type == "array") { return gemma4_array(); }
|
||||
return gemma4_value();
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// Helper to iterate over tools/functions
|
||||
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
||||
for (const auto & tool : tools) {
|
||||
@@ -43,7 +141,9 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.format = (autoparser.tools.format.mode == tool_format::TAG_WITH_GEMMA4_DICT)
|
||||
? COMMON_CHAT_FORMAT_PEG_GEMMA4
|
||||
: COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = autoparser.preserved_tokens;
|
||||
|
||||
auto parser = autoparser.build_parser(inputs);
|
||||
@@ -92,6 +192,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons
|
||||
|
||||
ctx.extracting_reasoning = extract_reasoning && reasoning.mode != reasoning_mode::NONE;
|
||||
ctx.content = &content;
|
||||
ctx.reasoning = &reasoning;
|
||||
|
||||
// Build reasoning parser
|
||||
ctx.reasoning_parser = reasoning.build_parser(ctx);
|
||||
@@ -100,6 +201,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons
|
||||
|
||||
bool has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
bool pure_content = reasoning.mode == reasoning_mode::NONE;
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
@@ -107,12 +209,14 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons
|
||||
p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"),
|
||||
response_format
|
||||
}) + p.end();
|
||||
pure_content = false;
|
||||
} else if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
|
||||
parser = tools.build_parser(ctx);
|
||||
pure_content = false;
|
||||
} else {
|
||||
parser = content.build_parser(ctx);
|
||||
}
|
||||
return p.prefix(inputs.generation_prompt, reasoning.start) + parser;
|
||||
return pure_content ? p.prefix(inputs.generation_prompt, reasoning.start) + parser : p.prefix(inputs.generation_prompt, reasoning.start) << parser;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -166,6 +270,8 @@ common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const
|
||||
return build_tool_parser_tag_json(ctx);
|
||||
case tool_format::TAG_WITH_TAGGED:
|
||||
return build_tool_parser_tag_tagged(ctx);
|
||||
case tool_format::TAG_WITH_GEMMA4_DICT:
|
||||
return build_tool_parser_tag_gemma4_dict(ctx);
|
||||
default:
|
||||
LOG_ERR("[ERROR] Template seems to support tool calls, but failed to determine tool format. Tool calling will not work properly. "
|
||||
"Check for a fixed template for your model in the models/templates directory of your llama.cpp installation or "
|
||||
@@ -430,4 +536,121 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_gemma4_builder g4(p);
|
||||
static const std::string QUOTE = "<|\"|>";
|
||||
|
||||
common_peg_parser tool_choice = p.choice();
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & func = tool.at("function");
|
||||
std::string name = func.at("name");
|
||||
const auto & params = func.at("parameters");
|
||||
|
||||
if (!params.contains("properties") || !params.at("properties").is_object()) {
|
||||
auto func_parser = p.atomic(
|
||||
p.tool_open(p.literal(function.name_prefix) + p.tool_name(p.literal(name)) + p.literal("{")) +
|
||||
p.tool_args(p.eps()) +
|
||||
p.tool_close(p.literal("}")));
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto & properties = params.at("properties");
|
||||
std::set<std::string> required;
|
||||
if (params.contains("required") && params.at("required").is_array()) {
|
||||
params.at("required").get_to(required);
|
||||
}
|
||||
|
||||
// Build per-argument parsers, sorted alphabetically (matching template's dictsort)
|
||||
struct arg_entry {
|
||||
std::string param_name;
|
||||
common_peg_parser parser;
|
||||
};
|
||||
std::vector<arg_entry> arg_entries;
|
||||
|
||||
for (const auto & [param_name, param_schema] : properties.items()) {
|
||||
std::string type = "object";
|
||||
auto type_v = param_schema.contains("type") ? param_schema.at("type") : json::object();
|
||||
if (type_v.is_string()) type_v.get_to(type);
|
||||
|
||||
common_peg_parser value_parser = p.eps();
|
||||
if (type == "string") {
|
||||
// String values are delimited by <|"|>...<|"|>
|
||||
value_parser =
|
||||
p.literal(QUOTE) +
|
||||
p.tool_arg_string_value(p.schema(p.until(QUOTE),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema", param_schema, true)) +
|
||||
p.literal(QUOTE);
|
||||
} else if (type == "number" || type == "integer") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_number());
|
||||
} else if (type == "boolean") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_bool());
|
||||
} else if (type == "null") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_null());
|
||||
} else if (type == "object") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_dict());
|
||||
} else if (type == "array") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_array());
|
||||
} else {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_value());
|
||||
}
|
||||
|
||||
auto arg = p.tool_arg(
|
||||
p.tool_arg_open(p.tool_arg_name(p.literal(param_name)) + p.literal(":")) +
|
||||
value_parser +
|
||||
p.tool_arg_close(p.eps()));
|
||||
|
||||
arg_entries.push_back({param_name, p.rule("tool-" + name + "-arg-" + param_name, arg)});
|
||||
}
|
||||
|
||||
// Sort alphabetically to match Jinja's dictsort
|
||||
std::sort(arg_entries.begin(), arg_entries.end(), [](const auto & a, const auto & b) {
|
||||
return a.param_name < b.param_name;
|
||||
});
|
||||
|
||||
// Build arg sequence: any arg, then zero-or-more comma-separated additional args
|
||||
common_peg_parser args_seq = p.eps();
|
||||
if (!arg_entries.empty()) {
|
||||
common_peg_parser any_arg = p.choice();
|
||||
for (auto & entry : arg_entries) {
|
||||
any_arg |= entry.parser;
|
||||
}
|
||||
args_seq = p.optional(
|
||||
any_arg + p.repeat(p.literal(",") + any_arg, 0, (int) arg_entries.size() - 1));
|
||||
}
|
||||
|
||||
// Full parser: call:name{args}
|
||||
auto func_parser = p.atomic(
|
||||
p.tool_open(p.literal(function.name_prefix) + p.tool_name(p.literal(name)) + p.literal("{")) +
|
||||
p.tool_args(args_seq) +
|
||||
p.tool_close(p.literal("}")));
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
// Wrap each call in <|tool_call>...</tool_call|>
|
||||
auto wrapped_call = p.literal(format.per_call_start) + tool_choice + p.literal(format.per_call_end);
|
||||
|
||||
common_peg_parser tool_calls = p.eps();
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call));
|
||||
} else {
|
||||
tool_calls = p.trigger_rule("tool-call", wrapped_call);
|
||||
}
|
||||
|
||||
if (!force_tools) {
|
||||
tool_calls = p.optional(tool_calls);
|
||||
}
|
||||
|
||||
auto content_before_tools = p.until_one_of({ format.per_call_start, ctx.reasoning->start });
|
||||
return ctx.reasoning_parser +
|
||||
(force_tools ? p.eps() : p.optional(p.content(content_before_tools) + p.optional(ctx.reasoning_parser))) +
|
||||
tool_calls + p.end();
|
||||
}
|
||||
|
||||
} // namespace autoparser
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "common.h"
|
||||
#include "jinja/caps.h"
|
||||
#include "peg-parser.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <chrono>
|
||||
#include <optional>
|
||||
@@ -144,6 +145,7 @@ enum class tool_format {
|
||||
JSON_NATIVE, // Pure JSON: {"name": "X", "arguments": {...}}
|
||||
TAG_WITH_JSON, // Tag-based with JSON args: <function=X>{...}</function>
|
||||
TAG_WITH_TAGGED, // Tag-based with tagged args: <param=key>value</param>
|
||||
TAG_WITH_GEMMA4_DICT, // Gemma4 custom dict: <|tool_call>call:name{key:<|"|>val<|"|>}<tool_call|>
|
||||
};
|
||||
|
||||
inline std::ostream & operator<<(std::ostream & os, const tool_format & format) {
|
||||
@@ -156,6 +158,8 @@ inline std::ostream & operator<<(std::ostream & os, const tool_format & format)
|
||||
return os << "TAG_WITH_JSON";
|
||||
case tool_format::TAG_WITH_TAGGED:
|
||||
return os << "TAG_WITH_TAGGED";
|
||||
case tool_format::TAG_WITH_GEMMA4_DICT:
|
||||
return os << "TAG_WITH_GEMMA4_DICT";
|
||||
default:
|
||||
return os << "UNKNOWN";
|
||||
}
|
||||
@@ -212,12 +216,14 @@ struct tool_id_analysis {
|
||||
// ============================================================================
|
||||
|
||||
struct analyze_content;
|
||||
struct analyze_reasoning;
|
||||
|
||||
struct parser_build_context {
|
||||
common_chat_peg_builder & p;
|
||||
const generation_params & inputs;
|
||||
const generation_params & inputs;
|
||||
common_peg_parser reasoning_parser;
|
||||
bool extracting_reasoning = false;
|
||||
const analyze_reasoning * reasoning = nullptr;
|
||||
const analyze_content * content = nullptr;
|
||||
|
||||
parser_build_context(common_chat_peg_builder & p, const generation_params & inputs);
|
||||
@@ -350,6 +356,7 @@ struct analyze_tools : analyze_base {
|
||||
common_peg_parser build_tool_parser_json_native(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_json(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_tagged(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_gemma4_dict(parser_build_context & ctx) const;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -92,6 +92,34 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: Functionary 3.1]\n" ANSI_RESET);
|
||||
}
|
||||
},
|
||||
// Gemma4 - custom dict format: <|tool_call>call:name{key:<|"|>val<|"|>}<tool_call|>
|
||||
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
|
||||
if (tmpl.src.find("'<|tool_call>call:'") != std::string::npos) {
|
||||
analysis.tools.format.mode = tool_format::TAG_WITH_GEMMA4_DICT;
|
||||
analysis.tools.format.per_call_start = "<|tool_call>";
|
||||
analysis.tools.format.per_call_end = "<tool_call|>";
|
||||
analysis.tools.format.section_start = "";
|
||||
analysis.tools.format.section_end = "";
|
||||
analysis.tools.function.name_prefix = "call:";
|
||||
analysis.tools.function.name_suffix = "";
|
||||
analysis.tools.arguments.start = "{";
|
||||
analysis.tools.arguments.end = "}";
|
||||
analysis.tools.arguments.name_prefix = "";
|
||||
analysis.tools.arguments.name_suffix = ":";
|
||||
analysis.tools.arguments.separator = ",";
|
||||
analysis.reasoning.mode = reasoning_mode::TAG_BASED;
|
||||
analysis.reasoning.start = "<|channel>thought";
|
||||
analysis.reasoning.end = "<channel|>";
|
||||
analysis.preserved_tokens.clear();
|
||||
analysis.preserved_tokens.push_back("<|tool_call>");
|
||||
analysis.preserved_tokens.push_back("<tool_call|>");
|
||||
analysis.preserved_tokens.push_back("<|tool_response>");
|
||||
analysis.preserved_tokens.push_back("<tool_response|>");
|
||||
analysis.preserved_tokens.push_back("<|\"|>");
|
||||
analysis.preserved_tokens.push_back("<|turn>");
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: Gemma4]\n" ANSI_RESET);
|
||||
}
|
||||
},
|
||||
// DeepSeek-R1-Distill-Qwen
|
||||
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
|
||||
if (tmpl.src.find(
|
||||
|
||||
@@ -75,6 +75,84 @@ static std::string escape_json_string_inner(const std::string & s) {
|
||||
return escaped;
|
||||
}
|
||||
|
||||
static const std::string GEMMA4_QUOTE = "<|\"|>";
|
||||
|
||||
static std::string normalize_gemma4_to_json(const std::string & input) {
|
||||
std::string result;
|
||||
result.reserve(input.size() * 2);
|
||||
|
||||
enum Ctx { DICT, ARRAY };
|
||||
std::vector<Ctx> ctx;
|
||||
|
||||
auto is_ws = [](char c) { return c == ' ' || c == '\t' || c == '\n' || c == '\r'; };
|
||||
auto skip_ws = [&](size_t & pos) {
|
||||
while (pos < input.size() && is_ws(input[pos])) {
|
||||
result += input[pos++];
|
||||
}
|
||||
};
|
||||
|
||||
auto quote_unquoted_key = [&](size_t & pos) {
|
||||
if (pos < input.size() && input[pos] != '"' && input[pos] != '}') {
|
||||
result += '"';
|
||||
while (pos < input.size() && input[pos] != ':' && !is_ws(input[pos])) {
|
||||
result += input[pos++];
|
||||
}
|
||||
result += '"';
|
||||
skip_ws(pos);
|
||||
}
|
||||
};
|
||||
|
||||
size_t i = 0;
|
||||
while (i < input.size()) {
|
||||
if (i + GEMMA4_QUOTE.size() <= input.size() &&
|
||||
input.compare(i, GEMMA4_QUOTE.size(), GEMMA4_QUOTE) == 0) {
|
||||
result += '"';
|
||||
i += GEMMA4_QUOTE.size();
|
||||
continue;
|
||||
}
|
||||
|
||||
char c = input[i];
|
||||
|
||||
if (c == '{') {
|
||||
result += c;
|
||||
ctx.push_back(DICT);
|
||||
++i;
|
||||
skip_ws(i);
|
||||
quote_unquoted_key(i);
|
||||
continue;
|
||||
}
|
||||
if (c == '}') {
|
||||
result += c;
|
||||
if (!ctx.empty()) ctx.pop_back();
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
if (c == '[') {
|
||||
result += c;
|
||||
ctx.push_back(ARRAY);
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
if (c == ']') {
|
||||
result += c;
|
||||
if (!ctx.empty()) ctx.pop_back();
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
if (c == ',' && !ctx.empty() && ctx.back() == DICT) {
|
||||
result += c;
|
||||
++i;
|
||||
skip_ws(i);
|
||||
quote_unquoted_key(i);
|
||||
continue;
|
||||
}
|
||||
|
||||
result += c;
|
||||
++i;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Convert Python-style single-quoted strings to JSON double-quoted strings
|
||||
// Only converts outer string delimiters, properly handling escape sequences:
|
||||
// - {'key': 'value'} -> {"key": "value"}
|
||||
@@ -214,6 +292,14 @@ std::string & common_chat_peg_mapper::args_target() {
|
||||
return (current_tool && !current_tool->name.empty()) ? current_tool->arguments : args_buffer;
|
||||
}
|
||||
|
||||
std::string common_chat_peg_mapper::normalize_container_value(const std::string & input) {
|
||||
return normalize_quotes_to_json(input);
|
||||
}
|
||||
|
||||
std::string common_chat_peg_gemma4_mapper::normalize_container_value(const std::string & input) {
|
||||
return normalize_quotes_to_json(normalize_gemma4_to_json(input));
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
|
||||
const common_peg_parse_result & parse_result_arg) {
|
||||
arena.visit(parse_result_arg, [this](const common_peg_ast_node & node) { map(node); });
|
||||
@@ -352,7 +438,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
||||
// For potential containers, normalize Python-style single quotes to JSON double quotes
|
||||
bool is_potential_container = value_content[0] == '[' || value_content[0] == '{';
|
||||
if (is_potential_container) {
|
||||
value_content = normalize_quotes_to_json(value_content);
|
||||
value_content = normalize_container_value(value_content);
|
||||
}
|
||||
|
||||
// Try to parse as JSON value (number, bool, null, object, array)
|
||||
|
||||
@@ -17,7 +17,9 @@ class common_chat_peg_mapper {
|
||||
|
||||
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
|
||||
virtual void map(const common_peg_ast_node & node);
|
||||
private:
|
||||
protected:
|
||||
virtual std::string normalize_container_value(const std::string & input);
|
||||
private:
|
||||
// Tool call handling state
|
||||
std::optional<common_chat_tool_call> pending_tool_call; // Tool call waiting for name
|
||||
common_chat_tool_call * current_tool = nullptr;
|
||||
@@ -30,6 +32,13 @@ class common_chat_peg_mapper {
|
||||
std::string & args_target();
|
||||
};
|
||||
|
||||
class common_chat_peg_gemma4_mapper : public common_chat_peg_mapper {
|
||||
public:
|
||||
common_chat_peg_gemma4_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||
protected:
|
||||
std::string normalize_container_value(const std::string & input) override;
|
||||
};
|
||||
|
||||
struct content_structure;
|
||||
struct tool_call_structure;
|
||||
|
||||
|
||||
121
common/chat.cpp
121
common/chat.cpp
@@ -13,6 +13,8 @@
|
||||
#include "jinja/caps.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <ctime>
|
||||
@@ -694,6 +696,8 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
return "peg-simple";
|
||||
case COMMON_CHAT_FORMAT_PEG_NATIVE:
|
||||
return "peg-native";
|
||||
case COMMON_CHAT_FORMAT_PEG_GEMMA4:
|
||||
return "peg-gemma4";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
@@ -760,12 +764,12 @@ static void foreach_parameter(const json &
|
||||
}
|
||||
}
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
static std::string common_chat_template_direct_apply_impl(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override,
|
||||
const std::optional<json> & tools_override,
|
||||
const std::optional<json> & additional_context) {
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt) {
|
||||
jinja::context ctx(tmpl.source());
|
||||
|
||||
nlohmann::ordered_json inp = nlohmann::ordered_json{
|
||||
@@ -812,6 +816,12 @@ std::string common_chat_template_direct_apply(
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
@@ -862,7 +872,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "[THINK]";
|
||||
data.thinking_end_tag = "[/THINK]";
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
"[THINK]",
|
||||
@@ -945,7 +955,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
|
||||
auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
auto prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
|
||||
// Check if we need to replace the return token with end token during
|
||||
// inference and without generation prompt. For more details see:
|
||||
@@ -980,15 +990,19 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
auto channel = p.literal("<|channel|>") + (p.literal("commentary") | p.literal("analysis"));
|
||||
auto constrain_type = p.chars("[A-Za-z0-9_-]", 1, -1);
|
||||
|
||||
// Occasionally, gpt-oss-20b will prefix channels with this commentary
|
||||
auto stray_commentary = p.optional(p.literal("<|channel|>commentary") + p.optional(p.literal(" to=assistant")));
|
||||
auto start_analysis = stray_commentary + p.literal("<|channel|>analysis<|message|>");
|
||||
|
||||
if (extract_reasoning) {
|
||||
p.rule("analysis", p.literal("<|channel|>analysis<|message|>") + p.reasoning(content) + end);
|
||||
p.rule("analysis", start_analysis + p.reasoning(content) + end);
|
||||
} else {
|
||||
p.rule("analysis", p.content(p.literal("<|channel|>analysis<|message|>") + content + end));
|
||||
p.rule("analysis", p.content(start_analysis + content + end));
|
||||
}
|
||||
|
||||
auto analysis = p.ref("analysis");
|
||||
auto preamble = p.rule("preamble", p.literal("<|channel|>commentary<|message|>") + p.content(content) + end);
|
||||
auto final_msg = p.rule("final", p.literal("<|channel|>final<|message|>") + p.content(content));
|
||||
auto final_msg = p.rule("final", stray_commentary + p.literal("<|channel|>final<|message|>") + p.content(content));
|
||||
|
||||
// Consume any unsolicited tool calls, e.g. builtin functions
|
||||
auto unsolicited = p.rule("unsolicited", p.atomic(p.optional(channel) + p.literal(" to=") + content + end));
|
||||
@@ -996,7 +1010,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
auto any = p.rule("any", preamble | analysis);
|
||||
|
||||
if (has_response_format) {
|
||||
auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type);
|
||||
auto constraint = p.optional(p.space() + p.optional(p.literal("<|constrain|>")) + constrain_type);
|
||||
auto response_format = p.rule("response-format",
|
||||
p.literal("<|channel|>final") + constraint + p.literal("<|message|>") +
|
||||
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
@@ -1013,7 +1027,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
const auto & params = function.at("parameters");
|
||||
|
||||
auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name));
|
||||
auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type);
|
||||
auto constraint = p.optional(p.space() + p.optional(p.literal("<|constrain|>")) + constrain_type);
|
||||
auto args = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", params));
|
||||
|
||||
// recipient in role header
|
||||
@@ -1054,6 +1068,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^\\s+to$" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^<\\|channel\\|>(?:commentary|analysis)\\s+to=functions$" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(\\s+to)" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(<\\|channel\\|>(?:commentary|analysis)\\s+to)" }
|
||||
};
|
||||
@@ -1067,7 +1082,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
">>>all",
|
||||
@@ -1161,7 +1176,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
@@ -1284,7 +1299,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
@@ -1363,7 +1378,7 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
@@ -1434,7 +1449,7 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = false;
|
||||
data.preserved_tokens = {
|
||||
@@ -1540,6 +1555,50 @@ static void requires_non_null_content(json & messages) {
|
||||
}
|
||||
}
|
||||
|
||||
// Gemma4 uses a custom tool_responses field instead of role:tool messages.
|
||||
// Convert consecutive role:tool messages into a single user message with tool_responses.
|
||||
static void convert_tool_responses_gemma4(json & messages) {
|
||||
json result = json::array();
|
||||
size_t i = 0;
|
||||
while (i < messages.size()) {
|
||||
if (messages[i].contains("role") && messages[i].at("role") == "tool") {
|
||||
json tool_responses = json::array();
|
||||
while (i < messages.size() &&
|
||||
messages[i].contains("role") &&
|
||||
messages[i].at("role") == "tool") {
|
||||
const auto & tool_msg = messages[i];
|
||||
std::string name;
|
||||
if (tool_msg.contains("tool_call_id") && tool_msg.at("tool_call_id").is_string()) {
|
||||
name = tool_msg.at("tool_call_id");
|
||||
} else if (tool_msg.contains("name") && tool_msg.at("name").is_string()) {
|
||||
name = tool_msg.at("name");
|
||||
}
|
||||
json response;
|
||||
if (tool_msg.contains("content")) {
|
||||
const auto & content = tool_msg.at("content");
|
||||
if (content.is_string()) {
|
||||
// Try to parse the content as JSON; fall back to raw string
|
||||
try {
|
||||
response = json::parse(content.get<std::string>());
|
||||
} catch (...) {
|
||||
response = content;
|
||||
}
|
||||
} else {
|
||||
response = content;
|
||||
}
|
||||
}
|
||||
tool_responses.push_back({{"name", name}, {"response", response}});
|
||||
i++;
|
||||
}
|
||||
result.push_back({{"role", "user"}, {"tool_responses", tool_responses}});
|
||||
} else {
|
||||
result.push_back(messages[i]);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
messages = result;
|
||||
}
|
||||
|
||||
static void func_args_not_string(json & messages) {
|
||||
GGML_ASSERT(messages.is_array());
|
||||
for (auto & message : messages) {
|
||||
@@ -1668,10 +1727,14 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
workaround::func_args_not_string(params.messages);
|
||||
}
|
||||
|
||||
if (src.find("'<|tool_call>call:'") != std::string::npos) {
|
||||
workaround::convert_tool_responses_gemma4(params.messages);
|
||||
}
|
||||
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||
params.generation_prompt = diff.right;
|
||||
|
||||
@@ -1705,11 +1768,11 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
common_chat_params data;
|
||||
auto params_copy = params;
|
||||
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.generation_prompt = params.generation_prompt;
|
||||
auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) {
|
||||
return p.prefix(params.generation_prompt) + p.content(p.rest());
|
||||
return p.prefix(params.generation_prompt) << p.content(p.rest());
|
||||
});
|
||||
data.parser = parser.save();
|
||||
return data;
|
||||
@@ -1852,8 +1915,13 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
// Try to extract any partial results from what was successfully parsed
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
std::unique_ptr<common_chat_peg_mapper> mapper;
|
||||
if (params.format == COMMON_CHAT_FORMAT_PEG_GEMMA4) {
|
||||
mapper = std::make_unique<common_chat_peg_gemma4_mapper>(msg);
|
||||
} else {
|
||||
mapper = std::make_unique<common_chat_peg_mapper>(msg);
|
||||
}
|
||||
mapper->from_ast(ctx.ast, result);
|
||||
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "\nAST for partial parse (fail):\n%s\n", ctx.ast.dump().c_str());
|
||||
@@ -1868,8 +1936,13 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
std::unique_ptr<common_chat_peg_mapper> mapper;
|
||||
if (params.format == COMMON_CHAT_FORMAT_PEG_GEMMA4) {
|
||||
mapper = std::make_unique<common_chat_peg_gemma4_mapper>(msg);
|
||||
} else {
|
||||
mapper = std::make_unique<common_chat_peg_mapper>(msg);
|
||||
}
|
||||
mapper->from_ast(ctx.ast, result);
|
||||
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "\nAST for %s parse:\n%s\n", is_partial ? "partial" : "full", ctx.ast.dump().c_str());
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "jinja/parser.h"
|
||||
#include "nlohmann/json_fwd.hpp"
|
||||
#include "peg-parser.h"
|
||||
#include "jinja/parser.h"
|
||||
#include "jinja/runtime.h"
|
||||
#include "jinja/caps.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include "nlohmann/json_fwd.hpp"
|
||||
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
@@ -19,8 +19,6 @@
|
||||
using chat_template_caps = jinja::caps;
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
struct common_chat_templates;
|
||||
|
||||
namespace autoparser {
|
||||
@@ -75,41 +73,9 @@ struct common_chat_template {
|
||||
const std::string & bos_token() const { return bos_tok; }
|
||||
const std::string & eos_token() const { return eos_tok; }
|
||||
|
||||
// TODO: this is ugly, refactor it somehow
|
||||
json add_system(const json & messages, const std::string & system_prompt) const {
|
||||
GGML_ASSERT(messages.is_array());
|
||||
auto msgs_copy = messages;
|
||||
if (!caps.supports_system_role) {
|
||||
if (msgs_copy.empty()) {
|
||||
msgs_copy.insert(msgs_copy.begin(), json{
|
||||
{"role", "user"},
|
||||
{"content", system_prompt}
|
||||
});
|
||||
} else {
|
||||
auto & first_msg = msgs_copy[0];
|
||||
if (!first_msg.contains("content")) {
|
||||
first_msg["content"] = "";
|
||||
}
|
||||
first_msg["content"] = system_prompt + "\n\n"
|
||||
+ first_msg["content"].get<std::string>();
|
||||
}
|
||||
} else {
|
||||
if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") {
|
||||
msgs_copy.insert(msgs_copy.begin(), json{
|
||||
{"role", "system"},
|
||||
{"content", system_prompt}
|
||||
});
|
||||
} else if (msgs_copy[0].at("role") == "system") {
|
||||
msgs_copy[0]["content"] = system_prompt;
|
||||
}
|
||||
}
|
||||
return msgs_copy;
|
||||
}
|
||||
|
||||
chat_template_caps original_caps() const {
|
||||
return caps;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct common_chat_msg {
|
||||
@@ -184,6 +150,7 @@ enum common_chat_format {
|
||||
// These are intended to be parsed by the PEG parser
|
||||
COMMON_CHAT_FORMAT_PEG_SIMPLE,
|
||||
COMMON_CHAT_FORMAT_PEG_NATIVE,
|
||||
COMMON_CHAT_FORMAT_PEG_GEMMA4,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
@@ -256,8 +223,8 @@ common_chat_templates_ptr common_chat_templates_init(const struct llama_model *
|
||||
const std::string & bos_token_override = "",
|
||||
const std::string & eos_token_override = "");
|
||||
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||
|
||||
struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs);
|
||||
@@ -274,9 +241,9 @@ std::string common_chat_format_example(const struct common_chat_templates *
|
||||
bool use_jinja,
|
||||
const std::map<std::string, std::string> & chat_template_kwargs);
|
||||
|
||||
const char * common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
const char * common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
|
||||
// used by arg and server
|
||||
const char * common_reasoning_format_name(common_reasoning_format format);
|
||||
@@ -302,7 +269,4 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt);
|
||||
const autoparser::generation_params & inputs);
|
||||
|
||||
@@ -1442,6 +1442,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
|
||||
mparams.progress_callback = params.load_progress_callback;
|
||||
mparams.progress_callback_user_data = params.load_progress_callback_user_data;
|
||||
mparams.no_alloc = params.no_alloc;
|
||||
|
||||
return mparams;
|
||||
}
|
||||
|
||||
@@ -679,6 +679,7 @@ struct common_params {
|
||||
// return false from callback to abort model loading or true to continue
|
||||
llama_progress_callback load_progress_callback = NULL;
|
||||
void * load_progress_callback_user_data = NULL;
|
||||
bool no_alloc = false; // Don't allocate model buffers
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
|
||||
@@ -1557,6 +1557,36 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
||||
|
||||
// GBNF generation implementation
|
||||
void common_peg_arena::build_grammar(const common_grammar_builder & builder, bool lazy) const {
|
||||
auto schema_delegates = [](const common_peg_schema_parser & s) -> bool {
|
||||
if (!s.schema) {
|
||||
return true;
|
||||
}
|
||||
if (s.raw && s.schema->contains("type") && s.schema->at("type").is_string() && s.schema->at("type") == "string") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Unwrap the parser so we can properly check if it's a sequence or choice
|
||||
auto effective_parser = [&](common_peg_parser_id id) -> const common_peg_parser_variant & {
|
||||
while (true) {
|
||||
const auto & p = parsers_.at(id);
|
||||
if (const auto * tag = std::get_if<common_peg_tag_parser>(&p)) {
|
||||
id = tag->child;
|
||||
} else if (const auto * atomic = std::get_if<common_peg_atomic_parser>(&p)) {
|
||||
id = atomic->child;
|
||||
} else if (const auto * schema = std::get_if<common_peg_schema_parser>(&p)) {
|
||||
if (schema_delegates(*schema)) {
|
||||
id = schema->child;
|
||||
} else {
|
||||
return p;
|
||||
}
|
||||
} else {
|
||||
return p;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Generate GBNF for a parser
|
||||
std::function<std::string(common_peg_parser_id)> to_gbnf = [&](common_peg_parser_id id) -> std::string {
|
||||
const auto & parser = parsers_.at(id);
|
||||
@@ -1577,7 +1607,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
s += " ";
|
||||
}
|
||||
auto child_gbnf = to_gbnf(child);
|
||||
const auto & child_parser = parsers_.at(child);
|
||||
const auto & child_parser = effective_parser(child);
|
||||
if (std::holds_alternative<common_peg_choice_parser>(child_parser) ||
|
||||
std::holds_alternative<common_peg_sequence_parser>(child_parser)) {
|
||||
s += "(" + child_gbnf + ")";
|
||||
@@ -1593,7 +1623,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
s += " | ";
|
||||
}
|
||||
auto child_gbnf = to_gbnf(child);
|
||||
const auto & child_parser = parsers_.at(child);
|
||||
const auto & child_parser = effective_parser(child);
|
||||
if (std::holds_alternative<common_peg_choice_parser>(child_parser)) {
|
||||
s += "(" + child_gbnf + ")";
|
||||
} else {
|
||||
@@ -1603,7 +1633,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return s;
|
||||
} else if constexpr (std::is_same_v<T, common_peg_repetition_parser>) {
|
||||
auto child_gbnf = to_gbnf(p.child);
|
||||
const auto & child_parser = parsers_.at(p.child);
|
||||
const auto & child_parser = effective_parser(p.child);
|
||||
if (std::holds_alternative<common_peg_choice_parser>(child_parser) ||
|
||||
std::holds_alternative<common_peg_sequence_parser>(child_parser)) {
|
||||
child_gbnf = "(" + child_gbnf + ")";
|
||||
@@ -1663,15 +1693,10 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
}
|
||||
return gbnf_excluding_pattern(p.delimiters);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
if (p.schema) {
|
||||
if (p.raw && p.schema->contains("type") && p.schema->at("type").is_string() && p.schema->at("type") == "string") {
|
||||
// TODO: Implement more comprehensive grammar generation for raw strings.
|
||||
// For now, use the grammar emitted from the underlying parser.
|
||||
return to_gbnf(p.child);
|
||||
}
|
||||
return builder.add_schema(p.name, *p.schema);
|
||||
if (schema_delegates(p)) {
|
||||
return to_gbnf(p.child);
|
||||
}
|
||||
return to_gbnf(p.child);
|
||||
return builder.add_schema(p.name, *p.schema);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
return p.name;
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ref_parser>) {
|
||||
|
||||
@@ -1164,7 +1164,7 @@ class TextModel(ModelBase):
|
||||
if (n_experts := self.find_hparam(["num_local_experts", "num_experts"], optional=True)) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
logger.info(f"gguf: expert count = {n_experts}")
|
||||
if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True)) is not None:
|
||||
if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token", "top_k_experts"], optional=True)) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
logger.info(f"gguf: experts used count = {n_experts_used}")
|
||||
if (n_expert_groups := self.hparams.get("n_group")) is not None:
|
||||
@@ -6878,7 +6878,9 @@ class Gemma2Model(TextModel):
|
||||
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
|
||||
class Gemma3Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA3
|
||||
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
|
||||
|
||||
def norm_shift(self, name: str) -> float:
|
||||
return 1.0 if name.endswith("norm.weight") else 0.0 # Gemma3RMSNorm adds 1.0 to the norm value
|
||||
|
||||
def set_vocab(self):
|
||||
if (self.dir_model / "tokenizer.model").is_file():
|
||||
@@ -6916,17 +6918,22 @@ class Gemma3Model(TextModel):
|
||||
|
||||
# remove OOV (out-of-vocabulary) rows in token_embd
|
||||
if "embed_tokens.weight" in name:
|
||||
n_vocab_real = -1
|
||||
if (self.dir_model / "tokenizer.model").is_file():
|
||||
tokens = self._create_vocab_sentencepiece()[0]
|
||||
n_vocab_real = len(tokens)
|
||||
else:
|
||||
tokens = self.get_vocab_base()[0]
|
||||
data_torch = data_torch[:len(tokens)]
|
||||
with open(self.dir_model / "tokenizer.json", "r", encoding="utf-8") as f:
|
||||
tokenizer_json = json.load(f)
|
||||
n_vocab_real = len(tokenizer_json["model"]["vocab"]) + len(tokenizer_json["added_tokens"])
|
||||
data_torch = data_torch[:n_vocab_real]
|
||||
|
||||
# ref code in Gemma3RMSNorm
|
||||
# output = output * (1.0 + self.weight.float())
|
||||
# note: this is not the case on gemma3n
|
||||
if name.endswith("norm.weight"):
|
||||
data_torch = data_torch + self.norm_shift
|
||||
f_shift = self.norm_shift(name)
|
||||
if f_shift != 0.0:
|
||||
data_torch = data_torch + f_shift
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
@@ -7100,7 +7107,8 @@ class ConformerAudioModel(MmprojModel):
|
||||
assert data_torch.shape[2] == 1
|
||||
data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[1])
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min"))
|
||||
yield (mapped_name, data_torch)
|
||||
|
||||
|
||||
@ModelBase.register("DeepseekOCRForCausalLM")
|
||||
@@ -7289,7 +7297,6 @@ class Gemma3nVisionAudioModel(ConformerAudioModel):
|
||||
@ModelBase.register("Gemma3nForCausalLM", "Gemma3nForConditionalGeneration")
|
||||
class Gemma3NModel(Gemma3Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA3N
|
||||
norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
|
||||
|
||||
_altup_proj: list[Tensor] = []
|
||||
_altup_unembd: list[Tensor] = []
|
||||
@@ -7308,6 +7315,10 @@ class Gemma3NModel(Gemma3Model):
|
||||
torch.Tensor(), # to be replaced
|
||||
]
|
||||
|
||||
def norm_shift(self, name: str) -> float:
|
||||
del name
|
||||
return 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
|
||||
|
||||
def set_vocab(self):
|
||||
# For Gemma3n multimodal models, we need the FULL vocab_size (262400)
|
||||
# which includes special tokens from 262144-262399 for vision/audio.
|
||||
@@ -7425,6 +7436,212 @@ class Gemma3NModel(Gemma3Model):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Gemma4ForConditionalGeneration")
|
||||
class Gemma4Model(Gemma3Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA4
|
||||
|
||||
def norm_shift(self, name: str) -> float:
|
||||
del name # unused
|
||||
return 0.0
|
||||
|
||||
def set_vocab(self):
|
||||
vocab = gguf.LlamaHfVocab(self.dir_model)
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
visible_tokens = {"<|channel>", "<channel|>", "<|tool_call>", "<tool_call|>", "<|tool_response>", "<tool_response|>", "<|\"|>"}
|
||||
|
||||
for text, score, toktype in vocab.all_tokens():
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
text_str = text.decode()
|
||||
if text_str in visible_tokens:
|
||||
# always render these tokens, so that the chat parser can read them
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
logger.info(f"Token '{text_str}' is set to USER_DEFINED")
|
||||
else:
|
||||
toktypes.append(toktype)
|
||||
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
|
||||
# TODO @ngxson : there are some known (rare) issues with the tokenizer during development
|
||||
# but I don't have time to dive into them right now;
|
||||
# using a dedicated tokenizer name so that we can fix later without re-converting GGUF
|
||||
self.gguf_writer.add_tokenizer_model("gemma4")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
self.gguf_writer.add_add_bos_token(False) # already added via the chat template
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
num_kv_shared_layers = self.hparams["num_kv_shared_layers"]
|
||||
self.gguf_writer.add_shared_kv_layers(num_kv_shared_layers)
|
||||
|
||||
# per-layer embedding is optional
|
||||
n_pl_embd = self.hparams.get("hidden_size_per_layer_input") or 0
|
||||
self.gguf_writer.add_embedding_length_per_layer_input(n_pl_embd)
|
||||
|
||||
swa_layers = [t == "sliding_attention" for t in self.hparams["layer_types"]]
|
||||
self.gguf_writer.add_sliding_window_pattern(swa_layers)
|
||||
|
||||
head_dim_full = self.hparams["global_head_dim"]
|
||||
head_dim_swa = self.hparams["head_dim"]
|
||||
# correct the head dim for global/swa layers
|
||||
self.gguf_writer.add_key_length(head_dim_full)
|
||||
self.gguf_writer.add_value_length(head_dim_full)
|
||||
self.gguf_writer.add_key_length_swa(head_dim_swa)
|
||||
self.gguf_writer.add_value_length_swa(head_dim_swa)
|
||||
|
||||
expert_intermediate_size = self.find_hparam(["expert_intermediate_size", "moe_intermediate_size"])
|
||||
if expert_intermediate_size is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
|
||||
|
||||
# if use_double_wide_mlp is set, we need to adjust the value for kv shared layers
|
||||
use_double_wide_mlp = self.hparams.get("use_double_wide_mlp", False)
|
||||
first_kv_shared_layer_idx = self.block_count - num_kv_shared_layers
|
||||
if use_double_wide_mlp:
|
||||
n_ff = self.hparams["intermediate_size"]
|
||||
n_ff_arr = [n_ff if il < first_kv_shared_layer_idx else n_ff * 2 for il in range(self.block_count)]
|
||||
self.gguf_writer.add_feed_forward_length(n_ff_arr)
|
||||
|
||||
# handle num_global_key_value_heads
|
||||
num_key_value_heads_full = self.hparams.get("num_global_key_value_heads")
|
||||
num_key_value_heads_swa = self.hparams.get("num_key_value_heads")
|
||||
if num_key_value_heads_full is not None and num_key_value_heads_swa is not None:
|
||||
value_arr = [num_key_value_heads_swa if is_swa else num_key_value_heads_full for is_swa in swa_layers]
|
||||
self.gguf_writer.add_head_count_kv(value_arr)
|
||||
|
||||
# handle n_rot differently for global vs swa layers
|
||||
partial_rotary_factor_swa = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
n_rot_full = int(head_dim_full) # "proportional" is used, see generate_extra_tensors
|
||||
n_rot_swa = int(head_dim_swa * partial_rotary_factor_swa)
|
||||
self.gguf_writer.add_rope_dimension_count(n_rot_full)
|
||||
self.gguf_writer.add_rope_dimension_count_swa(n_rot_swa)
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# full layer uses "proportional" rope with partial_rotary_factor=0.25
|
||||
# the expected ordering is cc000000ss000000 (c = cos, s = sin, 0 = unrotated),
|
||||
# but ggml neox only supports ccss000000000000, and we cannot rearrange the head because that will break use_alternative_attention
|
||||
# solution is to set specific freq_factors for the unrotated dims
|
||||
|
||||
# IMPORTANT: this ROPE_FREQS tensor is ONLY used by the full_attention layers
|
||||
rope_params_full = self.hparams["rope_parameters"]["full_attention"]
|
||||
assert rope_params_full["rope_type"] == "proportional"
|
||||
head_dim_full = (self.hparams["global_head_dim"])
|
||||
partial_rotary_factor_full = rope_params_full["partial_rotary_factor"]
|
||||
n_rot_full = int(head_dim_full * partial_rotary_factor_full / 2)
|
||||
n_unrot_full = int(head_dim_full / 2) - n_rot_full
|
||||
values = [1.0] * n_rot_full + [1e30] * n_unrot_full
|
||||
rope_freqs_full = torch.tensor(values, dtype=torch.float32)
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), rope_freqs_full)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.endswith("per_dim_scale") or name.endswith("layer_scalar"):
|
||||
name = name + ".weight"
|
||||
|
||||
if "language_model." not in name and "rope_freqs" not in name:
|
||||
return # skip non-language model tensors
|
||||
|
||||
name = name.replace("language_model.", "")
|
||||
if name.endswith("router.scale"):
|
||||
name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_INP, bid, ".scale")
|
||||
yield (name, data_torch)
|
||||
return
|
||||
if ".per_expert_scale" in name:
|
||||
# convert per-expert scale to FFN down scale
|
||||
name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_EXP, bid, ".scale")
|
||||
yield (name, data_torch)
|
||||
return
|
||||
if ".experts." in name and not name.endswith(".weight"):
|
||||
name += ".weight"
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Gemma4ForConditionalGeneration")
|
||||
class Gemma4VisionAudioModel(MmprojModel):
|
||||
has_audio_encoder = True
|
||||
has_vision_encoder = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
self.hparams_vision["image_size"] = 224 # unused, but set to avoid error
|
||||
|
||||
# remap audio hparams
|
||||
if self.hparams_audio:
|
||||
self.hparams_audio["feat_in"] = self.hparams_audio.get("input_feat_size", 128)
|
||||
self.hparams_audio["intermediate_size"] = self.hparams_audio["hidden_size"] * 4
|
||||
else:
|
||||
self.has_audio_encoder = False
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# vision params
|
||||
self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.GEMMA4V)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))
|
||||
|
||||
# audio params
|
||||
if self.hparams_audio:
|
||||
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.GEMMA4A)
|
||||
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])
|
||||
self.gguf_writer.add_audio_attention_layernorm_eps(1e-5)
|
||||
|
||||
def is_audio_tensor(self, name: str) -> bool:
|
||||
return "audio_tower" in name or "embed_audio" in name
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
if self.is_audio_tensor(name):
|
||||
if ".conv" in name or "_conv" in name and ".weight" in name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
if "position_embedding_table" in name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
if name.startswith("model.language_model."):
|
||||
return # skip
|
||||
|
||||
if len(data_torch.shape) == 0:
|
||||
# convert scalar tensors (input/output_mix/max) to 1D tensors
|
||||
data_torch = data_torch.unsqueeze(0)
|
||||
|
||||
if self.is_audio_tensor(name):
|
||||
assert self.hparams_audio is not None
|
||||
name = name.replace("model.audio_tower.", "conformer.")
|
||||
name = name.replace(".linear.", ".")
|
||||
if name.endswith("per_dim_key_scale") or name.endswith("per_dim_scale"):
|
||||
name = name + ".weight"
|
||||
data_torch = torch.nn.functional.softplus(data_torch)
|
||||
if "lconv1d.depthwise_conv1d" in name and name.endswith(".weight"):
|
||||
assert data_torch.shape[1] == 1
|
||||
data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[2])
|
||||
mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min"))
|
||||
yield (mapped_name, data_torch)
|
||||
|
||||
else:
|
||||
name = name.replace("model.vision_tower.encoder.", "vision_model.model.")
|
||||
name = name.replace(".linear.weight", ".weight")
|
||||
if name.endswith("layer_scalar") or name.endswith("position_embedding_table"):
|
||||
name = name + ".weight"
|
||||
if name.endswith("patch_embedder.input_proj.weight"):
|
||||
n_embd, ksize_sq_c = data_torch.shape
|
||||
patch_size = int((ksize_sq_c // 3) ** 0.5)
|
||||
data_torch = data_torch.reshape(n_embd, patch_size, patch_size, 3)
|
||||
data_torch = data_torch.permute(0, 3, 1, 2).contiguous()
|
||||
mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min"))
|
||||
yield (mapped_name, data_torch)
|
||||
|
||||
|
||||
@ModelBase.register("Starcoder2ForCausalLM")
|
||||
class StarCoder2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.STARCODER2
|
||||
|
||||
@@ -15,13 +15,18 @@ static bool run(llama_context * ctx, const common_params & params) {
|
||||
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos, true);
|
||||
|
||||
if (tokens.empty()) {
|
||||
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
LOG_INF("number of input tokens = %zu\n", tokens.size());
|
||||
for (size_t i = 0; i < tokens.size(); ++i) {
|
||||
LOG_INF(" %d\n", tokens[i]);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
|
||||
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 9)
|
||||
set(GGML_VERSION_PATCH 10)
|
||||
set(GGML_VERSION_PATCH 11)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
@@ -2231,6 +2231,22 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session *
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_UNUSED(sess);
|
||||
return true;
|
||||
}
|
||||
|
||||
enum dspqbuf_type {
|
||||
DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,
|
||||
DSPQBUF_TYPE_CPU_WRITE_DSP_READ,
|
||||
@@ -2399,6 +2415,16 @@ static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bu
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_cumsum_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_CUMSUM;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_GET_ROWS;
|
||||
|
||||
@@ -2780,6 +2806,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
ggml_hexagon_dispatch_op<init_ssm_conv_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_CUMSUM:
|
||||
ggml_hexagon_dispatch_op<init_cumsum_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
|
||||
}
|
||||
@@ -3254,6 +3284,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_ssm_conv(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_CUMSUM:
|
||||
supp = ggml_hexagon_supported_cumsum(sess, op);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ add_library(${HTP_LIB} SHARED
|
||||
repeat-ops.c
|
||||
argsort-ops.c
|
||||
ssm-conv.c
|
||||
cumsum-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
|
||||
267
ggml/src/ggml-hexagon/htp/cumsum-ops.c
Normal file
267
ggml/src/ggml-hexagon/htp/cumsum-ops.c
Normal file
@@ -0,0 +1,267 @@
|
||||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-types.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-dma.h"
|
||||
|
||||
#define htp_cumsum_tensors_preamble \
|
||||
struct htp_tensor * restrict src0 = &octx->src0; \
|
||||
struct htp_tensor * restrict dst = &octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct htp_cumsum_context {
|
||||
struct htp_ops_context * octx;
|
||||
size_t src_row_size;
|
||||
size_t dst_row_size;
|
||||
size_t src_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
uint32_t rows_per_thread;
|
||||
uint32_t total_rows;
|
||||
};
|
||||
|
||||
#define htp_cumsum_preamble \
|
||||
struct htp_cumsum_context * cctx = (struct htp_cumsum_context *) data; \
|
||||
struct htp_ops_context * octx = cctx->octx; \
|
||||
htp_cumsum_tensors_preamble; \
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HVX prefix scan helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#if __HVX_ARCH__ > 75
|
||||
static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) {
|
||||
return Q6_Vsf_vadd_VsfVsf(a, b);
|
||||
}
|
||||
#else
|
||||
static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) {
|
||||
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b));
|
||||
}
|
||||
#endif // __HVX_ARCH__ > 75
|
||||
|
||||
static inline HVX_Vector hvx_prefix_scan_f32(HVX_Vector v, HVX_Vector carry_in) {
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
|
||||
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 4));
|
||||
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 8));
|
||||
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 16));
|
||||
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 32));
|
||||
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 64));
|
||||
v = hvx_cumsum_vadd(v, carry_in);
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_splat_last_f32(HVX_Vector v) {
|
||||
return hvx_vec_repl4(Q6_V_vror_VR(v, 124));
|
||||
}
|
||||
|
||||
static inline void hvx_cumsum_row_f32(const float * restrict src, float * restrict dst, uint32_t n) {
|
||||
const uint32_t nvec = n / VLEN_FP32;
|
||||
const uint32_t nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector carry = Q6_V_vsplat_R(0);
|
||||
|
||||
for (uint32_t i = 0; i < nvec; i++) {
|
||||
HVX_Vector v = *((const HVX_UVector *) (src + i * VLEN_FP32));
|
||||
v = hvx_prefix_scan_f32(v, carry);
|
||||
hvx_vec_store_u(dst + i * VLEN_FP32, VLEN, v);
|
||||
carry = hvx_splat_last_f32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
float acc = hvx_vec_get_f32(carry);
|
||||
const float * src_tail = src + nvec * VLEN_FP32;
|
||||
float * dst_tail = dst + nvec * VLEN_FP32;
|
||||
for (uint32_t i = 0; i < nloe; i++) {
|
||||
acc += src_tail[i];
|
||||
dst_tail[i] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per thread worker: Double-buffered DMA
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static void cumsum_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_cumsum_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint32_t ir0 = cctx->rows_per_thread * ith;
|
||||
const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows);
|
||||
|
||||
if (ir0 >= ir1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t src_row_size = cctx->src_row_size;
|
||||
const size_t dst_row_size = cctx->dst_row_size;
|
||||
const size_t src_row_size_aligned = cctx->src_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = cctx->dst_row_size_aligned;
|
||||
|
||||
const uint8_t * src_data = (const uint8_t *) src0->data;
|
||||
uint8_t * dst_data = (uint8_t *) dst->data;
|
||||
|
||||
uint8_t * src_spad = octx->src0_spad.data + (ith * src_row_size_aligned * 2);
|
||||
uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned * 2);
|
||||
|
||||
for (uint32_t ir = ir0, spad_idx = 0; ir < ir1 && spad_idx < 2; ir++, spad_idx++) {
|
||||
// Dummy dst writeback to establish queue ordering
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(dst_data, dst_spad + (spad_idx * dst_row_size_aligned)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src_spad + (spad_idx * src_row_size_aligned),
|
||||
src_data + (ir * src_row_size)),
|
||||
src_row_size_aligned, src_row_size, 1);
|
||||
}
|
||||
|
||||
for (uint32_t ir = ir0; ir < ir1; ir++) {
|
||||
float * dst_spad_row = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src_spad_row = (float *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
hvx_cumsum_row_f32(src_spad_row, dst_spad_row, ne00);
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(dst_data + (ir * dst_row_size), (uint8_t *) dst_spad_row),
|
||||
dst_row_size, dst_row_size_aligned, 1);
|
||||
|
||||
const uint32_t next_row = ir + 2;
|
||||
if (next_row < ir1) {
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr((uint8_t *) src_spad_row, src_data + (next_row * src_row_size)),
|
||||
src_row_size_aligned, src_row_size, 1);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "cumsum-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n",
|
||||
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per thread worker: Direct HVX (no DMA)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static void cumsum_thread_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_cumsum_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint8_t * src_data = (const uint8_t *) src0->data;
|
||||
uint8_t * dst_data = (uint8_t *) dst->data;
|
||||
|
||||
const uint32_t ir0 = cctx->rows_per_thread * ith;
|
||||
const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows);
|
||||
|
||||
for (uint32_t ir = ir0; ir < ir1; ir++) {
|
||||
const float * restrict src_row = (const float *) (src_data + ir * cctx->src_row_size);
|
||||
float * restrict dst_row = (float *) (dst_data + ir * cctx->dst_row_size);
|
||||
hvx_cumsum_row_f32(src_row, dst_row, ne00);
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "cumsum-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n",
|
||||
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
int op_cumsum_f32(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
const uint32_t total_rows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, total_rows);
|
||||
|
||||
const size_t src_row_size = src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
// 2 ping-pong buffers per thread for src and dst
|
||||
const size_t spad_per_thread = 2 * (src_row_size_aligned + dst_row_size_aligned);
|
||||
|
||||
octx->src0_spad.size_per_thread = src_row_size_aligned * 2;
|
||||
octx->dst_spad.size_per_thread = dst_row_size_aligned * 2;
|
||||
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
||||
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
struct htp_cumsum_context cctx = {
|
||||
.octx = octx,
|
||||
.src_row_size = src_row_size,
|
||||
.dst_row_size = dst_row_size,
|
||||
.src_row_size_aligned = src_row_size_aligned,
|
||||
.dst_row_size_aligned = dst_row_size_aligned,
|
||||
.rows_per_thread = (total_rows + n_threads - 1) / n_threads,
|
||||
.total_rows = total_rows,
|
||||
};
|
||||
|
||||
if (octx->ctx->vtcm_size < spad_per_thread * n_threads) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32, &cctx, n_threads);
|
||||
} else {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32_dma, &cctx, n_threads);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
int op_cumsum(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
switch (dst->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = op_cumsum_f32(octx);
|
||||
break;
|
||||
default:
|
||||
err = HTP_STATUS_NO_SUPPORT;
|
||||
break;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
@@ -75,6 +75,7 @@ enum htp_op {
|
||||
HTP_OP_SUM_ROWS,
|
||||
HTP_OP_SSM_CONV,
|
||||
HTP_OP_REPEAT,
|
||||
HTP_OP_CUMSUM,
|
||||
INVALID
|
||||
};
|
||||
|
||||
|
||||
@@ -60,5 +60,6 @@ int op_cpy(struct htp_ops_context * octx);
|
||||
int op_repeat(struct htp_ops_context * octx);
|
||||
int op_argsort(struct htp_ops_context * octx);
|
||||
int op_ssm_conv(struct htp_ops_context * octx);
|
||||
int op_cumsum(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
||||
@@ -860,6 +860,41 @@ static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req *
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_cumsum_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We've written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_cumsum(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_activations_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
@@ -1474,6 +1509,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
proc_ssm_conv_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_CUMSUM:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad cumsum-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_cumsum_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Op %u", req.op);
|
||||
break;
|
||||
|
||||
@@ -9612,6 +9612,9 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
|
||||
cl_mem B_image1d;
|
||||
cl_mem B_sub_buffer;
|
||||
cl_mem S_image1d;
|
||||
// for B transpose
|
||||
cl_mem B_image1d_trans = nullptr;
|
||||
cl_mem B_d = nullptr;
|
||||
|
||||
cl_mem D_image1d;
|
||||
cl_mem D_sub_buffer;
|
||||
@@ -9703,9 +9706,6 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
|
||||
global_work_size[2] = 1;
|
||||
} else {
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
cl_mem B_image1d_trans = nullptr;
|
||||
// for B transpose
|
||||
cl_mem B_d = nullptr;
|
||||
int padding;
|
||||
|
||||
//how many extra elements beyond multiple of 8
|
||||
@@ -9800,6 +9800,12 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
|
||||
CL_CHECK(clReleaseMemObject(S_image1d));
|
||||
CL_CHECK(clReleaseMemObject(D_sub_buffer));
|
||||
CL_CHECK(clReleaseMemObject(D_image1d));
|
||||
if (B_image1d_trans) {
|
||||
CL_CHECK(clReleaseMemObject(B_image1d_trans));
|
||||
}
|
||||
if (B_d) {
|
||||
CL_CHECK(clReleaseMemObject(B_d));
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(backend);
|
||||
GGML_UNUSED(src0);
|
||||
|
||||
@@ -1009,8 +1009,8 @@ public:
|
||||
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
||||
|
||||
struct stored_graph {
|
||||
ggml_context_ptr ctx_ptr;
|
||||
ggml_cgraph * graph;
|
||||
std::vector<uint8_t> buffer;
|
||||
ggml_cgraph * graph;
|
||||
};
|
||||
|
||||
private:
|
||||
@@ -1518,10 +1518,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
||||
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
|
||||
|
||||
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
||||
|
||||
if (stored_graphs[device].buffer.size() < buf_size) {
|
||||
stored_graphs[device].buffer.resize(buf_size);
|
||||
}
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.mem_buffer =*/ stored_graphs[device].buffer.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||
@@ -1551,7 +1553,6 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
||||
}
|
||||
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
||||
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
||||
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
|
||||
stored_graphs[device].graph = graph;
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -569,9 +569,15 @@ static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw()));
|
||||
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((*stream)
|
||||
.memset(ctx->dev_ptr, value, buffer->size)
|
||||
.wait()));
|
||||
constexpr size_t MAX_CHUNK = 2ULL << 30; // 2 GiB
|
||||
for (size_t off = 0; off < buffer->size; off += MAX_CHUNK) {
|
||||
size_t chunk = std::min(buffer->size - off, MAX_CHUNK);
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(
|
||||
(*stream)
|
||||
.memset(static_cast<char*>(ctx->dev_ptr) + off, value, chunk)
|
||||
.wait()
|
||||
));
|
||||
}
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||
|
||||
@@ -95,6 +95,12 @@ struct ggml_webgpu_generic_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_processed_shader {
|
||||
std::string wgsl;
|
||||
std::string variant;
|
||||
std::shared_ptr<void> decisions;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_conv_shader_decisions {
|
||||
uint32_t block_size;
|
||||
uint32_t tokens_per_wg;
|
||||
@@ -384,11 +390,12 @@ struct ggml_webgpu_flash_attn_pipeline_key {
|
||||
bool has_mask;
|
||||
bool has_sinks;
|
||||
bool uses_logit_softcap;
|
||||
bool use_vec;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
||||
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
||||
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
||||
uses_logit_softcap == other.uses_logit_softcap;
|
||||
uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -402,6 +409,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.has_mask);
|
||||
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
||||
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
||||
ggml_webgpu_hash_combine(seed, key.use_vec);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
@@ -421,6 +429,115 @@ struct ggml_webgpu_flash_attn_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||
// Keep conservative defaults unless this is the f16 vec-split shape family.
|
||||
if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) {
|
||||
return 1u;
|
||||
}
|
||||
|
||||
// Head-dim specializations used by the tuned vec f16 path.
|
||||
switch (key.head_dim_qk) {
|
||||
case 64: return 2u;
|
||||
case 96: return 4u;
|
||||
case 128: return 1u;
|
||||
case 192: return 2u;
|
||||
case 576: return 2u;
|
||||
default: return 1u;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
|
||||
uint32_t head_dim_v;
|
||||
uint32_t wg_size;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
||||
ggml_webgpu_hash_combine(seed, key.wg_size);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
|
||||
const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
|
||||
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context {
|
||||
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn_vec_reduce";
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
variant += std::string("_wg") + std::to_string(context.max_wg_size);
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_blk_pipeline_key {
|
||||
uint32_t q_tile;
|
||||
uint32_t kv_tile;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const {
|
||||
return q_tile == other.q_tile && kv_tile == other.kv_tile;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.q_tile);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_tile);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_blk_shader_lib_context {
|
||||
ggml_webgpu_flash_attn_blk_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn_vec_blk";
|
||||
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile));
|
||||
variant += std::string("_qt") + std::to_string(context.key.q_tile);
|
||||
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile));
|
||||
variant += std::string("_kvt") + std::to_string(context.key.kv_tile);
|
||||
|
||||
uint32_t wg_size = 1;
|
||||
while ((wg_size << 1) <= context.max_wg_size) {
|
||||
wg_size <<= 1;
|
||||
}
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
variant += std::string("_wg") + std::to_string(wg_size);
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
return result;
|
||||
}
|
||||
|
||||
// This is exposed because it's necessary in supports_op
|
||||
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
||||
uint32_t kv_tile,
|
||||
@@ -659,6 +776,14 @@ class ggml_webgpu_shader_lib {
|
||||
repeat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
|
||||
flash_attn_vec_reduce_pipelines;
|
||||
std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
|
||||
flash_attn_blk_pipelines;
|
||||
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
|
||||
@@ -1673,24 +1798,8 @@ class ggml_webgpu_shader_lib {
|
||||
return repeat_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const bool has_mask = context.src3 != nullptr;
|
||||
const bool has_sinks = context.src4 != nullptr;
|
||||
|
||||
bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
|
||||
(context.src1->ne[1] % context.sg_mat_n == 0);
|
||||
|
||||
ggml_webgpu_flash_attn_pipeline_key key = {
|
||||
.kv_type = context.src1->type,
|
||||
.head_dim_qk = (uint32_t) context.src0->ne[0],
|
||||
.head_dim_v = (uint32_t) context.src2->ne[0],
|
||||
.kv_direct = kv_direct,
|
||||
.has_mask = has_mask,
|
||||
.has_sinks = has_sinks,
|
||||
.uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,
|
||||
};
|
||||
|
||||
auto it = flash_attn_pipelines.find(key);
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) {
|
||||
auto it = flash_attn_pipelines.find(context.key);
|
||||
if (it != flash_attn_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
@@ -1698,7 +1807,7 @@ class ggml_webgpu_shader_lib {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn";
|
||||
|
||||
switch (key.kv_type) {
|
||||
switch (context.key.kv_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("KV_F32");
|
||||
break;
|
||||
@@ -1714,41 +1823,52 @@ class ggml_webgpu_shader_lib {
|
||||
default:
|
||||
GGML_ABORT("Unsupported KV type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_") + ggml_type_name(key.kv_type);
|
||||
variant += std::string("_") + ggml_type_name(context.key.kv_type);
|
||||
|
||||
if (key.has_mask) {
|
||||
if (context.key.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
variant += "_mask";
|
||||
}
|
||||
if (key.has_sinks) {
|
||||
if (context.key.has_sinks) {
|
||||
defines.push_back("SINKS");
|
||||
variant += "_sinks";
|
||||
}
|
||||
if (key.uses_logit_softcap) {
|
||||
if (context.key.uses_logit_softcap) {
|
||||
defines.push_back("LOGIT_SOFTCAP");
|
||||
variant += "_lgsc";
|
||||
}
|
||||
if (key.kv_direct) {
|
||||
if (context.key.kv_direct) {
|
||||
defines.push_back("KV_DIRECT");
|
||||
variant += "_kvdirect";
|
||||
}
|
||||
if (context.key.has_mask && context.key.use_vec) {
|
||||
defines.push_back("BLK");
|
||||
variant += "_blk";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
|
||||
uint32_t q_tile = context.sg_mat_m;
|
||||
uint32_t q_tile = context.sg_mat_m;
|
||||
uint32_t kv_tile =
|
||||
std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,
|
||||
context.wg_mem_limit_bytes, context.max_subgroup_size }),
|
||||
std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
|
||||
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
if (key.kv_direct) {
|
||||
if (context.key.use_vec) {
|
||||
q_tile = 1;
|
||||
kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context)));
|
||||
kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
||||
const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key);
|
||||
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
|
||||
}
|
||||
if (context.key.kv_direct) {
|
||||
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= context.sg_mat_n;
|
||||
}
|
||||
@@ -1757,19 +1877,51 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
|
||||
|
||||
uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
uint32_t wg_size = 0;
|
||||
if (context.key.use_vec) {
|
||||
wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
||||
} else {
|
||||
wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
}
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_flash_attn, defines);
|
||||
const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
|
||||
webgpu_pipeline pipeline =
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
|
||||
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
|
||||
decisions->q_tile = q_tile;
|
||||
decisions->kv_tile = kv_tile;
|
||||
decisions->wg_size = wg_size;
|
||||
pipeline.context = decisions;
|
||||
flash_attn_pipelines[context.key] = pipeline;
|
||||
return flash_attn_pipelines[context.key];
|
||||
}
|
||||
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
flash_attn_pipelines[key] = pipeline;
|
||||
return flash_attn_pipelines[key];
|
||||
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
|
||||
auto it = flash_attn_blk_pipelines.find(context.key);
|
||||
if (it != flash_attn_blk_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context);
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
|
||||
flash_attn_blk_pipelines[context.key] = pipeline;
|
||||
return flash_attn_blk_pipelines[context.key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_vec_reduce_pipeline(
|
||||
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
|
||||
auto it = flash_attn_vec_reduce_pipelines.find(context.key);
|
||||
if (it != flash_attn_vec_reduce_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context);
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
|
||||
flash_attn_vec_reduce_pipelines[context.key] = pipeline;
|
||||
return flash_attn_vec_reduce_pipelines[context.key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
|
||||
@@ -658,7 +658,6 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
for (size_t i = 0; i < params_bufs_list.size(); i++) {
|
||||
ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
|
||||
if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
|
||||
@@ -1481,7 +1480,6 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
ggml_tensor * Q,
|
||||
ggml_tensor * K,
|
||||
@@ -1565,30 +1563,248 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = Q,
|
||||
.src1 = K,
|
||||
.src2 = V,
|
||||
.src3 = mask,
|
||||
.src4 = sinks,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
||||
const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
|
||||
const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
|
||||
const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u);
|
||||
|
||||
const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned &&
|
||||
(Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
|
||||
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
|
||||
const bool kv_vec_type_supported =
|
||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
|
||||
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type);
|
||||
const uint32_t vec_nwg_cap =
|
||||
std::max(1u, std::min<uint32_t>(32u, ctx->global_ctx->capabilities.max_subgroup_size));
|
||||
const bool use_blk = use_vec && has_mask;
|
||||
|
||||
ggml_webgpu_flash_attn_pipeline_key key = {
|
||||
.kv_type = K->type,
|
||||
.head_dim_qk = (uint32_t) Q->ne[0],
|
||||
.head_dim_v = (uint32_t) V->ne[0],
|
||||
.kv_direct = kv_direct,
|
||||
.has_mask = static_cast<bool>(has_mask),
|
||||
.has_sinks = static_cast<bool>(has_sinks),
|
||||
.uses_logit_softcap = logit_softcap != 0.0f,
|
||||
.use_vec = use_vec,
|
||||
};
|
||||
|
||||
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
|
||||
.key = key,
|
||||
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
|
||||
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
|
||||
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
|
||||
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
||||
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
|
||||
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
||||
|
||||
wgpu::Buffer blk_buf = {};
|
||||
uint64_t blk_size_bytes = 0;
|
||||
uint32_t blk_nblk0 = 0;
|
||||
uint32_t blk_nblk1 = 0;
|
||||
uint32_t blk_batch_count = 0;
|
||||
|
||||
if (use_vec) {
|
||||
uint32_t nwg = 1u;
|
||||
const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile);
|
||||
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
|
||||
nwg <<= 1;
|
||||
}
|
||||
nwg = std::min(nwg, vec_nwg_cap);
|
||||
GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size);
|
||||
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
|
||||
const bool use_vec_reduce = nwg > 1u;
|
||||
GGML_ASSERT(nrows <= UINT32_MAX);
|
||||
|
||||
uint64_t tmp_stats_base = 0;
|
||||
uint64_t tmp_size_bytes = 0;
|
||||
wgpu::Buffer tmp_buf = {};
|
||||
uint64_t tmp_bind_offset = 0;
|
||||
uint64_t tmp_bind_size = 0;
|
||||
const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
|
||||
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
||||
size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes);
|
||||
|
||||
if (use_vec_reduce) {
|
||||
const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
|
||||
const uint64_t tmp_stats_elems = nrows * 2u * nwg;
|
||||
tmp_stats_base = tmp_data_elems;
|
||||
tmp_size_bytes =
|
||||
ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
GGML_ASSERT(tmp_stats_base <= UINT32_MAX);
|
||||
tmp_buf = ggml_webgpu_tensor_buf(dst);
|
||||
tmp_bind_offset = scratch_offset;
|
||||
tmp_bind_size = tmp_size_bytes;
|
||||
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
|
||||
} else {
|
||||
// nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation.
|
||||
tmp_buf = ggml_webgpu_tensor_buf(dst);
|
||||
tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst);
|
||||
tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst);
|
||||
}
|
||||
|
||||
webgpu_pipeline blk_pipeline;
|
||||
std::vector<uint32_t> blk_params;
|
||||
std::vector<wgpu::BindGroupEntry> blk_entries;
|
||||
if (use_blk) {
|
||||
GGML_ASSERT(has_mask);
|
||||
|
||||
blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile);
|
||||
blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile);
|
||||
blk_buf = ggml_webgpu_tensor_buf(dst);
|
||||
const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
|
||||
blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
|
||||
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
|
||||
blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = {
|
||||
.key =
|
||||
{
|
||||
.q_tile = decisions->q_tile,
|
||||
.kv_tile = decisions->kv_tile,
|
||||
},
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx);
|
||||
|
||||
blk_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask
|
||||
(uint32_t) Q->ne[1], // seq_len_q
|
||||
(uint32_t) K->ne[1], // seq_len_kv
|
||||
stride_mask3, // stride_mask3
|
||||
blk_nblk0, // nblk0
|
||||
blk_nblk1, // nblk1
|
||||
};
|
||||
blk_entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(mask),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, mask) },
|
||||
{ .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes },
|
||||
};
|
||||
scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> split_params = params;
|
||||
if (use_blk) {
|
||||
split_params.push_back(0u); // blk_base
|
||||
split_params.push_back(blk_nblk0); // blk_nblk0
|
||||
split_params.push_back(blk_nblk1); // blk_nblk1
|
||||
}
|
||||
split_params.push_back(0u); // tmp_data_base
|
||||
split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base
|
||||
split_params.push_back(nwg); // nwg
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> split_entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(Q),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, Q),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, Q) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(K),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, K),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, K) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(V),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, V),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, V) },
|
||||
};
|
||||
uint32_t split_binding_index = 3;
|
||||
if (has_mask) {
|
||||
split_entries.push_back({ .binding = split_binding_index++,
|
||||
.buffer = ggml_webgpu_tensor_buf(mask),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, mask) });
|
||||
}
|
||||
if (has_sinks) {
|
||||
split_entries.push_back({ .binding = split_binding_index++,
|
||||
.buffer = ggml_webgpu_tensor_buf(sinks),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
|
||||
}
|
||||
if (use_blk) {
|
||||
split_entries.push_back(
|
||||
{ .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes });
|
||||
}
|
||||
split_entries.push_back(
|
||||
{ .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size });
|
||||
split_entries.push_back({ .binding = split_binding_index++,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
|
||||
webgpu_pipeline reduce_pipeline;
|
||||
std::vector<uint32_t> reduce_params;
|
||||
std::vector<wgpu::BindGroupEntry> reduce_entries;
|
||||
if (use_vec_reduce) {
|
||||
const uint32_t reduce_wg_size = std::max(
|
||||
32u,
|
||||
std::min<uint32_t>(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
|
||||
ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = {
|
||||
.key =
|
||||
{
|
||||
.head_dim_v = (uint32_t) V->ne[0],
|
||||
.wg_size = reduce_wg_size,
|
||||
},
|
||||
.max_wg_size = reduce_wg_size,
|
||||
};
|
||||
reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
|
||||
|
||||
reduce_params = {
|
||||
(uint32_t) nrows, // nrows
|
||||
(uint32_t) Q->ne[1], // seq_len_q
|
||||
(uint32_t) Q->ne[2], // n_heads
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst
|
||||
nwg, // nwg
|
||||
0u, // tmp_data_base
|
||||
(uint32_t) tmp_stats_base, // tmp_stats_base
|
||||
};
|
||||
|
||||
reduce_entries = {
|
||||
{ .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||
};
|
||||
}
|
||||
|
||||
const uint64_t split_wg_total = (uint64_t) wg_x * nwg;
|
||||
GGML_ASSERT(split_wg_total <= UINT32_MAX);
|
||||
std::vector<webgpu_pipeline> pipelines;
|
||||
std::vector<std::vector<uint32_t>> params_list;
|
||||
std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
|
||||
std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
|
||||
|
||||
if (use_blk) {
|
||||
pipelines.push_back(blk_pipeline);
|
||||
params_list.push_back(std::move(blk_params));
|
||||
entries_list.push_back(std::move(blk_entries));
|
||||
workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count });
|
||||
}
|
||||
pipelines.push_back(pipeline);
|
||||
params_list.push_back(std::move(split_params));
|
||||
entries_list.push_back(std::move(split_entries));
|
||||
workgroups_list.push_back({ (uint32_t) split_wg_total, 1u });
|
||||
if (use_vec_reduce) {
|
||||
pipelines.push_back(reduce_pipeline);
|
||||
params_list.push_back(std::move(reduce_params));
|
||||
entries_list.push_back(std::move(reduce_entries));
|
||||
workgroups_list.push_back({ (uint32_t) nrows, 1u });
|
||||
}
|
||||
|
||||
return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
|
||||
entries_list, workgroups_list);
|
||||
}
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
#endif
|
||||
|
||||
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool is_unary = dst->op == GGML_OP_UNARY;
|
||||
@@ -2559,7 +2775,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
std::vector<webgpu_submission> subs;
|
||||
uint32_t num_batched_kernels = 0;
|
||||
bool contains_set_rows = false;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
|
||||
contains_set_rows = true;
|
||||
@@ -2834,6 +3049,86 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const ggml_tensor * Q = tensor->src[0];
|
||||
const ggml_tensor * K = tensor->src[1];
|
||||
const ggml_tensor * V = tensor->src[2];
|
||||
const ggml_tensor * mask = tensor->src[3];
|
||||
const ggml_tensor * sinks = tensor->src[4];
|
||||
if (Q && K && V) {
|
||||
GGML_UNUSED(sinks);
|
||||
const bool kv_direct = (K->type == GGML_TYPE_F16) &&
|
||||
(Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) &&
|
||||
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
const bool kv_vec_type_supported =
|
||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
const bool use_vec =
|
||||
(Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
|
||||
(V->type == K->type);
|
||||
if (use_vec) {
|
||||
const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
||||
const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
||||
const size_t limit_bytes =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
const size_t q_tile = sg_mat_m;
|
||||
const size_t base_q_bytes =
|
||||
(Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
size_t bytes_per_kv = 0;
|
||||
if (!kv_direct) {
|
||||
bytes_per_kv += std::max(Q->ne[0], V->ne[0]);
|
||||
}
|
||||
if (mask != nullptr) {
|
||||
bytes_per_kv += q_tile;
|
||||
}
|
||||
bytes_per_kv += q_tile;
|
||||
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
||||
uint32_t kv_tile =
|
||||
((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n;
|
||||
kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile));
|
||||
kv_tile = (kv_tile / sg_mat_n) * sg_mat_n;
|
||||
if (kv_direct) {
|
||||
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= sg_mat_n;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t vec_nwg_cap = std::max(
|
||||
1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
|
||||
uint32_t nwg = 1u;
|
||||
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
|
||||
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
|
||||
nwg <<= 1;
|
||||
}
|
||||
nwg = std::min(nwg, vec_nwg_cap);
|
||||
|
||||
const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
|
||||
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
|
||||
if (nwg > 1u) {
|
||||
const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
|
||||
const uint64_t tmp_stats_elems = nrows * 2u * nwg;
|
||||
const size_t tmp_size_bytes = ROUNDUP_POW2(
|
||||
(tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
res += tmp_size_bytes + align;
|
||||
}
|
||||
if (mask != nullptr) {
|
||||
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
|
||||
const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u);
|
||||
const uint32_t stride_mask3 =
|
||||
(uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
|
||||
const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
|
||||
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
|
||||
const size_t blk_size_bytes =
|
||||
ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
res += blk_size_bytes + align;
|
||||
}
|
||||
res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
105
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl
Normal file
105
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl
Normal file
@@ -0,0 +1,105 @@
|
||||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
|
||||
#define Q_TILE 1
|
||||
#define KV_TILE 32
|
||||
#define WG_SIZE 32
|
||||
|
||||
struct Params {
|
||||
offset_mask: u32,
|
||||
seq_len_q: u32,
|
||||
seq_len_kv: u32,
|
||||
stride_mask3: u32,
|
||||
// Number of KV blocks and Q blocks per batch.
|
||||
// nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE).
|
||||
nblk0: u32,
|
||||
nblk1: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> mask: array<f16>;
|
||||
@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
|
||||
@group(0) @binding(2) var<uniform> params: Params;
|
||||
|
||||
const MASK_MIN: f32 = -65504.0;
|
||||
const MASK_MAX: f32 = 65504.0;
|
||||
var<workgroup> wg_min: array<f32, WG_SIZE>;
|
||||
var<workgroup> wg_max: array<f32, WG_SIZE>;
|
||||
var<workgroup> wg_any: array<u32, WG_SIZE>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>) {
|
||||
// Dispatch mapping:
|
||||
// - x indexes KV blocks
|
||||
// - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk
|
||||
let kv_blk = wg_id.x;
|
||||
let y = wg_id.y;
|
||||
let q_blk = y % params.nblk1;
|
||||
let batch_idx = y / params.nblk1;
|
||||
if (kv_blk >= params.nblk0) {
|
||||
return;
|
||||
}
|
||||
|
||||
let q_start = q_blk * Q_TILE;
|
||||
let k_start = kv_blk * KV_TILE;
|
||||
|
||||
let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
|
||||
let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3;
|
||||
|
||||
// We keep min/max to classify:
|
||||
// - fully masked (max <= MASK_MIN)
|
||||
// - all-zero mask (min == 0 && max == 0)
|
||||
// - mixed/general mask
|
||||
var local_min = MASK_MAX;
|
||||
var local_max = -MASK_MAX;
|
||||
var local_any = 0u;
|
||||
|
||||
for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) {
|
||||
let q_row = q_start + q_rel;
|
||||
if (q_row >= params.seq_len_q) {
|
||||
continue;
|
||||
}
|
||||
let row_base = mask_batch_base + q_row * params.seq_len_kv;
|
||||
for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) {
|
||||
let k_col = k_start + k_rel;
|
||||
if (k_col >= params.seq_len_kv) {
|
||||
continue;
|
||||
}
|
||||
let mv = f32(mask[row_base + k_col]);
|
||||
local_min = min(local_min, mv);
|
||||
local_max = max(local_max, mv);
|
||||
local_any = 1u;
|
||||
}
|
||||
}
|
||||
|
||||
wg_min[local_id.x] = local_min;
|
||||
wg_max[local_id.x] = local_max;
|
||||
wg_any[local_id.x] = local_any;
|
||||
workgroupBarrier();
|
||||
|
||||
// Thread 0 writes one state per block.
|
||||
if (local_id.x == 0u) {
|
||||
var mmin = wg_min[0];
|
||||
var mmax = wg_max[0];
|
||||
var many = wg_any[0];
|
||||
for (var i = 1u; i < WG_SIZE; i += 1u) {
|
||||
mmin = min(mmin, wg_min[i]);
|
||||
mmax = max(mmax, wg_max[i]);
|
||||
many = max(many, wg_any[i]);
|
||||
}
|
||||
|
||||
var state = 0u;
|
||||
if (many != 0u) {
|
||||
if (mmax <= MASK_MIN) {
|
||||
state = 0u;
|
||||
} else if (mmin == 0.0 && mmax == 0.0) {
|
||||
state = 2u;
|
||||
} else {
|
||||
state = 1u;
|
||||
}
|
||||
}
|
||||
|
||||
let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk;
|
||||
blk[blk_idx] = state;
|
||||
}
|
||||
}
|
||||
78
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl
Normal file
78
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl
Normal file
@@ -0,0 +1,78 @@
|
||||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
|
||||
// Default values
|
||||
#define HEAD_DIM_V 64
|
||||
#define WG_SIZE 128
|
||||
|
||||
struct Params {
|
||||
nrows: u32,
|
||||
seq_len_q: u32,
|
||||
n_heads: u32,
|
||||
offset_dst: u32,
|
||||
nwg: u32,
|
||||
tmp_data_base: u32,
|
||||
tmp_stats_base: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> tmp: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
@group(0) @binding(2) var<uniform> params: Params;
|
||||
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
let rid = wg_id.x;
|
||||
if (rid >= params.nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
let rows_per_batch = params.n_heads * params.seq_len_q;
|
||||
let batch_idx = rid / rows_per_batch;
|
||||
let rem = rid % rows_per_batch;
|
||||
let head_idx = rem / params.seq_len_q;
|
||||
let q_row = rem % params.seq_len_q;
|
||||
|
||||
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||
let dst3_stride = dst2_stride * params.seq_len_q;
|
||||
let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V;
|
||||
|
||||
let thread = sg_inv_id;
|
||||
if (params.nwg > subgroup_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
let stats_base = params.tmp_stats_base + rid * (2u * params.nwg);
|
||||
let active_thread = thread < params.nwg;
|
||||
let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread);
|
||||
let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread);
|
||||
let m = subgroupMax(mi);
|
||||
let ms = select(0.0, exp(mi - m), active_thread);
|
||||
let s = subgroupAdd(si * ms);
|
||||
let inv_s = select(0.0, 1.0 / s, s != 0.0);
|
||||
|
||||
let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg);
|
||||
for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) {
|
||||
var weighted = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
if (active_thread) {
|
||||
let src = row_tmp_base + thread * HEAD_DIM_V + elem_base;
|
||||
weighted = vec4<f32>(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms;
|
||||
}
|
||||
|
||||
let sum_x = subgroupAdd(weighted.x);
|
||||
let sum_y = subgroupAdd(weighted.y);
|
||||
let sum_z = subgroupAdd(weighted.z);
|
||||
let sum_w = subgroupAdd(weighted.w);
|
||||
|
||||
if (thread == 0u) {
|
||||
let dst_vec_index = (row_base + elem_base) >> 2u;
|
||||
dst[dst_vec_index] = vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s;
|
||||
}
|
||||
}
|
||||
}
|
||||
729
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl
Normal file
729
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl
Normal file
@@ -0,0 +1,729 @@
|
||||
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
||||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
enable chromium_experimental_subgroup_matrix;
|
||||
|
||||
#ifdef KV_F32
|
||||
#define KV_TYPE f32
|
||||
#else
|
||||
#define KV_TYPE f16
|
||||
#endif
|
||||
|
||||
#define HEAD_DIM_QK 64
|
||||
#define HEAD_DIM_V 64
|
||||
|
||||
|
||||
#define SG_MAT_M 8
|
||||
#define SG_MAT_N 8
|
||||
#define SG_MAT_K 8
|
||||
|
||||
#define Q_TILE SG_MAT_M
|
||||
#define KV_TILE 16
|
||||
#define WG_SIZE 64
|
||||
#ifndef VEC_NE
|
||||
#define VEC_NE 4u
|
||||
#endif
|
||||
|
||||
#define KV_BLOCKS (KV_TILE / SG_MAT_N)
|
||||
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
#if defined(KV_Q4_0)
|
||||
#define NQ 16
|
||||
#define F16_PER_BLOCK 9
|
||||
#define WEIGHTS_PER_F16 4
|
||||
#elif defined(KV_Q8_0)
|
||||
#define NQ 8
|
||||
#define F16_PER_BLOCK 17
|
||||
#define WEIGHTS_PER_F16 2
|
||||
#endif
|
||||
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
|
||||
|
||||
fn get_byte(value: u32, index: u32) -> u32 {
|
||||
return (value >> (index * 8)) & 0xFF;
|
||||
}
|
||||
|
||||
fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
|
||||
}
|
||||
|
||||
struct Params {
|
||||
offset_q: u32,
|
||||
offset_k: u32,
|
||||
offset_v: u32,
|
||||
offset_mask: u32,
|
||||
offset_sinks: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
// shapes of Q/K/V
|
||||
n_heads: u32,
|
||||
seq_len_q: u32,
|
||||
seq_len_kv: u32,
|
||||
|
||||
// strides (in elements)
|
||||
stride_q1: u32,
|
||||
stride_q2: u32,
|
||||
stride_q3: u32,
|
||||
stride_k1: u32,
|
||||
stride_k2: u32,
|
||||
stride_k3: u32,
|
||||
stride_v1: u32,
|
||||
stride_v2: u32,
|
||||
stride_v3: u32,
|
||||
stride_mask3: u32,
|
||||
|
||||
// repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
|
||||
q_per_kv: u32,
|
||||
|
||||
// softmax params
|
||||
scale: f32,
|
||||
max_bias: f32,
|
||||
logit_softcap: f32,
|
||||
n_head_log2: f32,
|
||||
m0: f32,
|
||||
m1: f32,
|
||||
|
||||
#ifdef BLK
|
||||
blk_base: u32,
|
||||
blk_nblk0: u32,
|
||||
blk_nblk1: u32,
|
||||
#endif
|
||||
|
||||
tmp_data_base: u32,
|
||||
tmp_stats_base: u32,
|
||||
nwg: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
|
||||
#endif
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
|
||||
#endif
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||
#ifdef BLK
|
||||
#define BLK_BINDING 5
|
||||
#define TMP_BINDING 6
|
||||
#define DST_BINDING 7
|
||||
#define PARAMS_BINDING 8
|
||||
#else
|
||||
#define TMP_BINDING 5
|
||||
#define DST_BINDING 6
|
||||
#define PARAMS_BINDING 7
|
||||
#endif
|
||||
#elif defined(MASK)
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
#ifdef BLK
|
||||
#define BLK_BINDING 4
|
||||
#define TMP_BINDING 5
|
||||
#define DST_BINDING 6
|
||||
#define PARAMS_BINDING 7
|
||||
#else
|
||||
#define TMP_BINDING 4
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#elif defined(SINKS)
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define TMP_BINDING 4
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#else
|
||||
#define TMP_BINDING 3
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
|
||||
#ifdef BLK
|
||||
@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
|
||||
#endif
|
||||
@group(0) @binding(TMP_BINDING) var<storage, read_write> tmp: array<f32>;
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
|
||||
|
||||
// Just a very small float value.
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
|
||||
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
|
||||
|
||||
#ifndef KV_DIRECT
|
||||
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
||||
// we can reuse the same shmem for K and V since we only need one at a time
|
||||
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
|
||||
#endif
|
||||
|
||||
var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>;
|
||||
|
||||
#ifdef MASK
|
||||
// storage for mask values
|
||||
var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
#endif
|
||||
|
||||
// note that we reuse the same storage for both since we only need one at a time
|
||||
var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
|
||||
// Storage for row max and exp sum during online softmax
|
||||
var<workgroup> row_max_shmem: array<f32, Q_TILE>;
|
||||
var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
|
||||
var<workgroup> blk_state_wg: u32;
|
||||
|
||||
fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
|
||||
var v = select(FLOAT_MIN,
|
||||
f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
|
||||
kv_idx < KV_TILE);
|
||||
#ifdef LOGIT_SOFTCAP
|
||||
v = params.logit_softcap * tanh(v);
|
||||
#endif
|
||||
#ifdef MASK
|
||||
if (apply_mask) {
|
||||
var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
|
||||
v += select(mask_val, slope * mask_val, has_bias);
|
||||
}
|
||||
#endif
|
||||
return v;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
|
||||
// initialize row max for online softmax
|
||||
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
|
||||
row_max_shmem[i] = FLOAT_MIN;
|
||||
exp_sum_shmem[i] = 0.0;
|
||||
}
|
||||
|
||||
for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
|
||||
o_shmem[i] = 0.0;
|
||||
}
|
||||
|
||||
// workgroups per head/batch
|
||||
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
|
||||
let wg_per_batch = wg_per_head * params.n_heads;
|
||||
|
||||
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||
let dst3_stride = dst2_stride * params.seq_len_q;
|
||||
|
||||
let iwg = wg_id.x % params.nwg;
|
||||
let base_wg_id = wg_id.x / params.nwg;
|
||||
|
||||
// batch index
|
||||
let batch_idx = base_wg_id / wg_per_batch;
|
||||
let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
|
||||
let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
|
||||
let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
|
||||
let wg_in_batch = base_wg_id % wg_per_batch;
|
||||
|
||||
// head index
|
||||
let head_idx = wg_in_batch / wg_per_head;
|
||||
let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
|
||||
let k_head_idx = head_idx / params.q_per_kv;
|
||||
let v_head_idx = k_head_idx;
|
||||
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
|
||||
let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
|
||||
|
||||
// starting Q row for this workgroup
|
||||
let wg_in_head = wg_in_batch % wg_per_head;
|
||||
let q_row_start = wg_in_head * Q_TILE;
|
||||
|
||||
#ifdef MASK
|
||||
// mask offset
|
||||
let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
|
||||
#endif
|
||||
|
||||
let head = f32(head_idx);
|
||||
let has_bias = params.max_bias > 0.0;
|
||||
let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias);
|
||||
|
||||
// load q tile into shared memory
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let q_row = elem_idx / HEAD_DIM_QK;
|
||||
let q_col = elem_idx % HEAD_DIM_QK;
|
||||
let head_q_row = q_row_start + q_row;
|
||||
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
|
||||
q_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
Q[global_q_row_offset + q_col],
|
||||
head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
|
||||
}
|
||||
|
||||
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
|
||||
#ifdef BLK
|
||||
let q_blk = q_row_start / Q_TILE;
|
||||
let kv_blk = kv_tile / KV_TILE;
|
||||
let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
|
||||
let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk;
|
||||
let blk_state_local = blk[blk_idx];
|
||||
#else
|
||||
let blk_state_local = 1u;
|
||||
#endif
|
||||
if (local_id.x == 0u) {
|
||||
blk_state_wg = blk_state_local;
|
||||
}
|
||||
workgroupBarrier();
|
||||
let blk_state = blk_state_wg;
|
||||
let skip_tile = blk_state == 0u;
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
inter_shmem[elem_idx] = f16(0.0);
|
||||
}
|
||||
|
||||
// load k tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_Q8_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_DIRECT)
|
||||
// Direct global loads for KV
|
||||
#else
|
||||
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) {
|
||||
let k_row = elem_idx / HEAD_DIM_QK;
|
||||
let k_col = elem_idx % HEAD_DIM_QK;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
|
||||
let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
|
||||
let vec_idx = (global_k_row_offset + k_col) >> 2u;
|
||||
let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds);
|
||||
kv_shmem[elem_idx + 0u] = f16(k4.x);
|
||||
kv_shmem[elem_idx + 1u] = f16(k4.y);
|
||||
kv_shmem[elem_idx + 2u] = f16(k4.z);
|
||||
kv_shmem[elem_idx + 3u] = f16(k4.w);
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// accumulate q block * k block into registers across the entire KV tile
|
||||
if (!skip_tile) {
|
||||
let num_of_threads = subgroup_size / VEC_NE;
|
||||
let tx = sg_inv_id % num_of_threads;
|
||||
let ty = sg_inv_id / num_of_threads;
|
||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
continue;
|
||||
}
|
||||
let local_q_row_offset = q_tile_row * HEAD_DIM_QK;
|
||||
|
||||
for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) {
|
||||
let kv_idx = kv_base + ty;
|
||||
var partial_sum: f32 = 0.0;
|
||||
let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv;
|
||||
if (kv_valid) {
|
||||
for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) {
|
||||
let q_off = local_q_row_offset + i * 4u;
|
||||
|
||||
let qv = vec4<f32>(
|
||||
f32(q_shmem[q_off + 0u]),
|
||||
f32(q_shmem[q_off + 1u]),
|
||||
f32(q_shmem[q_off + 2u]),
|
||||
f32(q_shmem[q_off + 3u]));
|
||||
#ifdef KV_DIRECT
|
||||
let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u);
|
||||
let kv = vec4<f32>(K[idx >> 2u]);
|
||||
#else
|
||||
let idx = kv_idx * HEAD_DIM_QK + (i * 4u);
|
||||
let kv = vec4<f32>(
|
||||
f32(kv_shmem[idx + 0u]),
|
||||
f32(kv_shmem[idx + 1u]),
|
||||
f32(kv_shmem[idx + 2u]),
|
||||
f32(kv_shmem[idx + 3u]));
|
||||
#endif
|
||||
partial_sum += dot(qv, kv);
|
||||
}
|
||||
}
|
||||
var sum = partial_sum;
|
||||
// Reduce over tx threads (NL) for this ty stripe.
|
||||
var tx_delta = num_of_threads >> 1u;
|
||||
loop {
|
||||
if (tx_delta == 0u) {
|
||||
break;
|
||||
}
|
||||
let sh = subgroupShuffleDown(sum, tx_delta);
|
||||
if (tx < tx_delta) {
|
||||
sum += sh;
|
||||
}
|
||||
tx_delta >>= 1u;
|
||||
}
|
||||
|
||||
let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
|
||||
if (tx == 0u && kv_valid) {
|
||||
let dst_idx = q_tile_row * KV_TILE + kv_idx;
|
||||
inter_shmem[dst_idx] = f16(sum_bcast);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#ifdef MASK
|
||||
let apply_mask = !skip_tile && (blk_state != 2u);
|
||||
if (apply_mask) {
|
||||
// load mask tile into shared memory for this KV block
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
let mask_row = elem_idx / KV_TILE;
|
||||
let mask_col = elem_idx % KV_TILE;
|
||||
let global_q_row = q_row_start + mask_row;
|
||||
let global_k_col = kv_tile + mask_col;
|
||||
let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
|
||||
let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
|
||||
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
|
||||
}
|
||||
}
|
||||
#else
|
||||
let apply_mask = false;
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// online softmax
|
||||
if (!skip_tile) {
|
||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
|
||||
var prev_max = row_max_shmem[q_tile_row];
|
||||
var final_max = prev_max;
|
||||
// pass 1: compute final max across the full KV tile in chunks
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
|
||||
let softmax_term = select(FLOAT_MIN,
|
||||
calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask),
|
||||
kv_valid);
|
||||
final_max = subgroupMax(max(final_max, softmax_term));
|
||||
}
|
||||
|
||||
var total_exp_term: f32 = 0.0;
|
||||
// pass 2: compute exp sum and write P using final_max
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask);
|
||||
let cur_p = select(0.0,
|
||||
exp(softmax_term - final_max),
|
||||
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
|
||||
total_exp_term += subgroupAdd(cur_p);
|
||||
if (kv_idx < KV_TILE) {
|
||||
inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
|
||||
}
|
||||
}
|
||||
|
||||
let cur_exp = exp(prev_max - final_max);
|
||||
|
||||
if (sg_inv_id == 0) {
|
||||
row_max_shmem[q_tile_row] = final_max;
|
||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
|
||||
}
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// load v tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_Q8_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_DIRECT)
|
||||
// Direct global loads for KV
|
||||
#else
|
||||
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) {
|
||||
let v_row = elem_idx / HEAD_DIM_V;
|
||||
let v_col = elem_idx % HEAD_DIM_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
|
||||
let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
|
||||
let vec_idx = (global_v_row_offset + v_col) >> 2u;
|
||||
let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds);
|
||||
kv_shmem[elem_idx + 0u] = f16(v4.x);
|
||||
kv_shmem[elem_idx + 1u] = f16(v4.y);
|
||||
kv_shmem[elem_idx + 2u] = f16(v4.z);
|
||||
kv_shmem[elem_idx + 3u] = f16(v4.w);
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (!skip_tile) {
|
||||
// we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
|
||||
// we want to compute O += P * V across the full KV tile
|
||||
let ne_threads : u32 = VEC_NE;
|
||||
let nl_threads = max(1u, subgroup_size / ne_threads);
|
||||
let tx_pv = sg_inv_id % nl_threads;
|
||||
let ty_pv = sg_inv_id / nl_threads;
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) {
|
||||
var lo = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) {
|
||||
let kv_idx = cc * ne_threads + ty_pv;
|
||||
let v_row = kv_tile + kv_idx;
|
||||
if (v_row >= params.seq_len_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]);
|
||||
#ifdef KV_DIRECT
|
||||
let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
|
||||
let v4 = vec4<f32>(V[v_idx >> 2u]);
|
||||
#else
|
||||
let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u;
|
||||
let v4 = vec4<f32>(
|
||||
f32(kv_shmem[v_idx + 0u]),
|
||||
f32(kv_shmem[v_idx + 1u]),
|
||||
f32(kv_shmem[v_idx + 2u]),
|
||||
f32(kv_shmem[v_idx + 3u]));
|
||||
#endif
|
||||
lo += p * v4;
|
||||
}
|
||||
|
||||
var lo_x = lo.x;
|
||||
var lo_y = lo.y;
|
||||
var lo_z = lo.z;
|
||||
var lo_w = lo.w;
|
||||
// Reduce over ty threads (NE) for this tx thread.
|
||||
var ty_delta = ne_threads >> 1u;
|
||||
loop {
|
||||
if (ty_delta == 0u) {
|
||||
break;
|
||||
}
|
||||
let thread_delta = ty_delta * nl_threads;
|
||||
let shx = subgroupShuffleDown(lo_x, thread_delta);
|
||||
let shy = subgroupShuffleDown(lo_y, thread_delta);
|
||||
let shz = subgroupShuffleDown(lo_z, thread_delta);
|
||||
let shw = subgroupShuffleDown(lo_w, thread_delta);
|
||||
if (ty_pv < ty_delta) {
|
||||
lo_x += shx;
|
||||
lo_y += shy;
|
||||
lo_z += shz;
|
||||
lo_w += shw;
|
||||
}
|
||||
ty_delta >>= 1u;
|
||||
}
|
||||
|
||||
if (ty_pv == 0u) {
|
||||
let elem_base = vec_col * 4u;
|
||||
let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base;
|
||||
o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x);
|
||||
o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y);
|
||||
o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z);
|
||||
o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
|
||||
#ifdef SINKS
|
||||
// Sinks are global terms and must be applied exactly once across split workgroups.
|
||||
if (iwg == 0u) {
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
|
||||
var prev_max = row_max_shmem[q_tile_row];
|
||||
|
||||
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
|
||||
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
|
||||
let new_max = subgroupMax(max(prev_max, sink_val));
|
||||
let max_exp = exp(prev_max - new_max);
|
||||
let sink_exp = exp(sink_val - new_max);
|
||||
|
||||
let sink_exp_sum = subgroupAdd(sink_exp);
|
||||
|
||||
if (sg_inv_id == 0) {
|
||||
row_max_shmem[q_tile_row] = new_max;
|
||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
|
||||
}
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp);
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
#endif
|
||||
let rows_per_batch = params.n_heads * params.seq_len_q;
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) { break; }
|
||||
|
||||
if (params.nwg == 1u) {
|
||||
let exp_sum = exp_sum_shmem[q_tile_row];
|
||||
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
|
||||
let row_base: u32 =
|
||||
params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V;
|
||||
|
||||
for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) {
|
||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
||||
|
||||
let v = vec4<f32>(
|
||||
f32(o_shmem[i0]) * scale,
|
||||
f32(o_shmem[i1]) * scale,
|
||||
f32(o_shmem[i2]) * scale,
|
||||
f32(o_shmem[i3]) * scale
|
||||
);
|
||||
|
||||
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
|
||||
dst[dst_vec_index] = v;
|
||||
}
|
||||
} else {
|
||||
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row;
|
||||
let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V;
|
||||
let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg;
|
||||
|
||||
for (var elem_base = sg_inv_id * 4u;
|
||||
elem_base < HEAD_DIM_V;
|
||||
elem_base += subgroup_size * 4u) {
|
||||
|
||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
||||
|
||||
let tbase = tmp_row_data_base + elem_base;
|
||||
tmp[tbase + 0u] = f32(o_shmem[i0]);
|
||||
tmp[tbase + 1u] = f32(o_shmem[i1]);
|
||||
tmp[tbase + 2u] = f32(o_shmem[i2]);
|
||||
tmp[tbase + 3u] = f32(o_shmem[i3]);
|
||||
}
|
||||
|
||||
if (sg_inv_id == 0u) {
|
||||
tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row];
|
||||
tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -419,6 +419,7 @@ class MODEL_ARCH(IntEnum):
|
||||
GEMMA2 = auto()
|
||||
GEMMA3 = auto()
|
||||
GEMMA3N = auto()
|
||||
GEMMA4 = auto()
|
||||
GEMMA_EMBEDDING = auto()
|
||||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
@@ -535,8 +536,11 @@ class MODEL_TENSOR(IntEnum):
|
||||
FFN_GATE_INP = auto()
|
||||
FFN_GATE_INP_SHEXP = auto()
|
||||
FFN_NORM = auto()
|
||||
FFN_PRE_NORM = auto()
|
||||
FFN_PRE_NORM = auto() # alias of FFN_NORM
|
||||
FFN_PRE_NORM_2 = auto() # gemma4
|
||||
FFN_POST_NORM = auto()
|
||||
FFN_POST_NORM_1 = auto() # gemma4
|
||||
FFN_POST_NORM_2 = auto() # gemma4
|
||||
FFN_GATE = auto()
|
||||
FFN_DOWN = auto()
|
||||
FFN_UP = auto()
|
||||
@@ -558,6 +562,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
LAYER_OUT_NORM = auto()
|
||||
LAYER_OUT_SCALE = auto()
|
||||
PER_LAYER_TOKEN_EMBD = auto() # gemma3n
|
||||
PER_LAYER_MODEL_PROJ = auto() # gemma3n
|
||||
PER_LAYER_INP_GATE = auto() # gemma3n
|
||||
@@ -722,8 +727,11 @@ class MODEL_TENSOR(IntEnum):
|
||||
V_ENC_FFN_UP = auto()
|
||||
V_ENC_FFN_GATE = auto()
|
||||
V_ENC_FFN_DOWN = auto()
|
||||
V_ENC_ATTN_POST_NORM = auto() # gemma4
|
||||
V_ENC_FFN_POST_NORM = auto()
|
||||
V_LAYER_SCALE_1 = auto()
|
||||
V_LAYER_SCALE_2 = auto()
|
||||
V_LAYER_OUT_SCALE = auto()
|
||||
V_PRE_NORM = auto()
|
||||
V_POST_NORM = auto()
|
||||
V_MM_POST_NORM = auto()
|
||||
@@ -761,6 +769,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
V_MM_GATE = auto() # cogvlm
|
||||
V_TOK_BOI = auto() # cogvlm
|
||||
V_TOK_EOI = auto() # cogvlm
|
||||
V_STD_BIAS = auto() # gemma4
|
||||
V_STD_SCALE = auto() # gemma4
|
||||
V_SAM_POS_EMBD = auto() # Deepseek-OCR
|
||||
V_SAM_PATCH_EMBD = auto() # Deepseek-OCR
|
||||
V_SAM_PRE_NORM = auto() # Deepseek-OCR
|
||||
@@ -781,6 +791,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
A_ENC_EMBD_POS = auto()
|
||||
A_ENC_EMBD_NORM = auto()
|
||||
A_ENC_EMBD_TO_LOGITS = auto() # lfm2
|
||||
A_ENC_INP_PROJ = auto() # gemma4
|
||||
A_ENC_CONV1D = auto()
|
||||
A_ENC_CONV1D_NORM = auto() # gemma3n
|
||||
A_PRE_NORM = auto()
|
||||
@@ -789,10 +800,13 @@ class MODEL_TENSOR(IntEnum):
|
||||
A_ENC_ATTN_Q = auto()
|
||||
A_ENC_ATTN_K = auto()
|
||||
A_ENC_ATTN_V = auto()
|
||||
A_ENC_ATTN_POST_NORM = auto()
|
||||
A_ENC_ATTN_PRE_NORM = auto()
|
||||
A_ENC_ATTN_K_REL = auto() # gemma4
|
||||
A_ENC_PER_DIM_SCALE = auto() # gemma3n
|
||||
A_ENC_INPUT_NORM = auto()
|
||||
A_ENC_OUTPUT = auto()
|
||||
A_ENC_OUTPUT_NORM = auto()
|
||||
A_ENC_OUTPUT = auto() # TODO @ngxson: rename to ATTN_OUT
|
||||
A_ENC_OUTPUT_NORM = auto() # TODO @ngxson: rename to ATTN_OUT
|
||||
A_ENC_FFN_UP = auto()
|
||||
A_ENC_FFN_NORM = auto()
|
||||
A_ENC_FFN_POST_NORM = auto() # gemma3n
|
||||
@@ -813,6 +827,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
A_MM_HARD_EMB_NORM = auto() # gemma3n
|
||||
A_MM_SOFT_EMB_NORM = auto() # gemma3n
|
||||
A_MM_INP_PROJ = auto() # gemma3n
|
||||
A_PER_DIM_K_SCALE = auto() # gemma4
|
||||
A_PER_DIM_SCALE = auto() # gemma4
|
||||
# nextn/mtp
|
||||
NEXTN_EH_PROJ = auto()
|
||||
NEXTN_EMBED_TOKENS = auto()
|
||||
@@ -882,6 +898,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.GEMMA2: "gemma2",
|
||||
MODEL_ARCH.GEMMA3: "gemma3",
|
||||
MODEL_ARCH.GEMMA3N: "gemma3n",
|
||||
MODEL_ARCH.GEMMA4: "gemma4",
|
||||
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
@@ -1000,6 +1017,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm",
|
||||
MODEL_TENSOR.FFN_PRE_NORM_2: "blk.{bid}.pre_ffw_norm_2", # gemma4
|
||||
MODEL_TENSOR.FFN_POST_NORM_1: "blk.{bid}.post_ffw_norm_1", # gemma4
|
||||
MODEL_TENSOR.FFN_POST_NORM_2: "blk.{bid}.post_ffw_norm_2", # gemma4
|
||||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
||||
@@ -1019,6 +1039,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.MOE_LATENT_DOWN: "blk.{bid}.ffn_latent_down", # nemotron 3 super
|
||||
MODEL_TENSOR.MOE_LATENT_UP: "blk.{bid}.ffn_latent_up", # nemotron 3 super
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE: "blk.{bid}.layer_output_scale",
|
||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
|
||||
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n
|
||||
MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm", # gemma3n
|
||||
@@ -1183,8 +1204,11 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.V_ENC_ATTN_POST_NORM: "v.blk.{bid}.attn_post_norm",
|
||||
MODEL_TENSOR.V_ENC_FFN_POST_NORM: "v.blk.{bid}.ffn_post_norm",
|
||||
MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1",
|
||||
MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2",
|
||||
MODEL_TENSOR.V_LAYER_OUT_SCALE: "v.blk.{bid}.out_scale",
|
||||
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
|
||||
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
|
||||
MODEL_TENSOR.V_MM_POST_NORM: "mm.post_norm",
|
||||
@@ -1222,6 +1246,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.V_MM_GATE: "mm.gate",
|
||||
MODEL_TENSOR.V_TOK_BOI: "v.boi",
|
||||
MODEL_TENSOR.V_TOK_EOI: "v.eoi",
|
||||
MODEL_TENSOR.V_STD_BIAS: "v.std_bias", # gemma4
|
||||
MODEL_TENSOR.V_STD_SCALE: "v.std_scale", # gemma4
|
||||
# DeepSeek-OCR SAM
|
||||
MODEL_TENSOR.V_SAM_POS_EMBD: "v.sam.pos_embd",
|
||||
MODEL_TENSOR.V_SAM_PATCH_EMBD: "v.sam.patch_embd",
|
||||
@@ -1243,6 +1269,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
|
||||
MODEL_TENSOR.A_ENC_EMBD_NORM: "a.position_embd_norm",
|
||||
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits",
|
||||
MODEL_TENSOR.A_ENC_INP_PROJ: "a.input_projection",
|
||||
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
|
||||
MODEL_TENSOR.A_ENC_CONV1D_NORM: "a.conv1d.{bid}.norm",
|
||||
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
|
||||
@@ -1251,6 +1278,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.A_ENC_ATTN_POST_NORM: "a.blk.{bid}.attn_post_norm",
|
||||
MODEL_TENSOR.A_ENC_ATTN_PRE_NORM: "a.blk.{bid}.attn_pre_norm",
|
||||
MODEL_TENSOR.A_ENC_ATTN_K_REL: "a.blk.{bid}.attn_k_rel",
|
||||
MODEL_TENSOR.A_ENC_PER_DIM_SCALE: "a.blk.{bid}.per_dim_scale",
|
||||
MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
|
||||
MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
|
||||
@@ -1275,6 +1305,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.A_MM_SOFT_EMB_NORM: "mm.a.soft_emb_norm", # gemma3n
|
||||
MODEL_TENSOR.A_MM_EMBEDDING: "mm.a.embedding", # gemma3n
|
||||
MODEL_TENSOR.A_MM_HARD_EMB_NORM: "mm.a.hard_emb_norm", # gemma3n
|
||||
MODEL_TENSOR.A_PER_DIM_K_SCALE: "a.blk.{bid}.per_dim_k_scale", # gemma4
|
||||
MODEL_TENSOR.A_PER_DIM_SCALE: "a.blk.{bid}.per_dim_scale", # gemma4
|
||||
# lfm2 audio
|
||||
MODEL_TENSOR.A_ENC_NORM_CONV: "a.blk.{bid}.norm_conv",
|
||||
MODEL_TENSOR.A_ENC_LINEAR_POS: "a.blk.{bid}.linear_pos",
|
||||
@@ -1319,8 +1351,11 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.V_ENC_FFN_UP,
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE,
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN,
|
||||
MODEL_TENSOR.V_ENC_ATTN_POST_NORM,
|
||||
MODEL_TENSOR.V_ENC_FFN_POST_NORM,
|
||||
MODEL_TENSOR.V_LAYER_SCALE_1,
|
||||
MODEL_TENSOR.V_LAYER_SCALE_2,
|
||||
MODEL_TENSOR.V_LAYER_OUT_SCALE,
|
||||
MODEL_TENSOR.V_PRE_NORM,
|
||||
MODEL_TENSOR.V_POST_NORM,
|
||||
MODEL_TENSOR.V_MM_POST_NORM,
|
||||
@@ -1358,6 +1393,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.V_MM_GATE,
|
||||
MODEL_TENSOR.V_TOK_BOI,
|
||||
MODEL_TENSOR.V_TOK_EOI,
|
||||
MODEL_TENSOR.V_STD_BIAS,
|
||||
MODEL_TENSOR.V_STD_SCALE,
|
||||
MODEL_TENSOR.V_SAM_POS_EMBD,
|
||||
MODEL_TENSOR.V_SAM_PATCH_EMBD,
|
||||
MODEL_TENSOR.V_SAM_PRE_NORM,
|
||||
@@ -1375,6 +1412,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS,
|
||||
MODEL_TENSOR.A_ENC_EMBD_NORM,
|
||||
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS,
|
||||
MODEL_TENSOR.A_ENC_INP_PROJ,
|
||||
MODEL_TENSOR.A_ENC_CONV1D,
|
||||
MODEL_TENSOR.A_ENC_CONV1D_NORM,
|
||||
MODEL_TENSOR.A_PRE_NORM,
|
||||
@@ -1383,6 +1421,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.A_ENC_ATTN_Q,
|
||||
MODEL_TENSOR.A_ENC_ATTN_K,
|
||||
MODEL_TENSOR.A_ENC_ATTN_V,
|
||||
MODEL_TENSOR.A_ENC_ATTN_POST_NORM,
|
||||
MODEL_TENSOR.A_ENC_ATTN_PRE_NORM,
|
||||
MODEL_TENSOR.A_ENC_ATTN_K_REL,
|
||||
MODEL_TENSOR.A_ENC_PER_DIM_SCALE,
|
||||
MODEL_TENSOR.A_ENC_INPUT_NORM,
|
||||
MODEL_TENSOR.A_ENC_OUTPUT,
|
||||
@@ -1416,6 +1457,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.A_MM_SOFT_EMB_NORM,
|
||||
MODEL_TENSOR.A_MM_EMBEDDING,
|
||||
MODEL_TENSOR.A_MM_HARD_EMB_NORM,
|
||||
MODEL_TENSOR.A_PER_DIM_K_SCALE,
|
||||
MODEL_TENSOR.A_PER_DIM_SCALE,
|
||||
],
|
||||
MODEL_ARCH.LLAMA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
@@ -2273,6 +2316,38 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.LAUREL_R,
|
||||
MODEL_TENSOR.LAUREL_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.GEMMA4: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_UP_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_PRE_NORM_2,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM_1,
|
||||
MODEL_TENSOR.FFN_POST_NORM_2,
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE,
|
||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
|
||||
MODEL_TENSOR.PER_LAYER_MODEL_PROJ,
|
||||
MODEL_TENSOR.PER_LAYER_INP_GATE,
|
||||
MODEL_TENSOR.PER_LAYER_PROJ,
|
||||
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
|
||||
MODEL_TENSOR.PER_LAYER_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.GEMMA_EMBEDDING: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
@@ -4010,6 +4085,8 @@ class VisionProjectorType:
|
||||
GEMMA3 = "gemma3"
|
||||
GEMMA3NV = "gemma3nv"
|
||||
GEMMA3NA = "gemma3na"
|
||||
GEMMA4V = "gemma4v"
|
||||
GEMMA4A = "gemma4a"
|
||||
PHI4 = "phi4"
|
||||
IDEFICS3 = "idefics3"
|
||||
PIXTRAL = "pixtral"
|
||||
|
||||
@@ -799,6 +799,7 @@ class GGUFWriter:
|
||||
def add_shared_kv_layers(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
|
||||
|
||||
# if input is array, true means SWA and false means full_attention for each layer
|
||||
def add_sliding_window_pattern(self, value: int | Sequence[bool]) -> None:
|
||||
key = Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch)
|
||||
if isinstance(value, int):
|
||||
|
||||
@@ -401,6 +401,10 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.pre_mlp_layernorm", # afmoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_PRE_NORM_2: (
|
||||
"model.layers.{bid}.pre_feedforward_layernorm_2", # gemma4
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_POST_NORM: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
|
||||
@@ -411,6 +415,14 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.post_moe_norm", # grok-2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_POST_NORM_1: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm_1", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_POST_NORM_2: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm_2", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
"layers.{bid}.feed_forward.gate", # mixtral
|
||||
"model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe
|
||||
@@ -428,6 +440,7 @@ class TensorNameMap:
|
||||
"layers.{bid}.gate", # mistral-large
|
||||
"backbone.layers.{bid}.mixer.gate", # nemotron-h-moe
|
||||
"model.layers.{bid}.moe.gate", # step3.5
|
||||
"model.layers.{bid}.router.proj", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
@@ -570,6 +583,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_UP_EXP: (
|
||||
"model.layers.{bid}.mlp.experts.gate_up_proj",
|
||||
"model.layers.{bid}.experts.gate_up_proj", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.MOE_LATENT_DOWN: (
|
||||
@@ -629,6 +643,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
|
||||
"model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker
|
||||
"model.layers.{bid}.moe.down_proj", # step3.5
|
||||
"model.layers.{bid}.experts.down_proj", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
@@ -693,6 +708,10 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.final_layernorm", # bailingmoe2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE: (
|
||||
"model.layers.{bid}.layer_scalar", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
|
||||
"model.embed_tokens_per_layer", # gemma3n
|
||||
),
|
||||
@@ -1383,6 +1402,7 @@ class TensorNameMap:
|
||||
"model.vision_model.embeddings.patch_embedding", # Deepseek-OCR CLIP
|
||||
"siglip2.vision_model.embeddings.patch_embedding",
|
||||
"vision_model.radio_model.model.patch_generator.embedder", # Nemotron Nano v2 VL
|
||||
"model.vision_tower.patch_embedder.input_proj", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_NORM: (
|
||||
@@ -1400,6 +1420,7 @@ class TensorNameMap:
|
||||
"model.vision.patch_embedding.position_embedding", # cogvlm
|
||||
"visual.embeddings.position_embedding", # glm4v
|
||||
"vision_model.radio_model.model.patch_generator.pos_embed", # Nemotron Nano v2 VL
|
||||
"model.vision_tower.patch_embedder.position_embedding_table", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_IMGNL: (
|
||||
@@ -1430,12 +1451,14 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
|
||||
"siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # youtuvl
|
||||
"model.vision_model.transformer.layers.{bid}.self_attn.q_proj", # Deepseek-OCR CLIP, generated
|
||||
"vision_model.model.layers.{bid}.self_attn.q_proj.linear", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL
|
||||
"model.vision_tower.encoder.layer.{bid}.attention.q_norm", # Intern-S1
|
||||
"visual.blocks.{bid}.attn.q_norm", # GLM-OCR
|
||||
"vision_model.model.layers.{bid}.self_attn.q_norm", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_K: (
|
||||
@@ -1450,12 +1473,14 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
|
||||
"model.vision_model.transformer.layers.{bid}.self_attn.k_proj", # Deepseek-OCR CLIP, generated
|
||||
"siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj",
|
||||
"vision_model.model.layers.{bid}.self_attn.k_proj.linear", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL
|
||||
"model.vision_tower.encoder.layer.{bid}.attention.k_norm", # Intern-S1
|
||||
"visual.blocks.{bid}.attn.k_norm", # GLM-OCR
|
||||
"vision_model.model.layers.{bid}.self_attn.k_norm", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_V: (
|
||||
@@ -1470,6 +1495,7 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
|
||||
"siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj",
|
||||
"model.vision_model.transformer.layers.{bid}.self_attn.v_proj", # Deepseek-OCR CLIP, generated
|
||||
"vision_model.model.layers.{bid}.self_attn.v_proj.linear", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM: (
|
||||
@@ -1480,7 +1506,7 @@ class TensorNameMap:
|
||||
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
|
||||
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral-hf
|
||||
"vision_encoder.transformer.layers.{bid}.attention_norm", # pixtral
|
||||
"vision_model.model.layers.{bid}.input_layernorm", # llama4
|
||||
"vision_model.model.layers.{bid}.input_layernorm", # llama4, gemma4
|
||||
"visual.blocks.{bid}.norm1", # qwen2vl
|
||||
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
|
||||
"model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
|
||||
@@ -1505,6 +1531,7 @@ class TensorNameMap:
|
||||
"model.vision_model.transformer.layers.{bid}.self_attn.out_proj", # Deepseek-OCR CLIP
|
||||
"siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl
|
||||
"vision_model.radio_model.model.blocks.{bid}.attn.proj", # Nemotron Nano v2 VL
|
||||
"vision_model.model.layers.{bid}.self_attn.o_proj.linear", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
|
||||
@@ -1522,6 +1549,7 @@ class TensorNameMap:
|
||||
"model.vision_model.transformer.layers.{bid}.layer_norm2", # Deepseek-OCR CLIP
|
||||
"siglip2.vision_model.encoder.layers.{bid}.layer_norm2",
|
||||
"vision_model.radio_model.model.blocks.{bid}.norm2", # Nemotron Nano v2 VL
|
||||
"vision_model.model.layers.{bid}.pre_feedforward_layernorm", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: (
|
||||
@@ -1540,12 +1568,14 @@ class TensorNameMap:
|
||||
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
|
||||
"siglip2.vision_model.encoder.layers.{bid}.mlp.fc1",
|
||||
"vision_model.radio_model.model.blocks.{bid}.mlp.fc1", # Nemotron Nano v2 VL
|
||||
"vision_model.model.layers.{bid}.mlp.up_proj", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE: (
|
||||
"vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral-hf
|
||||
"vision_encoder.transformer.layers.{bid}.feed_forward.w1", # pixtral
|
||||
"visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl
|
||||
"vision_model.model.layers.{bid}.mlp.gate_proj", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN: (
|
||||
@@ -1564,6 +1594,15 @@ class TensorNameMap:
|
||||
"model.vision_model.transformer.layers.{bid}.mlp.fc2", # Deepseek-OCR CLIP
|
||||
"siglip2.vision_model.encoder.layers.{bid}.mlp.fc2",
|
||||
"vision_model.radio_model.model.blocks.{bid}.mlp.fc2", # Nemotron Nano v2 VL
|
||||
"vision_model.model.layers.{bid}.mlp.down_proj", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_POST_NORM: (
|
||||
"vision_model.model.layers.{bid}.post_attention_layernorm", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_POST_NORM: (
|
||||
"vision_model.model.layers.{bid}.post_feedforward_layernorm", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_LAYER_SCALE_1: (
|
||||
@@ -1576,6 +1615,10 @@ class TensorNameMap:
|
||||
"model.vision_tower.encoder.layer.{bid}.lambda_2", # Intern-S1
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_LAYER_OUT_SCALE: (
|
||||
"vision_model.model.layers.{bid}.layer_scalar", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_PRE_NORM: (
|
||||
"vision_tower.vision_model.pre_layrnorm",
|
||||
"vision_tower.ln_pre", # pixtral-hf
|
||||
@@ -1763,6 +1806,14 @@ class TensorNameMap:
|
||||
"model.vision.eoi", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_STD_BIAS: (
|
||||
"model.vision_tower.std_bias", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_STD_SCALE: (
|
||||
"model.vision_tower.std_scale", # gemma4
|
||||
),
|
||||
|
||||
# audio (mtmd)
|
||||
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS: (
|
||||
@@ -1782,10 +1833,15 @@ class TensorNameMap:
|
||||
"audio_tower.conv{bid}", # ultravox
|
||||
"conformer.pre_encode.conv.{bid}", # lfm2
|
||||
"model.audio_tower.subsample_conv_projection.conv_{bid}.conv", # gemma3n
|
||||
"conformer.subsample_conv_projection.layer{bid}.conv", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_CONV1D_NORM: (
|
||||
"model.audio_tower.subsample_conv_projection.conv_{bid}.norm", # gemma3n
|
||||
"conformer.subsample_conv_projection.layer{bid}.norm", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_INP_PROJ: (
|
||||
"conformer.subsample_conv_projection.input_proj_linear", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_PRE_NORM: (),
|
||||
@@ -1799,22 +1855,38 @@ class TensorNameMap:
|
||||
"audio_tower.layers.{bid}.self_attn.q_proj", # ultravox
|
||||
"conformer.layers.{bid}.self_attn.linear_q", # lfm2
|
||||
"conformer.layers.{bid}.attention.attn.q_proj", # gemma3n
|
||||
"conformer.layers.{bid}.self_attn.q_proj", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_ATTN_K: (
|
||||
"audio_tower.layers.{bid}.self_attn.k_proj", # ultravox
|
||||
"conformer.layers.{bid}.self_attn.linear_k", # lfm2
|
||||
"conformer.layers.{bid}.attention.attn.k_proj", # gemma3n
|
||||
"conformer.layers.{bid}.self_attn.k_proj", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_ATTN_V: (
|
||||
"audio_tower.layers.{bid}.self_attn.v_proj", # ultravox
|
||||
"conformer.layers.{bid}.self_attn.linear_v", # lfm2
|
||||
"conformer.layers.{bid}.attention.attn.v_proj", # gemma3n
|
||||
"conformer.layers.{bid}.self_attn.v_proj", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_ATTN_K_REL: (
|
||||
"conformer.layers.{bid}.self_attn.relative_k_proj", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_ATTN_POST_NORM: (
|
||||
"conformer.layers.{bid}.norm_post_attn", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_ATTN_PRE_NORM: (
|
||||
"conformer.layers.{bid}.norm_pre_attn", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_PER_DIM_SCALE: (
|
||||
"conformer.layers.{bid}.attention.attn.per_dim_scale", # gemma3n
|
||||
"conformer.layers.{bid}.self_attn.per_dim_scale", # gemma3n
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_LAYER_PRE_NORM: (
|
||||
@@ -1831,6 +1903,7 @@ class TensorNameMap:
|
||||
"audio_tower.layers.{bid}.self_attn.out_proj", # ultravox
|
||||
"conformer.layers.{bid}.self_attn.linear_out", # lfm2
|
||||
"conformer.layers.{bid}.attention.post", # gemma3n
|
||||
"conformer.layers.{bid}.self_attn.post", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_OUTPUT_NORM: (
|
||||
@@ -1842,10 +1915,12 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.A_ENC_FFN_NORM: (
|
||||
"conformer.layers.{bid}.norm_feed_forward1", # lfm2
|
||||
"conformer.layers.{bid}.ffw_layer_start.pre_layer_norm", # gemma3n
|
||||
"conformer.layers.{bid}.feed_forward1.pre_layer_norm", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_POST_NORM: (
|
||||
"conformer.layers.{bid}.ffw_layer_start.post_layer_norm", # gemma3n
|
||||
"conformer.layers.{bid}.feed_forward1.post_layer_norm", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_SCALE: (
|
||||
@@ -1856,6 +1931,7 @@ class TensorNameMap:
|
||||
"audio_tower.layers.{bid}.fc1", # ultravox
|
||||
"conformer.layers.{bid}.feed_forward1.linear1", # lfm2
|
||||
"conformer.layers.{bid}.ffw_layer_start.ffw_layer_1", # gemma3n
|
||||
"conformer.layers.{bid}.feed_forward1.ffw_layer_1", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_GATE: (),
|
||||
@@ -1864,25 +1940,30 @@ class TensorNameMap:
|
||||
"audio_tower.layers.{bid}.fc2", # ultravox
|
||||
"conformer.layers.{bid}.feed_forward1.linear2", # lfm2
|
||||
"conformer.layers.{bid}.ffw_layer_start.ffw_layer_2", # gemma3n
|
||||
"conformer.layers.{bid}.feed_forward1.ffw_layer_2", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_UP_1: (
|
||||
"conformer.layers.{bid}.feed_forward2.linear1", # lfm2
|
||||
"conformer.layers.{bid}.ffw_layer_end.ffw_layer_1", # gemma3n
|
||||
"conformer.layers.{bid}.feed_forward2.ffw_layer_1", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_DOWN_1: (
|
||||
"conformer.layers.{bid}.feed_forward2.linear2", # lfm2
|
||||
"conformer.layers.{bid}.ffw_layer_end.ffw_layer_2", # gemma3n
|
||||
"conformer.layers.{bid}.feed_forward2.ffw_layer_2", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_NORM_1: (
|
||||
"conformer.layers.{bid}.norm_feed_forward2", # lfm2
|
||||
"conformer.layers.{bid}.ffw_layer_end.pre_layer_norm", # gemma3n
|
||||
"conformer.layers.{bid}.feed_forward2.pre_layer_norm", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_POST_NORM_1: (
|
||||
"conformer.layers.{bid}.ffw_layer_end.post_layer_norm", # gemma3n
|
||||
"conformer.layers.{bid}.feed_forward2.post_layer_norm", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_SCALE_1: (
|
||||
@@ -1904,7 +1985,8 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.A_ENC_OUT: (
|
||||
"conformer.pre_encode.out", # lfm2
|
||||
"model.audio_tower.subsample_conv_projection.input_proj_linear", # gemma3n
|
||||
"model.audio_tower.subsample_conv_projection.input_proj_linear", # gemma3n (note: it should be A_ENC_INP_PROJ, this is a mistake; it should be corrected in C++ code when it's supported)
|
||||
"conformer.output_proj", # gemma4
|
||||
),
|
||||
|
||||
# note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors
|
||||
@@ -1918,6 +2000,7 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.A_MMPROJ_FC: (
|
||||
"audio.multi_modal_projector.linear", # qwen2audio
|
||||
"audio_tower.proj", # qwen2omni
|
||||
"model.audio_tower.output_proj" # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MM_NORM_PRE: (
|
||||
@@ -1953,6 +2036,14 @@ class TensorNameMap:
|
||||
"conformer.layers.{bid}.lconv1d.conv_norm", # gemma3n
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_PER_DIM_K_SCALE: (
|
||||
"conformer.layers.{bid}.attention.attn.per_dim_key_scale", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_PER_DIM_SCALE: (
|
||||
"conformer.layers.{bid}.attention.attn.per_dim_scale", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MM_EMBEDDING: (
|
||||
"model.embed_audio.embedding", # gemma3n
|
||||
),
|
||||
|
||||
266
models/templates/gemma4.jinja
Normal file
266
models/templates/gemma4.jinja
Normal file
@@ -0,0 +1,266 @@
|
||||
{%- macro format_parameters(properties, required) -%}
|
||||
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in properties | dictsort -%}
|
||||
{%- set add_comma = false -%}
|
||||
{%- if key not in standard_keys -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{ key }}:{
|
||||
{%- if value['description'] -%}
|
||||
description:<|"|>{{ value['description'] }}<|"|>
|
||||
{%- set add_comma = true -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'OBJECT' -%}
|
||||
,properties:{
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
{%- elif value is mapping -%}
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- if value['required'] -%}
|
||||
,required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
,items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'] | dictsort -%}
|
||||
{%- if item_value is not none -%}
|
||||
{%- if ns_items.found_first %},{% endif -%}
|
||||
{%- set ns_items.found_first = true -%}
|
||||
{%- if item_key == 'properties' -%}
|
||||
properties:{
|
||||
{%- if item_value is mapping -%}
|
||||
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- elif item_key == 'required' -%}
|
||||
required:[
|
||||
{%- for req_item in item_value -%}
|
||||
<|"|>{{- req_item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- elif item_key == 'type' -%}
|
||||
{%- if item_value is string -%}
|
||||
type:{{ format_argument(item_value | upper) }}
|
||||
{%- else -%}
|
||||
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ item_key }}:{{ format_argument(item_value) }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
type:<|"|>{{ value['type'] | upper }}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_function_declaration(tool_data) -%}
|
||||
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
|
||||
{%- set params = tool_data['function']['parameters'] -%}
|
||||
{%- if params -%}
|
||||
,parameters:{
|
||||
{%- if params['properties'] -%}
|
||||
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||
{%- endif -%}
|
||||
{%- if params['required'] -%}
|
||||
required:[
|
||||
{%- for item in params['required'] -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{{- ',' if not loop.last -}}
|
||||
{%- endfor -%}
|
||||
],
|
||||
{%- endif -%}
|
||||
{%- if params['type'] -%}
|
||||
type:<|"|>{{- params['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if 'response' in tool_data['function'] -%}
|
||||
{%- set response_declaration = tool_data['function']['response'] -%}
|
||||
,response:{
|
||||
{%- if response_declaration['description'] -%}
|
||||
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
|
||||
{%- endif -%}
|
||||
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
|
||||
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_argument(argument, escape_keys=True) -%}
|
||||
{%- if argument is string -%}
|
||||
{{- '<|"|>' + argument + '<|"|>' -}}
|
||||
{%- elif argument is boolean -%}
|
||||
{{- 'true' if argument else 'false' -}}
|
||||
{%- elif argument is mapping -%}
|
||||
{{- '{' -}}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in argument | dictsort -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{%- if escape_keys -%}
|
||||
{{- '<|"|>' + key + '<|"|>' -}}
|
||||
{%- else -%}
|
||||
{{- key -}}
|
||||
{%- endif -%}
|
||||
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- elif argument is sequence -%}
|
||||
{{- '[' -}}
|
||||
{%- for item in argument -%}
|
||||
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']' -}}
|
||||
{%- else -%}
|
||||
{{- argument -}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro strip_thinking(text) -%}
|
||||
{%- set ns = namespace(result='') -%}
|
||||
{%- for part in text.split('<channel|>') -%}
|
||||
{%- if '<|channel>' in part -%}
|
||||
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
|
||||
{%- else -%}
|
||||
{%- set ns.result = ns.result + part -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ns.result | trim -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{ bos_token }}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- messages[0]['content'] | trim -}}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if tools -%}
|
||||
{%- for tool in tools %}
|
||||
{{- '<|tool>' -}}
|
||||
{{- format_function_declaration(tool) | trim -}}
|
||||
{{- '<tool|>' -}}
|
||||
{%- endfor %}
|
||||
{%- set ns.prev_message_type = 'tool' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif %}
|
||||
|
||||
{#- Loop through messages -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
|
||||
{{- '<|turn>' + role + '\n' }}
|
||||
|
||||
{%- if message['tool_calls'] -%}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- set function = tool_call['function'] -%}
|
||||
{{- '<|tool_call>call:' + function['name'] + '{' -}}
|
||||
{%- if function['arguments'] is mapping -%}
|
||||
{%- set ns_args = namespace(found_first=false) -%}
|
||||
{%- for key, value in function['arguments'] | dictsort -%}
|
||||
{%- if ns_args.found_first %},{% endif -%}
|
||||
{%- set ns_args.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{%- elif function['arguments'] is string -%}
|
||||
{{- function['arguments'] -}}
|
||||
{%- endif -%}
|
||||
{{- '}<tool_call|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_responses'] -%}
|
||||
{#- Tool Response handling -#}
|
||||
{%- for tool_response in message['tool_responses'] -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if tool_response['response'] is mapping -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{' -}}
|
||||
{%- for key, value in tool_response['response'] | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['content'] is string -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(message['content']) -}}
|
||||
{%- else -%}
|
||||
{{- message['content'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is sequence -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'text' -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(item['text']) -}}
|
||||
{%- else -%}
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '\n\n<|image|>\n\n' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '\n\n<|video|>\n\n' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if not (message['tool_responses'] and not message['content']) -%}
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||
{{- '<|turn>model\n' -}}
|
||||
{%- endif -%}
|
||||
{%- if not enable_thinking | default(false) -%}
|
||||
{{- '<|channel>thought\n<channel|>' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
118
models/templates/ibm-granite-granite-4.0.jinja
Normal file
118
models/templates/ibm-granite-granite-4.0.jinja
Normal file
@@ -0,0 +1,118 @@
|
||||
{%- set tools_system_message_prefix = 'You are a helpful assistant with access to the following tools. You may call one or more tools to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>' %}
|
||||
{%- set tools_system_message_suffix = '\n</tools>\n\nFor each tool call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.' %}
|
||||
{%- set documents_system_message_prefix = 'You are a helpful assistant with access to the following documents. You may use one or more documents to assist with the user query.\n\nYou are given a list of documents within <documents></documents> XML tags:\n<documents>' %}
|
||||
{%- set documents_system_message_suffix = '\n</documents>\n\nWrite the response to the user\'s input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.' %}
|
||||
{%- set g4_default_system_message = 'You are a helpful assistant. Please ensure responses are professional, accurate, and safe.' %}
|
||||
{%- if available_tools is defined and available_tools %}
|
||||
{%- set tools = available_tools %}
|
||||
{%- endif %}
|
||||
{%- set ns = namespace(tools_system_message=tools_system_message_prefix,
|
||||
documents_system_message=documents_system_message_prefix,
|
||||
default_system_message=g4_default_system_message,
|
||||
system_message=''
|
||||
) %}
|
||||
{%- if tools %}
|
||||
{%- for tool in tools %}
|
||||
{%- set ns.tools_system_message = ns.tools_system_message + '\n' + (tool | tojson) %}
|
||||
{%- endfor %}
|
||||
{%- set ns.tools_system_message = ns.tools_system_message + tools_system_message_suffix %}
|
||||
{%- else %}
|
||||
{%- set ns.tools_system_message = '' %}
|
||||
{%- endif %}
|
||||
{%- if documents %}
|
||||
{%- for document in documents %}
|
||||
{%- set ns.documents_system_message = ns.documents_system_message + '\n' + (document | tojson) %}
|
||||
{%- endfor %}
|
||||
{%- set ns.documents_system_message = ns.documents_system_message + documents_system_message_suffix %}
|
||||
{%- else %}
|
||||
{%- set ns.documents_system_message = '' %}
|
||||
{%- endif %}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{%- if messages[0].content is string %}
|
||||
{%- set ns.system_message = messages[0].content %}
|
||||
{%- elif messages[0].content is iterable %}
|
||||
{%- for entry in messages[0].content %}
|
||||
{%- if entry.type== 'text' %}
|
||||
{%- if ns.system_message != '' %}
|
||||
{%- set ns.system_message = ns.system_message + '\n' %}
|
||||
{%- endif %}
|
||||
{%- set ns.system_message = ns.system_message + entry.text %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- if tools and documents %}
|
||||
{%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message + '\n\n' + ns.documents_system_message %}
|
||||
{%- elif tools %}
|
||||
{%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message %}
|
||||
{%- elif documents %}
|
||||
{%- set ns.system_message = ns.system_message + '\n\n' + ns.documents_system_message %}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{%- if tools and documents %}
|
||||
{%- set ns.system_message = ns.tools_system_message + '\n\n' + ns.documents_system_message %}
|
||||
{%- elif tools %}
|
||||
{%- set ns.system_message = ns.tools_system_message %}
|
||||
{%- elif documents %}
|
||||
{%- set ns.system_message = ns.documents_system_message %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if ns.system_message %}
|
||||
{{- '<|start_of_role|>system<|end_of_role|>' + ns.system_message + '<|end_of_text|>\n' }}
|
||||
{%- else %}
|
||||
{{- '<|start_of_role|>system<|end_of_role|>' + ns.default_system_message + '<|end_of_text|>\n' }}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- set content = namespace(val='') %}
|
||||
{%- if message.content is string %}
|
||||
{%- set content.val = message.content %}
|
||||
{%- else %}
|
||||
{%- if message.content is iterable %}
|
||||
{%- for entry in message.content %}
|
||||
{%- if entry.type== 'text' %}
|
||||
{%- if content.val != '' %}
|
||||
{%- set content.val = content.val + '\n' %}
|
||||
{%- endif %}
|
||||
{%- set content.val = content.val + entry.text %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if (message.role == 'user') or (message.role == 'system' and not loop.first) %}
|
||||
{{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val + '<|end_of_text|>\n' }}
|
||||
{%- elif message.role == 'assistant' %}
|
||||
{{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }}
|
||||
{%- if message.tool_calls %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if (loop.first and content.val) or (not loop.first) %}
|
||||
{{- '\n' }}
|
||||
{%- endif %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '<tool_call>\n{"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '", "arguments": ' }}
|
||||
{%- if tool_call.arguments is string %}
|
||||
{{- tool_call.arguments }}
|
||||
{%- else %}
|
||||
{{- tool_call.arguments | tojson }}
|
||||
{%- endif %}
|
||||
{{- '}\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|end_of_text|>\n' }}
|
||||
{%- elif message.role == 'tool' %}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != 'tool') %}
|
||||
{{- '<|start_of_role|>user<|end_of_role|>' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{{- content.val }}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != 'tool') %}
|
||||
{{- '<|end_of_text|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|start_of_role|>assistant<|end_of_role|>' }}
|
||||
{%- endif %}
|
||||
@@ -1 +1 @@
|
||||
50634c28837c24ac68b380b5750b41e701c87d73
|
||||
49f84a924f6ea4fc2ef73dbbd8cc4d734b54bd6d
|
||||
|
||||
@@ -73,6 +73,7 @@ add_library(llama
|
||||
models/gemma2-iswa.cpp
|
||||
models/gemma3.cpp
|
||||
models/gemma3n-iswa.cpp
|
||||
models/gemma4-iswa.cpp
|
||||
models/glm4-moe.cpp
|
||||
models/glm4.cpp
|
||||
models/gpt2.cpp
|
||||
|
||||
@@ -56,6 +56,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
{ LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_GEMMA3N, "gemma3n" },
|
||||
{ LLM_ARCH_GEMMA4, "gemma4" },
|
||||
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
@@ -165,6 +166,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
|
||||
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
|
||||
{ LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" },
|
||||
{ LLM_KV_EMBEDDING_LENGTH_PER_LAYER, "%s.embedding_length_per_layer_input" },
|
||||
{ LLM_KV_FEATURES_LENGTH, "%s.features_length" },
|
||||
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
|
||||
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
|
||||
@@ -238,6 +240,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
|
||||
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
|
||||
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
|
||||
{ LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" },
|
||||
@@ -364,6 +367,9 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM_1, "blk.%d.post_ffw_norm_1" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM_2, "blk.%d.post_ffw_norm_2" },
|
||||
{ LLM_TENSOR_FFN_PRE_NORM_2, "blk.%d.pre_ffw_norm_2" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
@@ -373,6 +379,7 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
|
||||
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
||||
{ LLM_TENSOR_LAYER_OUT_SCALE, "blk.%d.layer_output_scale" },
|
||||
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
||||
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
||||
{ LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" },
|
||||
@@ -1342,6 +1349,38 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
LLM_TENSOR_LAUREL_R,
|
||||
LLM_TENSOR_LAUREL_POST_NORM,
|
||||
};
|
||||
case LLM_ARCH_GEMMA4:
|
||||
return {
|
||||
LLM_TENSOR_ROPE_FREQS,
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_Q,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
LLM_TENSOR_ATTN_V,
|
||||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_GATE,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
LLM_TENSOR_FFN_GATE_UP_EXPS,
|
||||
LLM_TENSOR_FFN_DOWN_EXPS,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_POST_NORM,
|
||||
LLM_TENSOR_FFN_POST_NORM_1,
|
||||
LLM_TENSOR_FFN_POST_NORM_2,
|
||||
LLM_TENSOR_FFN_PRE_NORM_2,
|
||||
LLM_TENSOR_LAYER_OUT_SCALE,
|
||||
LLM_TENSOR_PER_LAYER_TOKEN_EMBD,
|
||||
LLM_TENSOR_PER_LAYER_MODEL_PROJ,
|
||||
LLM_TENSOR_PER_LAYER_PROJ_NORM,
|
||||
LLM_TENSOR_PER_LAYER_INP_GATE,
|
||||
LLM_TENSOR_PER_LAYER_PROJ,
|
||||
LLM_TENSOR_PER_LAYER_POST_NORM,
|
||||
};
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
@@ -2654,11 +2693,15 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_FFN_PRE_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_FFN_POST_NORM_1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_FFN_POST_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_LAYER_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_LAYER_OUT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_Q_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_KV_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
|
||||
@@ -60,6 +60,7 @@ enum llm_arch {
|
||||
LLM_ARCH_GEMMA2,
|
||||
LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_GEMMA3N,
|
||||
LLM_ARCH_GEMMA4,
|
||||
LLM_ARCH_GEMMA_EMBEDDING,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
@@ -169,6 +170,7 @@ enum llm_kv {
|
||||
LLM_KV_CONTEXT_LENGTH,
|
||||
LLM_KV_EMBEDDING_LENGTH,
|
||||
LLM_KV_EMBEDDING_LENGTH_OUT,
|
||||
LLM_KV_EMBEDDING_LENGTH_PER_LAYER,
|
||||
LLM_KV_FEATURES_LENGTH,
|
||||
LLM_KV_BLOCK_COUNT,
|
||||
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
|
||||
@@ -242,6 +244,7 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
|
||||
LLM_KV_ATTENTION_INDEXER_TOP_K,
|
||||
LLM_KV_ATTENTION_SHARED_KV_LAYERS,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_COUNT_SWA,
|
||||
@@ -369,6 +372,9 @@ enum llm_tensor {
|
||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_POST_NORM,
|
||||
LLM_TENSOR_FFN_POST_NORM_1,
|
||||
LLM_TENSOR_FFN_POST_NORM_2,
|
||||
LLM_TENSOR_FFN_PRE_NORM_2,
|
||||
LLM_TENSOR_FFN_GATE,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
@@ -393,6 +399,7 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
LLM_TENSOR_LAYER_OUT_NORM,
|
||||
LLM_TENSOR_LAYER_OUT_SCALE,
|
||||
LLM_TENSOR_POST_ATTN_NORM,
|
||||
LLM_TENSOR_POST_MLP_NORM,
|
||||
LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
|
||||
|
||||
@@ -60,7 +60,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
|
||||
{ "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE },
|
||||
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
|
||||
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
||||
{ "granite", LLM_CHAT_TEMPLATE_GRANITE_3_X },
|
||||
{ "granite-4.0", LLM_CHAT_TEMPLATE_GRANITE_4_0 },
|
||||
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
||||
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
||||
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
||||
@@ -191,7 +192,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
} else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) {
|
||||
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
|
||||
} else if (tmpl_contains("<|start_of_role|>")) {
|
||||
return LLM_CHAT_TEMPLATE_GRANITE;
|
||||
if (tmpl_contains("<tool_call>") || tmpl_contains("<tools>")) {
|
||||
return LLM_CHAT_TEMPLATE_GRANITE_4_0;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_GRANITE_3_X;
|
||||
} else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) {
|
||||
return LLM_CHAT_TEMPLATE_GIGACHAT;
|
||||
} else if (tmpl_contains("<|role_start|>")) {
|
||||
@@ -617,8 +621,8 @@ int32_t llm_chat_apply_template(
|
||||
ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
|
||||
// IBM Granite template
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_3_X) {
|
||||
// IBM Granite 3.x template
|
||||
for (const auto & message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|start_of_role|>" << role << "<|end_of_role|>";
|
||||
@@ -630,6 +634,20 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|start_of_role|>assistant<|end_of_role|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_4_0) {
|
||||
// IBM Granite 4.0 template
|
||||
for (const auto & message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "assistant_tool_call") {
|
||||
ss << "<|start_of_role|>assistant<|end_of_role|><|tool_call|>";
|
||||
} else {
|
||||
ss << "<|start_of_role|>" << role << "<|end_of_role|>";
|
||||
}
|
||||
ss << message->content << "<|end_of_text|>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|start_of_role|>assistant<|end_of_role|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
|
||||
// GigaChat template
|
||||
bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
|
||||
|
||||
@@ -39,7 +39,8 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_EXAONE_4,
|
||||
LLM_CHAT_TEMPLATE_EXAONE_MOE,
|
||||
LLM_CHAT_TEMPLATE_RWKV_WORLD,
|
||||
LLM_CHAT_TEMPLATE_GRANITE,
|
||||
LLM_CHAT_TEMPLATE_GRANITE_3_X,
|
||||
LLM_CHAT_TEMPLATE_GRANITE_4_0,
|
||||
LLM_CHAT_TEMPLATE_GIGACHAT,
|
||||
LLM_CHAT_TEMPLATE_MEGREZ,
|
||||
LLM_CHAT_TEMPLATE_YANDEX,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-context.h"
|
||||
#include "ggml.h"
|
||||
#include "stdint.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
// Reserve a new compute graph. It is valid until the next call to llama_graph_reserve.
|
||||
LLAMA_API struct ggml_cgraph * llama_graph_reserve(
|
||||
@@ -10,3 +10,47 @@ LLAMA_API struct ggml_cgraph * llama_graph_reserve(
|
||||
uint32_t n_tokens,
|
||||
uint32_t n_seqs,
|
||||
uint32_t n_outputs);
|
||||
|
||||
// Get the default ggml_type for a given ftype.
|
||||
LLAMA_API ggml_type llama_ftype_get_default_type(llama_ftype ftype);
|
||||
|
||||
// Quantization state.
|
||||
struct quantize_state_impl;
|
||||
|
||||
LLAMA_API quantize_state_impl * llama_quant_init(
|
||||
const llama_model * model,
|
||||
const llama_model_quantize_params * params);
|
||||
|
||||
LLAMA_API void llama_quant_free(quantize_state_impl * qs);
|
||||
|
||||
// Descriptor for constructing a mock model for quantization testing.
|
||||
struct llama_quant_model_desc {
|
||||
const char * architecture;
|
||||
uint32_t n_embd;
|
||||
uint32_t n_ff;
|
||||
uint32_t n_layer;
|
||||
uint32_t n_head;
|
||||
uint32_t n_head_kv;
|
||||
uint32_t n_expert;
|
||||
uint32_t n_embd_head_k;
|
||||
uint32_t n_embd_head_v;
|
||||
};
|
||||
|
||||
// Create a mock model from a metadata descriptor (for testing).
|
||||
// The returned model must be freed with llama_model_free().
|
||||
LLAMA_API llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc);
|
||||
|
||||
// Returns true if this tensor should be quantized (based on name, dims, params).
|
||||
LLAMA_API bool llama_quant_tensor_allows_quantization(
|
||||
const quantize_state_impl * qs,
|
||||
const ggml_tensor * tensor);
|
||||
|
||||
// Compute quantization type assignments for a list of tensors.
|
||||
// All tensors should be quantizable (use llama_quant_tensor_allows_quantization to filter).
|
||||
// result_types: caller-allocated array of n_tensors elements, filled with assigned types.
|
||||
LLAMA_API void llama_quant_compute_types(
|
||||
quantize_state_impl * qs,
|
||||
llama_ftype ftype,
|
||||
ggml_tensor ** tensors,
|
||||
ggml_type * result_types,
|
||||
size_t n_tensors);
|
||||
|
||||
@@ -209,6 +209,9 @@ struct llama_hparams {
|
||||
// qwen3vl deepstack
|
||||
uint32_t n_deepstack_layers = 0;
|
||||
|
||||
// gemma4 per-layer embedding
|
||||
uint32_t n_embd_per_layer = 0;
|
||||
|
||||
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/8141
|
||||
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
@@ -1261,6 +1261,31 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA4:
|
||||
{
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
|
||||
|
||||
uint32_t n_kv_shared_layers = 0;
|
||||
ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false);
|
||||
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers;
|
||||
hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling)
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, hparams.n_embd_per_layer);
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 35: type = LLM_TYPE_E2B; break;
|
||||
case 42: type = LLM_TYPE_E4B; break; // to confirm: E4B or E5B?
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
{
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
|
||||
@@ -4229,6 +4254,100 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA4:
|
||||
{
|
||||
const uint32_t n_embd_per_layer = hparams.n_embd_per_layer;
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
|
||||
if (n_embd_head_k != n_embd_head_v) {
|
||||
throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v");
|
||||
}
|
||||
if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) {
|
||||
throw std::runtime_error("Gemma 4 requires n_embd_head_k_swa == n_embd_head_v_swa");
|
||||
}
|
||||
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
if (n_embd_per_layer > 0) {
|
||||
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_per_layer * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_per_layer}, 0);
|
||||
}
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
int rope_freqs_flag = 0;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
const int64_t n_head = hparams.n_head(i);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(i);
|
||||
const int64_t n_embd_k = hparams.n_embd_k_gqa(i);
|
||||
const int64_t n_embd_v = hparams.n_embd_v_gqa(i);
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
// note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj)
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
if (!hparams.is_swa(i)) {
|
||||
// full_attention layers use rope_freqs for proportional rope
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_embd_head/2}, rope_freqs_flag);
|
||||
rope_freqs_flag = TENSOR_DUPLICATED;
|
||||
}
|
||||
|
||||
// handle use_double_wide_mlp
|
||||
int64_t n_ff_cur = hparams.n_ff(i);
|
||||
|
||||
// for expert layers, we use normal FFN as shared expert (same as python code)
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0);
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
// MoE router
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED);
|
||||
bool has_expert = layer.ffn_gate_inp != nullptr;
|
||||
|
||||
// norm
|
||||
if (has_expert) {
|
||||
layer.ffn_gate_inp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "scale", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_pre_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM_2, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_post_norm_1 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_1, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_post_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_2, "weight", i), {n_embd}, 0);
|
||||
|
||||
// MoE FFN
|
||||
layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||
|
||||
// per-expert scale will be loaded as down_exps_s at the end of the current switch case
|
||||
}
|
||||
|
||||
// per-layer embeddings
|
||||
if (n_embd_per_layer > 0) {
|
||||
layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_per_layer}, 0);
|
||||
layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_per_layer, n_embd}, 0);
|
||||
layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -8233,7 +8352,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
} else {
|
||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||
|
||||
if (arch == LLM_ARCH_GEMMA3N) {
|
||||
if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) {
|
||||
reuse = [&](int32_t il) {
|
||||
if (il >= (int32_t) hparams.n_layer_kv_from_start) {
|
||||
return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1);
|
||||
@@ -8486,6 +8605,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
{
|
||||
llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA4:
|
||||
{
|
||||
llm = std::make_unique<llm_build_gemma4_iswa>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
{
|
||||
llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
|
||||
@@ -9006,6 +9129,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_GEMMA2:
|
||||
case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_GEMMA3N:
|
||||
case LLM_ARCH_GEMMA4:
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
|
||||
@@ -270,6 +270,9 @@ struct llama_layer {
|
||||
struct ggml_tensor * ffn_norm = nullptr;
|
||||
struct ggml_tensor * ffn_norm_b = nullptr;
|
||||
struct ggml_tensor * ffn_post_norm = nullptr;
|
||||
struct ggml_tensor * ffn_post_norm_1 = nullptr; // gemma4
|
||||
struct ggml_tensor * ffn_post_norm_2 = nullptr; // gemma4
|
||||
struct ggml_tensor * ffn_pre_norm_2 = nullptr; // gemma4
|
||||
struct ggml_tensor * layer_out_norm = nullptr;
|
||||
struct ggml_tensor * layer_out_norm_b = nullptr;
|
||||
struct ggml_tensor * ffn_norm_exps = nullptr;
|
||||
@@ -285,6 +288,7 @@ struct llama_layer {
|
||||
|
||||
// ff MoE
|
||||
struct ggml_tensor * ffn_gate_inp = nullptr;
|
||||
struct ggml_tensor * ffn_gate_inp_s = nullptr; // gemma4
|
||||
struct ggml_tensor * ffn_gate_exps = nullptr;
|
||||
struct ggml_tensor * ffn_down_exps = nullptr;
|
||||
struct ggml_tensor * ffn_up_exps = nullptr;
|
||||
@@ -483,6 +487,9 @@ struct llama_layer {
|
||||
struct ggml_tensor * indexer_attn_k = nullptr;
|
||||
struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias
|
||||
|
||||
// gemma4 layer output scale
|
||||
struct ggml_tensor * out_scale = nullptr;
|
||||
|
||||
struct llama_layer_posnet posnet;
|
||||
|
||||
struct llama_layer_convnext convnext;
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
#include "llama.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-model-loader.h"
|
||||
#include "llama-ext.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <cinttypes>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
@@ -197,6 +197,7 @@ struct quantize_state_impl {
|
||||
|
||||
// per-tensor metadata, computed in the preliminary loop and used in the main loop
|
||||
struct tensor_metadata {
|
||||
std::string name;
|
||||
ggml_type target_type;
|
||||
tensor_category category;
|
||||
std::string remapped_imatrix_name;
|
||||
@@ -788,7 +789,7 @@ static bool tensor_requires_imatrix(const char * tensor_name, const ggml_type ds
|
||||
// given a file type, get the default tensor type
|
||||
//
|
||||
|
||||
static ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
|
||||
ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
|
||||
switch (ftype) {
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0: return GGML_TYPE_Q4_0;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_1: return GGML_TYPE_Q4_1;
|
||||
@@ -827,16 +828,32 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_S:
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_M: return GGML_TYPE_IQ3_S;
|
||||
|
||||
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
|
||||
default: return GGML_TYPE_COUNT;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void init_quantize_state_counters(quantize_state_impl & qs, std::vector<tensor_metadata> & metadata) {
|
||||
for (auto & tm : metadata) {
|
||||
tensor_category cat = tensor_get_category(tm.name);
|
||||
tm.category = cat;
|
||||
|
||||
if (category_is_attn_v(cat)) {
|
||||
++qs.n_attention_wv;
|
||||
}
|
||||
|
||||
if (cat == tensor_category::OUTPUT) {
|
||||
qs.has_tied_embeddings = false;
|
||||
}
|
||||
}
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer;
|
||||
}
|
||||
|
||||
//
|
||||
// main quantization driver
|
||||
//
|
||||
|
||||
static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
|
||||
ggml_type default_type;
|
||||
llama_ftype ftype = params->ftype;
|
||||
|
||||
int nthread = params->nthread;
|
||||
@@ -845,7 +862,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
nthread = std::thread::hardware_concurrency();
|
||||
}
|
||||
|
||||
default_type = llama_ftype_get_default_type(ftype);
|
||||
ggml_type default_type = llama_ftype_get_default_type(ftype);
|
||||
if (default_type == GGML_TYPE_COUNT) {
|
||||
throw std::runtime_error(format("invalid output file type %d\n", ftype));
|
||||
}
|
||||
|
||||
// mmap consistently increases speed on Linux, and also increases speed on Windows with
|
||||
// hot cache. It may cause a slowdown on macOS, possibly related to free memory.
|
||||
@@ -964,6 +984,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
});
|
||||
}
|
||||
|
||||
// compute tensor metadata once and cache it
|
||||
std::vector<tensor_metadata> metadata(tensors.size());
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
metadata[i].name = ggml_get_name(tensors[i]->tensor);
|
||||
}
|
||||
|
||||
// initialize quantization state counters and metadata categories
|
||||
init_quantize_state_counters(qs, metadata);
|
||||
|
||||
int idx = 0;
|
||||
uint16_t n_split = 1;
|
||||
|
||||
@@ -976,25 +1005,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
std::vector<gguf_context_ptr> ctx_outs(n_split);
|
||||
ctx_outs[0] = std::move(ctx_out);
|
||||
|
||||
// compute tensor metadata once and cache it
|
||||
std::vector<tensor_metadata> metadata(tensors.size());
|
||||
|
||||
// initialize quantization state before preliminary loop (counters for use_more_bits)
|
||||
{
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
const auto cat = tensor_get_category(tensors[i]->tensor->name);
|
||||
if (category_is_attn_v(cat)) {
|
||||
++qs.n_attention_wv;
|
||||
}
|
||||
if (cat == tensor_category::OUTPUT) {
|
||||
qs.has_tied_embeddings = false;
|
||||
}
|
||||
metadata[i].category = cat; // save and re-use the category while we're at it
|
||||
}
|
||||
// these also need to be set to n_layer by default
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer;
|
||||
}
|
||||
|
||||
// flag for --dry-run
|
||||
bool will_require_imatrix = false;
|
||||
|
||||
@@ -1005,7 +1015,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
const auto * it = tensors[i];
|
||||
const struct ggml_tensor * tensor = it->tensor;
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
|
||||
uint16_t i_split = params->keep_split ? it->idx : 0;
|
||||
if (!ctx_outs[i_split]) {
|
||||
@@ -1034,7 +1043,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
" - offending tensor: %s\n"
|
||||
" - target type: %s\n"
|
||||
"============================================================================\n\n",
|
||||
name.c_str(), ggml_type_name(metadata[i].target_type));
|
||||
metadata[i].name.c_str(), ggml_type_name(metadata[i].target_type));
|
||||
throw std::runtime_error("this quantization requires an imatrix!");
|
||||
}
|
||||
}
|
||||
@@ -1107,7 +1116,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
new_ofstream(weight.idx);
|
||||
}
|
||||
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
const size_t tensor_size = ggml_nbytes(tensor);
|
||||
|
||||
if (!params->dry_run) {
|
||||
@@ -1238,9 +1246,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
total_size_new += new_size;
|
||||
|
||||
// update the gguf meta data as we go
|
||||
gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
|
||||
GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
|
||||
gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
|
||||
gguf_set_tensor_type(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_type);
|
||||
GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), metadata[i].name.c_str())) == new_size);
|
||||
gguf_set_tensor_data(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_data);
|
||||
|
||||
// write tensor data + padding
|
||||
fout.write((const char *) new_data, new_size);
|
||||
@@ -1305,3 +1313,89 @@ uint32_t llama_model_quantize(
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Helper functions for external tools exposed in llama-ext.h
|
||||
//
|
||||
|
||||
quantize_state_impl * llama_quant_init(
|
||||
const llama_model * model,
|
||||
const llama_model_quantize_params * params) {
|
||||
return new quantize_state_impl(*model, params);
|
||||
}
|
||||
|
||||
void llama_quant_free(quantize_state_impl * qs) {
|
||||
delete qs;
|
||||
}
|
||||
|
||||
llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc) {
|
||||
struct llama_model_params mparams = llama_model_default_params();
|
||||
auto * model = new llama_model(mparams);
|
||||
|
||||
model->arch = llm_arch_from_string(desc->architecture);
|
||||
|
||||
// infer llm_type: only LLM_TYPE_70B matters for quantization logic
|
||||
if (model->arch == LLM_ARCH_LLAMA && desc->n_layer == 80 && desc->n_head != desc->n_head_kv) {
|
||||
model->type = LLM_TYPE_70B;
|
||||
}
|
||||
|
||||
model->hparams.n_embd = desc->n_embd;
|
||||
model->hparams.n_embd_head_k_full = desc->n_embd_head_k;
|
||||
model->hparams.n_embd_head_v_full = desc->n_embd_head_v;
|
||||
model->hparams.n_layer = desc->n_layer;
|
||||
model->hparams.n_expert = desc->n_expert;
|
||||
|
||||
for (uint32_t i = 0; i < desc->n_layer; i++) {
|
||||
model->hparams.n_head_arr[i] = desc->n_head;
|
||||
model->hparams.n_head_kv_arr[i] = desc->n_head_kv;
|
||||
model->hparams.n_ff_arr[i] = desc->n_ff;
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
bool llama_quant_tensor_allows_quantization(
|
||||
const quantize_state_impl * qs,
|
||||
const ggml_tensor * tensor) {
|
||||
return tensor_allows_quantization(qs->params, qs->model.arch, tensor);
|
||||
}
|
||||
|
||||
void llama_quant_compute_types(
|
||||
quantize_state_impl * qs,
|
||||
llama_ftype ftype,
|
||||
ggml_tensor ** tensors,
|
||||
ggml_type * result_types,
|
||||
size_t n_tensors) {
|
||||
// reset per-computation state
|
||||
qs->n_attention_wv = 0;
|
||||
qs->n_ffn_down = 0;
|
||||
qs->n_ffn_gate = 0;
|
||||
qs->n_ffn_up = 0;
|
||||
qs->i_attention_wv = 0;
|
||||
qs->i_ffn_down = 0;
|
||||
qs->i_ffn_gate = 0;
|
||||
qs->i_ffn_up = 0;
|
||||
qs->n_fallback = 0;
|
||||
qs->has_imatrix = false;
|
||||
qs->has_tied_embeddings = true;
|
||||
|
||||
// build metadata from tensor names
|
||||
std::vector<tensor_metadata> metadata(n_tensors);
|
||||
for (size_t i = 0; i < n_tensors; i++) {
|
||||
metadata[i].name = ggml_get_name(tensors[i]);
|
||||
}
|
||||
|
||||
// initialize counters and categories
|
||||
init_quantize_state_counters(*qs, metadata);
|
||||
|
||||
// use a local copy of params with the requested ftype
|
||||
llama_model_quantize_params local_params = *qs->params;
|
||||
local_params.ftype = ftype;
|
||||
|
||||
ggml_type default_type = llama_ftype_get_default_type(ftype);
|
||||
|
||||
// compute types
|
||||
for (size_t i = 0; i < n_tensors; i++) {
|
||||
result_types[i] = llama_tensor_get_type(*qs, &local_params, tensors[i], default_type, metadata[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1863,6 +1863,18 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
special_sep_id = LLAMA_TOKEN_NULL;
|
||||
special_pad_id = 3; // <|plamo:pad|>
|
||||
special_mask_id = LLAMA_TOKEN_NULL;
|
||||
} else if (tokenizer_model == "gemma4") {
|
||||
type = LLAMA_VOCAB_TYPE_SPM;
|
||||
|
||||
// default special tokens (to be read from GGUF)
|
||||
special_bos_id = LLAMA_TOKEN_NULL;
|
||||
special_eos_id = LLAMA_TOKEN_NULL;
|
||||
special_unk_id = LLAMA_TOKEN_NULL;
|
||||
special_sep_id = LLAMA_TOKEN_NULL;
|
||||
special_pad_id = LLAMA_TOKEN_NULL;
|
||||
special_mask_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
tokenizer_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
|
||||
}
|
||||
@@ -2490,6 +2502,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "[EOS]" // Kimi-K2
|
||||
|| t.first == "<|end_of_text|>"
|
||||
|| t.first == "<end_of_utterance>" // smoldocling
|
||||
|| t.first == "<turn|>" // gemma4
|
||||
|| t.first == "<|end▁of▁sentence|>" // deepseek-ocr
|
||||
) {
|
||||
special_eog_ids.insert(t.second);
|
||||
|
||||
311
src/models/gemma4-iswa.cpp
Normal file
311
src/models/gemma4-iswa.cpp
Normal file
@@ -0,0 +1,311 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params),
|
||||
model(model),
|
||||
n_embd_per_layer(model.hparams.n_embd_per_layer) {
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
|
||||
inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
|
||||
cb(inpL, "inp_scaled", -1);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// TODO: is causal == true correct? might need some changes
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer]
|
||||
ggml_tensor * inp_per_layer = nullptr;
|
||||
if (model.tok_embd_per_layer) {
|
||||
inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
|
||||
}
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(il);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v(il));
|
||||
|
||||
const int64_t n_head = hparams.n_head(il);
|
||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
const float freq_base_l = model.get_rope_freq_base(cparams, il);
|
||||
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
const int n_rot_l = hparams.n_rot(il);
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
ggml_tensor * freq_factors = nullptr;
|
||||
if (!hparams.is_swa(il)) {
|
||||
// full_attention layers use rope_freqs for proportional rope
|
||||
freq_factors = model.layers[il].rope_freqs;
|
||||
}
|
||||
|
||||
// Q projection (shared for both non-KV and KV layers)
|
||||
// this is to mirror Gemma4Attention in pytorch code
|
||||
ggml_tensor * Qcur;
|
||||
{
|
||||
Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur_pos", il);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
if (hparams.has_kv(il)) {
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = model.layers[il].wv
|
||||
? build_lora_mm(model.layers[il].wv, cur)
|
||||
: Kcur; // if v_proj is not present, use Kcur as Vcur
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
|
||||
Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
|
||||
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
cb(Vcur, "Vcur_normed", il);
|
||||
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
cb(Kcur, "Kcur_pos", il);
|
||||
|
||||
cur = build_attn(inp_attn, model.layers[il].wo,
|
||||
nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr,
|
||||
hparams.f_attention_scale, il);
|
||||
} else {
|
||||
// reuse KV cache of earlier layers
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, nullptr,
|
||||
Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
|
||||
}
|
||||
|
||||
// TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].attn_post_norm, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL);
|
||||
cb(attn_out, "attn_out", il);
|
||||
|
||||
// feed-forward network
|
||||
const bool is_moe_layer = model.layers[il].ffn_gate_inp != nullptr;
|
||||
if (is_moe_layer) {
|
||||
// MLP (shared exp)
|
||||
ggml_tensor * cur_mlp = build_norm(attn_out,
|
||||
model.layers[il].ffn_norm, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur_mlp, "ffn_norm_1", il);
|
||||
|
||||
cur_mlp = build_ffn(cur_mlp,
|
||||
model.layers[il].ffn_up, nullptr, nullptr,
|
||||
model.layers[il].ffn_gate, nullptr, nullptr,
|
||||
model.layers[il].ffn_down, nullptr, nullptr,
|
||||
nullptr,
|
||||
LLM_FFN_GELU, LLM_FFN_PAR, il);
|
||||
cur_mlp = build_norm(cur_mlp,
|
||||
model.layers[il].ffn_post_norm_1, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur_mlp, "ffn_mlp", il);
|
||||
|
||||
// Expert FFN
|
||||
ggml_tensor * cur_moe = build_norm(attn_out,
|
||||
model.layers[il].ffn_pre_norm_2, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur_moe, "ffn_norm_2", il);
|
||||
|
||||
// custom MoE logits calculation (router operates on attn_out, not cur)
|
||||
ggml_tensor * tmp = ggml_rms_norm(ctx0, attn_out, hparams.f_norm_rms_eps);
|
||||
tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) n_embd));
|
||||
tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_gate_inp_s);
|
||||
ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, tmp); // [n_expert, n_tokens]
|
||||
cb(logits, "ffn_moe_logits", il);
|
||||
|
||||
cur_moe = build_moe_ffn(cur_moe,
|
||||
nullptr, // gate_inp
|
||||
nullptr, // up_exps
|
||||
nullptr, // gate_exps
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr, // exp_probs_b (not used for gemma4)
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_GELU, true,
|
||||
1.0f,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il, logits,
|
||||
model.layers[il].ffn_gate_up_exps,
|
||||
nullptr, // up_exps_s
|
||||
nullptr, // gate_exps_s
|
||||
model.layers[il].ffn_down_exps_s);
|
||||
cur_moe = build_norm(cur_moe,
|
||||
model.layers[il].ffn_post_norm_2, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur_moe, "ffn_moe", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur_mlp, cur_moe);
|
||||
cb(cur, "ffn_moe_combined", il);
|
||||
} else {
|
||||
cur = build_norm(attn_out,
|
||||
model.layers[il].ffn_norm, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, nullptr, nullptr,
|
||||
model.layers[il].ffn_gate, nullptr, nullptr,
|
||||
model.layers[il].ffn_down, nullptr, nullptr,
|
||||
nullptr,
|
||||
LLM_FFN_GELU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].ffn_post_norm, nullptr,
|
||||
LLM_NORM_RMS, -1);
|
||||
cb(cur, "ffn_post_norm", il);
|
||||
|
||||
// residual connection
|
||||
cur = ggml_add(ctx0, cur, attn_out);
|
||||
|
||||
// per-layer embedding
|
||||
if (inp_per_layer) {
|
||||
ggml_tensor * pe_in = cur;
|
||||
cb(cur, "pe_in", il);
|
||||
|
||||
cur = build_lora_mm(model.layers[il].per_layer_inp_gate, cur); // [n_embd_per_layer, n_tokens]
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_per_layer, n_tokens]
|
||||
|
||||
// TODO @ngxson : improve this
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids);
|
||||
}
|
||||
|
||||
cur = ggml_mul(ctx0, cur, inp_this_layer);
|
||||
cur = build_lora_mm(model.layers[il].per_layer_proj, cur); // [n_embd, n_tokens]
|
||||
cur = build_norm(cur, model.layers[il].per_layer_post_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "per_layer_embd_out", il);
|
||||
|
||||
// residual connection
|
||||
cur = ggml_add(ctx0, pe_in, cur);
|
||||
}
|
||||
|
||||
// layer_scalar
|
||||
if (model.layers[il].out_scale) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].out_scale);
|
||||
cb(cur, "out_scaled", il);
|
||||
}
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, nullptr,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
if (hparams.f_final_logit_softcapping) {
|
||||
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
|
||||
cur = ggml_tanh(ctx0, cur);
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
|
||||
}
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
ggml_tensor * llm_build_gemma4_iswa::view_2d_slice(ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
// equivalent to get_per_layer_inputs() in python code
|
||||
// output shape: [n_embd_per_layer, n_layer, n_tokens]
|
||||
ggml_tensor * llm_build_gemma4_iswa::get_per_layer_inputs() {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
|
||||
ggml_tensor * inp_per_layer;
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, n_tokens);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_per_layer));
|
||||
cb(inp_per_layer, "inp_per_layer_selected", -1);
|
||||
res->add_input(std::move(inp));
|
||||
} else {
|
||||
// Vision embedding path: use padding token (ID=0) embedding
|
||||
// TODO: verify if this is the correct behavior in transformers implementation
|
||||
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_per_layer * n_layer
|
||||
|
||||
// Extract and dequantize padding token embedding (row 0)
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
|
||||
inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
|
||||
|
||||
// Reshape to [n_embd_per_layer, n_layer, 1]
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, 1);
|
||||
cb(inp_per_layer, "inp_per_layer_vision", -1);
|
||||
}
|
||||
return inp_per_layer;
|
||||
}
|
||||
|
||||
// equivalent to project_per_layer_inputs() in python code
|
||||
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
|
||||
// inputs_embeds shape: [n_embd, n_tokens]
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from get_per_layer_inputs)
|
||||
// output shape: [n_embd_per_layer, n_tokens, n_layer]
|
||||
ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
|
||||
const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
|
||||
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
|
||||
|
||||
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
|
||||
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens);
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS,
|
||||
-1); // [n_embd_per_layer, n_layer, n_tokens]
|
||||
cb(per_layer_proj, "per_layer_proj", -1);
|
||||
|
||||
inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
|
||||
cb(inp_per_layer, "inp_per_layer", -1);
|
||||
|
||||
// permute to shape: [n_embd_per_layer, n_tokens, n_layer]
|
||||
inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
|
||||
return inp_per_layer;
|
||||
}
|
||||
@@ -266,6 +266,17 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il);
|
||||
};
|
||||
|
||||
struct llm_build_gemma4_iswa : public llm_graph_context {
|
||||
const llama_model & model;
|
||||
|
||||
const int64_t n_embd_per_layer;
|
||||
|
||||
llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params);
|
||||
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx);
|
||||
ggml_tensor * get_per_layer_inputs();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer);
|
||||
};
|
||||
|
||||
struct llm_build_gemma_embedding : public llm_graph_context {
|
||||
llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
1
tests/.gitignore
vendored
1
tests/.gitignore
vendored
@@ -1,5 +1,6 @@
|
||||
*
|
||||
!*.*
|
||||
!snapshots/
|
||||
*.o
|
||||
ggml-common.h
|
||||
**/*.swp
|
||||
|
||||
@@ -274,6 +274,12 @@ if (TARGET cpp-httplib)
|
||||
add_executable(test-gguf-model-data test-gguf-model-data.cpp)
|
||||
target_link_libraries(test-gguf-model-data PRIVATE gguf-model-data common)
|
||||
llama_test(test-gguf-model-data LABEL "model")
|
||||
|
||||
# test-quant-type-selection requires gguf-model-data for remote model metadata
|
||||
llama_build_and_test(test-quant-type-selection.cpp LABEL "model")
|
||||
target_link_libraries(test-quant-type-selection PRIVATE gguf-model-data)
|
||||
target_compile_definitions(test-quant-type-selection PRIVATE
|
||||
SNAPSHOT_DIR="${CMAKE_CURRENT_SOURCE_DIR}/snapshots")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -287,3 +293,7 @@ target_include_directories(test-alloc PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
|
||||
|
||||
llama_build(export-graph-ops.cpp)
|
||||
target_include_directories(export-graph-ops PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
|
||||
if (TARGET gguf-model-data)
|
||||
target_link_libraries(export-graph-ops PRIVATE gguf-model-data)
|
||||
target_compile_definitions(export-graph-ops PRIVATE LLAMA_HF_FETCH)
|
||||
endif()
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "llama-cpp.h"
|
||||
#include "../src/llama-ext.h"
|
||||
#include "ggml.h"
|
||||
#include "gguf-model-data.h"
|
||||
#include "gguf.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "download.h"
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
|
||||
// Noop because weights are not needed
|
||||
static void set_tensor_data(struct ggml_tensor * tensor, void * userdata) {
|
||||
GGML_UNUSED(tensor);
|
||||
GGML_UNUSED(userdata);
|
||||
}
|
||||
|
||||
struct input_tensor {
|
||||
ggml_type type;
|
||||
@@ -132,9 +143,52 @@ int main(int argc, char ** argv) {
|
||||
|
||||
params.warmup = false;
|
||||
|
||||
auto init_result = common_init_from_params(params);
|
||||
llama_context * ctx;
|
||||
common_init_result_ptr init_result;
|
||||
llama_context_ptr ctx2;
|
||||
llama_model_ptr model;
|
||||
|
||||
llama_context * ctx = init_result->context();
|
||||
if (params.model.hf_repo.empty()) {
|
||||
init_result = common_init_from_params(params);
|
||||
|
||||
ctx = init_result->context();
|
||||
} else {
|
||||
#ifdef LLAMA_HF_FETCH
|
||||
auto [hf_repo, hf_quant] = common_download_split_repo_tag(params.model.hf_repo);
|
||||
if (hf_quant.empty() || hf_quant == "latest") {
|
||||
hf_quant = "Q4_K_M";
|
||||
}
|
||||
|
||||
gguf_context_ptr gguf_ctx = gguf_fetch_gguf_ctx(hf_repo, hf_quant);
|
||||
if (!gguf_ctx) {
|
||||
LOG_ERR("failed to fetch GGUF metadata from %s\n", hf_repo.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.devices = params.devices.data();
|
||||
model_params.no_alloc = true;
|
||||
|
||||
model.reset(llama_model_init_from_user(gguf_ctx.get(), set_tensor_data, nullptr, model_params));
|
||||
|
||||
if (!model) {
|
||||
LOG_ERR("failed to create llama_model from %s\n", hf_repo.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx2.reset(llama_init_from_model(model.get(), ctx_params));
|
||||
ctx = ctx2.get();
|
||||
|
||||
if (!ctx) {
|
||||
LOG_ERR("failed to create llama_context\n");
|
||||
return 1;
|
||||
}
|
||||
#else
|
||||
LOG_ERR("export-graph-ops compiled without HF fetch support\n");
|
||||
return 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = llama_n_seq_max(ctx);
|
||||
const uint32_t n_tokens = std::min(llama_n_ctx(ctx), llama_n_ubatch(ctx));
|
||||
@@ -143,13 +197,15 @@ int main(int argc, char ** argv) {
|
||||
|
||||
auto * gf_pp = llama_graph_reserve(ctx, n_tokens, n_seqs, n_tokens);
|
||||
if (!gf_pp) {
|
||||
throw std::runtime_error("failed to reserve prompt processing graph");
|
||||
LOG_ERR("failed to reserve prompt processing graph\n");
|
||||
return 1;
|
||||
}
|
||||
extract_graph_ops(gf_pp, "pp", tests);
|
||||
|
||||
auto * gf_tg = llama_graph_reserve(ctx, n_seqs, n_seqs, n_seqs);
|
||||
if (!gf_tg) {
|
||||
throw std::runtime_error("failed to reserve token generation graph");
|
||||
LOG_ERR("failed to reserve token generation graph\n");
|
||||
return 1;
|
||||
}
|
||||
extract_graph_ops(gf_tg, "tg", tests);
|
||||
|
||||
@@ -158,7 +214,8 @@ int main(int argc, char ** argv) {
|
||||
std::ofstream f(params.out_file);
|
||||
|
||||
if (!f.is_open()) {
|
||||
throw std::runtime_error("Unable to open output file");
|
||||
LOG_ERR("unable to open output file: %s\n", params.out_file.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (const auto& test : tests) {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "gguf-model-data.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "ggml-cpp.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -124,6 +125,35 @@ static bool gguf_skip_value(gguf_buf_reader & r, int32_t vtype) {
|
||||
}
|
||||
|
||||
static bool gguf_read_uint32_val(gguf_buf_reader & r, int32_t vtype, uint32_t & out) {
|
||||
// Handle array-valued fields (e.g. per-layer head counts in hybrid models)
|
||||
// by reading the first element as a representative value.
|
||||
if (vtype == GGUF_TYPE_ARRAY) {
|
||||
int32_t elem_type;
|
||||
uint64_t count;
|
||||
if (!r.read_val(elem_type)) {
|
||||
return false;
|
||||
}
|
||||
if (!r.read_val(count)) {
|
||||
return false;
|
||||
}
|
||||
if (count == 0) {
|
||||
return false;
|
||||
}
|
||||
// Read first element, skip the rest
|
||||
if (!gguf_read_uint32_val(r, elem_type, out)) {
|
||||
return false;
|
||||
}
|
||||
for (uint64_t i = 1; i < count; i++) {
|
||||
size_t sz = gguf_val_type_size(elem_type);
|
||||
if (sz == 0) {
|
||||
return false;
|
||||
}
|
||||
if (!r.skip(sz)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_UINT8) {
|
||||
uint8_t v;
|
||||
if (!r.read_val(v)) {
|
||||
@@ -486,7 +516,8 @@ static std::string detect_gguf_filename(const std::string & repo, const std::str
|
||||
static std::optional<gguf_remote_model> fetch_and_parse(
|
||||
const std::string & repo,
|
||||
const std::string & filename,
|
||||
const std::string & cache_path) {
|
||||
const std::string & cache_path,
|
||||
bool verbose) {
|
||||
std::string url = "https://huggingface.co/" + repo + "/resolve/main/" + filename;
|
||||
|
||||
// Progressive download inspired by RangeView.fetchChunk()
|
||||
@@ -495,7 +526,9 @@ static std::optional<gguf_remote_model> fetch_and_parse(
|
||||
const size_t max_chunk = 64 * 1024 * 1024;
|
||||
|
||||
while (chunk_size <= max_chunk) {
|
||||
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
|
||||
if (verbose) {
|
||||
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
|
||||
}
|
||||
|
||||
char range_buf[64];
|
||||
snprintf(range_buf, sizeof(range_buf), "bytes=0-%zu", chunk_size - 1);
|
||||
@@ -531,34 +564,42 @@ static std::optional<gguf_remote_model> fetch_and_parse(
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::string get_cache_file_path(const std::string& cdir, const std::string& repo_part, const std::string& filename) {
|
||||
std::string fname_part = sanitize_for_path(filename);
|
||||
return cdir + "/" + repo_part + "--" + fname_part + ".partial";
|
||||
}
|
||||
|
||||
// Try cache first, then fetch and parse a single GGUF shard.
|
||||
static std::optional<gguf_remote_model> fetch_or_cached(
|
||||
const std::string & repo,
|
||||
const std::string & filename,
|
||||
const std::string & cdir,
|
||||
const std::string & repo_part) {
|
||||
std::string fname_part = sanitize_for_path(filename);
|
||||
std::string cache_path = cdir + "/" + repo_part + "--" + fname_part + ".partial";
|
||||
const std::string & repo_part,
|
||||
bool verbose) {
|
||||
std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
|
||||
|
||||
{
|
||||
std::vector<char> cached;
|
||||
if (std::filesystem::exists(cache_path) && read_file(cache_path, cached)) {
|
||||
auto result = gguf_parse_meta(cached);
|
||||
if (result.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
|
||||
if (verbose) {
|
||||
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fs_create_directory_with_parents(cdir);
|
||||
return fetch_and_parse(repo, filename, cache_path);
|
||||
return fetch_and_parse(repo, filename, cache_path, verbose);
|
||||
}
|
||||
|
||||
std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
const std::string & repo,
|
||||
const std::string & quant,
|
||||
const std::string & cache_dir) {
|
||||
const std::string & cache_dir,
|
||||
bool verbose) {
|
||||
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
|
||||
std::string repo_part = sanitize_for_path(repo);
|
||||
|
||||
@@ -568,7 +609,7 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part, verbose);
|
||||
if (!model_opt.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
|
||||
return std::nullopt;
|
||||
@@ -583,8 +624,10 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
if (verbose) {
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
}
|
||||
|
||||
for (int i = 2; i <= model.n_split; i++) {
|
||||
char num_buf[6], total_buf[6];
|
||||
@@ -592,7 +635,7 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
|
||||
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
|
||||
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part, verbose);
|
||||
if (!shard.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
|
||||
return std::nullopt;
|
||||
@@ -611,3 +654,87 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
|
||||
return model_opt;
|
||||
}
|
||||
|
||||
gguf_context_ptr gguf_fetch_gguf_ctx(
|
||||
const std::string & repo,
|
||||
const std::string & quant,
|
||||
const std::string & cache_dir,
|
||||
bool verbose) {
|
||||
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
|
||||
std::string repo_part = sanitize_for_path(repo);
|
||||
|
||||
std::string split_prefix;
|
||||
std::string filename = detect_gguf_filename(repo, quant, split_prefix);
|
||||
|
||||
if (filename.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part, verbose);
|
||||
if (!model_opt.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto & model = model_opt.value();
|
||||
|
||||
const std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
|
||||
|
||||
ggml_context_ptr ggml_ctx_ptr;
|
||||
ggml_context * ggml_ctx{};
|
||||
gguf_init_params params{true, &ggml_ctx};
|
||||
gguf_context_ptr ctx{gguf_init_from_file(cache_path.c_str(), params)};
|
||||
ggml_ctx_ptr.reset(ggml_ctx);
|
||||
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "gguf_fetch: gguf_init_from_file failed\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// If the model is split across multiple files we need to fetch the remaining shards metadata
|
||||
if (model.n_split > 1) {
|
||||
if (split_prefix.empty()) {
|
||||
fprintf(stderr, "gguf_fetch: model reports %u splits but filename has no split pattern\n", model.n_split);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
}
|
||||
|
||||
for (int i = 2; i <= model.n_split; i++) {
|
||||
char num_buf[6], total_buf[6];
|
||||
snprintf(num_buf, sizeof(num_buf), "%05d", i);
|
||||
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
|
||||
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
|
||||
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part, verbose);
|
||||
if (!shard.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Load tensors from shard and add to main gguf_context
|
||||
const std::string shard_path = get_cache_file_path(cdir, repo_part, shard_name);
|
||||
ggml_context_ptr shard_ggml_ctx_ptr;
|
||||
ggml_context * shard_ggml_ctx{};
|
||||
gguf_init_params shard_params{true, &shard_ggml_ctx};
|
||||
gguf_context_ptr shard_ctx{gguf_init_from_file(shard_path.c_str(), shard_params)};
|
||||
shard_ggml_ctx_ptr.reset(shard_ggml_ctx);
|
||||
|
||||
if (shard_ctx == nullptr) {
|
||||
fprintf(stderr, "gguf_fetch: shard gguf_init_from_file failed\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(shard_ggml_ctx); t; t = ggml_get_next_tensor(shard_ggml_ctx, t)) {
|
||||
gguf_add_tensor(ctx.get(), t);
|
||||
}
|
||||
}
|
||||
|
||||
gguf_set_val_u16(ctx.get(), "split.count", 1);
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpp.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
@@ -39,4 +40,11 @@ struct gguf_remote_model {
|
||||
std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
const std::string & repo,
|
||||
const std::string & quant = "Q8_0",
|
||||
const std::string & cache_dir = ""); // empty = default
|
||||
const std::string & cache_dir = "", // empty = default
|
||||
bool verbose = true);
|
||||
|
||||
gguf_context_ptr gguf_fetch_gguf_ctx(
|
||||
const std::string & repo,
|
||||
const std::string & quant = "Q8_0",
|
||||
const std::string & cache_dir = "",
|
||||
bool verbose = true);
|
||||
|
||||
@@ -213,6 +213,66 @@ void test_gbnf_generation(testing &t) {
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("tagged choice inside sequence gets parenthesized", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("a") + p.tag("t", p.literal("b") | p.literal("c"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a" ("b" | "c")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("tagged sequence inside choice gets parenthesized", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.tag("t", p.literal("a") + p.literal("b")) | p.literal("c");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a" "b" | "c"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("atomic choice inside repetition gets parenthesized", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.one_or_more(p.atomic(p.literal("a") | p.literal("b")));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= ("a" | "b")+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("nested transparent wrappers get parenthesized", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("x") + p.tag("outer", p.atomic(p.literal("a") | p.literal("b")));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "x" ("a" | "b")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("emit only trigger rules (and references)", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
auto rule1 = p.rule("rule-1", p.literal("a") + p.ref("rule-2"));
|
||||
|
||||
3356
tests/snapshots/deepseek-v3.1.schema
Normal file
3356
tests/snapshots/deepseek-v3.1.schema
Normal file
File diff suppressed because it is too large
Load Diff
1452
tests/snapshots/gemma-3-4b-it.schema
Normal file
1452
tests/snapshots/gemma-3-4b-it.schema
Normal file
File diff suppressed because it is too large
Load Diff
4052
tests/snapshots/glm-4.6v.schema
Normal file
4052
tests/snapshots/glm-4.6v.schema
Normal file
File diff suppressed because it is too large
Load Diff
5597
tests/snapshots/gpt-oss-120b.schema
Normal file
5597
tests/snapshots/gpt-oss-120b.schema
Normal file
File diff suppressed because it is too large
Load Diff
3896
tests/snapshots/meta-llama-3.1-70b-instruct.schema
Normal file
3896
tests/snapshots/meta-llama-3.1-70b-instruct.schema
Normal file
File diff suppressed because it is too large
Load Diff
3354
tests/snapshots/nemotron-nano-3-30b-a3b.schema
Normal file
3354
tests/snapshots/nemotron-nano-3-30b-a3b.schema
Normal file
File diff suppressed because it is too large
Load Diff
1221
tests/snapshots/qwen3-0.6b.schema
Normal file
1221
tests/snapshots/qwen3-0.6b.schema
Normal file
File diff suppressed because it is too large
Load Diff
1905
tests/snapshots/qwen3-14b.schema
Normal file
1905
tests/snapshots/qwen3-14b.schema
Normal file
File diff suppressed because it is too large
Load Diff
2138
tests/snapshots/qwen3-coder-next.schema
Normal file
2138
tests/snapshots/qwen3-coder-next.schema
Normal file
File diff suppressed because it is too large
Load Diff
2406
tests/snapshots/qwen3.5-27b.schema
Normal file
2406
tests/snapshots/qwen3.5-27b.schema
Normal file
File diff suppressed because it is too large
Load Diff
2682
tests/snapshots/qwen3.5-397b-a17b.schema
Normal file
2682
tests/snapshots/qwen3.5-397b-a17b.schema
Normal file
File diff suppressed because it is too large
Load Diff
2450
tests/snapshots/step-3.5-flash.schema
Normal file
2450
tests/snapshots/step-3.5-flash.schema
Normal file
File diff suppressed because it is too large
Load Diff
@@ -354,6 +354,7 @@ int main_automated_tests(void) {
|
||||
std::string bos_token = "";
|
||||
std::string eos_token = "";
|
||||
bool supported_with_jinja = true;
|
||||
std::vector<llama_chat_message> extra_conversation = {};
|
||||
};
|
||||
std::vector<TestCase> test_cases {
|
||||
{
|
||||
@@ -604,6 +605,26 @@ int main_automated_tests(void) {
|
||||
/* .expected_output_jinja= */ "<seed:bos>system\nYou are a helpful assistant<seed:eos><seed:bos>user\nHello<seed:eos><seed:bos>assistant\nHi there<seed:eos><seed:bos>user\nWho are you<seed:eos><seed:bos>assistant\nI am an assistant<seed:eos><seed:bos>user\nAnother question<seed:eos><seed:bos>assistant\n",
|
||||
/* .bos_token= */ "<seed:bos>",
|
||||
/* .eos_token= */ "<seed:eos>",
|
||||
},
|
||||
{
|
||||
/* .name= */ "ibm-granite/granite-3.x (tool call)",
|
||||
/* .template_str= */ "{%- for message in messages %}\n {%- if message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\\n' }}\n {%- else %}\n {{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>\\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}",
|
||||
/* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>What is the weather?<|end_of_text|>\n<|start_of_role|>assistant_tool_call<|end_of_role|><|tool_call|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}]<|end_of_text|>\n<|start_of_role|>tool_response<|end_of_role|>{\"temperature\": 72}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>",
|
||||
/* .expected_output_jinja= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>What is the weather?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|><|tool_call|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}]<|end_of_text|>\n<|start_of_role|>tool_response<|end_of_role|>{\"temperature\": 72}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>",
|
||||
/* .bos_token= */ "",
|
||||
/* .eos_token= */ "",
|
||||
/* .supported_with_jinja= */ true,
|
||||
/* .extra_conversation= */ {{"user", "What is the weather?"}, {"assistant_tool_call", "[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}]"}, {"tool_response", "{\"temperature\": 72}"}},
|
||||
},
|
||||
{
|
||||
/* .name= */ "ibm-granite/granite-4.0 (tool call)",
|
||||
/* .template_str= */ "{%- for message in messages %}\n {%- if message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\\n' }}\n {%- else %}\n {{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>\\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}\n{# <tool_call> <tools> #}",
|
||||
/* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>What is the weather?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|><|tool_call|><tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}\n</tool_call><|end_of_text|>\n<|start_of_role|>tool_response<|end_of_role|>{\"temperature\": 72}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>",
|
||||
/* .expected_output_jinja= */ "",
|
||||
/* .bos_token= */ "",
|
||||
/* .eos_token= */ "",
|
||||
/* .supported_with_jinja= */ true,
|
||||
/* .extra_conversation= */ {{"user", "What is the weather?"}, {"assistant_tool_call", "<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}\n</tool_call>"}, {"tool_response", "{\"temperature\": 72}"}},
|
||||
}
|
||||
};
|
||||
std::vector<char> formatted_chat(1024);
|
||||
@@ -627,11 +648,13 @@ int main_automated_tests(void) {
|
||||
|
||||
for (const auto & test_case : test_cases) {
|
||||
std::cout << "\n\n=== " << test_case.name << " ===\n\n";
|
||||
formatted_chat.resize(1024);
|
||||
auto conv = conversation;
|
||||
conv.insert(conv.end(), test_case.extra_conversation.begin(), test_case.extra_conversation.end());
|
||||
formatted_chat.resize(2048);
|
||||
res = llama_chat_apply_template(
|
||||
test_case.template_str.c_str(),
|
||||
conversation.data(),
|
||||
conversation.size(),
|
||||
conv.data(),
|
||||
conv.size(),
|
||||
add_generation_prompt,
|
||||
formatted_chat.data(),
|
||||
formatted_chat.size()
|
||||
@@ -658,11 +681,15 @@ int main_automated_tests(void) {
|
||||
}
|
||||
std::cout << "\n\n=== " << test_case.name << " (jinja) ===\n\n";
|
||||
try {
|
||||
auto msgs = messages;
|
||||
for (const auto & msg : test_case.extra_conversation) {
|
||||
msgs.push_back(simple_msg(msg.role, msg.content));
|
||||
}
|
||||
auto output = format_using_common(
|
||||
test_case.template_str,
|
||||
test_case.bos_token,
|
||||
test_case.eos_token,
|
||||
messages);
|
||||
msgs);
|
||||
auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
|
||||
if (output != expected_output) {
|
||||
std::cout << "Template:```\n" << test_case.template_str << "\n```";
|
||||
|
||||
@@ -589,6 +589,51 @@ static common_chat_tool amount_tool{
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool toggle_tool{
|
||||
/* .name = */ "toggle",
|
||||
/* .description = */ "Toggle a feature",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to enable the feature"
|
||||
}
|
||||
},
|
||||
"required": ["enabled"]
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool nullable_tool{
|
||||
/* .name = */ "set_nullable",
|
||||
/* .description = */ "Set a nullable value",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"type": "null",
|
||||
"description": "A null value"
|
||||
}
|
||||
},
|
||||
"required": ["value"]
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool config_tool{
|
||||
/* .name = */ "set_config",
|
||||
/* .description = */ "Set configuration",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"description": "Configuration dict"
|
||||
}
|
||||
},
|
||||
"required": ["config"]
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool imaginary_number_tool{
|
||||
/* .name = */ "imaginary_number",
|
||||
/* .description = */ "Imaginary number converter",
|
||||
@@ -1869,6 +1914,130 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).expect_reconstruction().run();
|
||||
}
|
||||
|
||||
{
|
||||
// Google Gemma 4 (tool calling with Gemma4 dict format)
|
||||
auto tst = peg_tester("models/templates/gemma4.jinja");
|
||||
|
||||
tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).run();
|
||||
|
||||
// Simple tool call with string argument
|
||||
tst.test(
|
||||
"<|tool_call>call:get_time{city:<|\"|>London<|\"|>}<tool_call|>")
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_tool_calls("get_time", R"({"city": "London"})"))
|
||||
.run();
|
||||
|
||||
// Tool call with string argument containing special chars
|
||||
tst.test(
|
||||
"<|tool_call>call:get_time{city:<|\"|>San Francisco<|\"|>}<tool_call|>")
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_tool_calls("get_time", R"({"city": "San Francisco"})"))
|
||||
.run();
|
||||
|
||||
// Tool call with empty args
|
||||
tst.test(
|
||||
"<|tool_call>call:empty_args{}<tool_call|>")
|
||||
.tools({ empty_args_tool })
|
||||
.expect(message_with_tool_calls("empty_args", "{}"))
|
||||
.run();
|
||||
|
||||
// Tool call with string and content
|
||||
tst.test(
|
||||
"Hello, world!\nWhat's up?<|tool_call>call:get_time{city:<|\"|>Paris<|\"|>}<tool_call|>")
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_content_and_tool_call("Hello, world!\nWhat's up?", "get_time", R"({"city": "Paris"})"))
|
||||
.run();
|
||||
|
||||
// Parallel tool calls
|
||||
tst.test(
|
||||
"<|tool_call>call:get_time{city:<|\"|>London<|\"|>}<tool_call|>"
|
||||
"<|tool_call>call:get_weather{city:<|\"|>Paris<|\"|>}<tool_call|>")
|
||||
.tools({ get_time_tool, get_weather_tool })
|
||||
.parallel_tool_calls(true)
|
||||
.expect_tool_calls({
|
||||
{ "get_time", R"({"city": "London"})", "" },
|
||||
{ "get_weather", R"({"city": "Paris"})", "" },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with integer argument (number type)
|
||||
tst.test(
|
||||
"<|tool_call>call:special_function{arg1:42}<tool_call|>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_with_tool_calls("special_function", R"({"arg1": 42})"))
|
||||
.run();
|
||||
|
||||
// Tool call with negative number argument
|
||||
tst.test(
|
||||
"<|tool_call>call:special_function{arg1:-7}<tool_call|>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_with_tool_calls("special_function", R"({"arg1": -7})"))
|
||||
.run();
|
||||
|
||||
// Tool call with decimal number argument
|
||||
tst.test(
|
||||
"<|tool_call>call:amount{orig:3.14}<tool_call|>")
|
||||
.tools({ amount_tool })
|
||||
.expect(message_with_tool_calls("amount", R"({"orig": 3.14})"))
|
||||
.run();
|
||||
|
||||
// Tool call with boolean argument (true)
|
||||
tst.test(
|
||||
"<|tool_call>call:toggle{enabled:true}<tool_call|>")
|
||||
.tools({ toggle_tool })
|
||||
.expect(message_with_tool_calls("toggle", R"({"enabled": true})"))
|
||||
.run();
|
||||
|
||||
// Tool call with boolean argument (false)
|
||||
tst.test(
|
||||
"<|tool_call>call:toggle{enabled:false}<tool_call|>")
|
||||
.tools({ toggle_tool })
|
||||
.expect(message_with_tool_calls("toggle", R"({"enabled": false})"))
|
||||
.run();
|
||||
|
||||
// Tool call with null argument
|
||||
tst.test(
|
||||
"<|tool_call>call:set_nullable{value:null}<tool_call|>")
|
||||
.tools({ nullable_tool })
|
||||
.expect(message_with_tool_calls("set_nullable", R"({"value": null})"))
|
||||
.run();
|
||||
|
||||
// Tool call with array argument (todo list)
|
||||
tst.test(
|
||||
"<|tool_call>call:todo_list{todos:[<|\"|>buy milk<|\"|>,<|\"|>walk dog<|\"|>]}<tool_call|>")
|
||||
.tools({ todo_list })
|
||||
.expect(message_with_tool_calls("todo_list", R"({"todos":["buy milk","walk dog"]})"))
|
||||
.run();
|
||||
|
||||
// Tool call with object/dict argument
|
||||
tst.test(
|
||||
"<|tool_call>call:set_config{config:{theme:<|\"|>dark<|\"|>,count:3}}<tool_call|>")
|
||||
.tools({ config_tool })
|
||||
.expect(message_with_tool_calls("set_config", R"({"config":{"theme":"dark","count":3}})"))
|
||||
.run();
|
||||
|
||||
// Tool call with empty array
|
||||
tst.test(
|
||||
"<|tool_call>call:todo_list{todos:[]}<tool_call|>")
|
||||
.tools({ todo_list })
|
||||
.expect(message_with_tool_calls("todo_list", R"({"todos":[]})"))
|
||||
.run();
|
||||
|
||||
// Tool call with empty dict
|
||||
tst.test(
|
||||
"<|tool_call>call:set_config{config:{}}<tool_call|>")
|
||||
.tools({ config_tool })
|
||||
.expect(message_with_tool_calls("set_config", R"({"config":{}})"))
|
||||
.run();
|
||||
|
||||
// Tool call with scientific notation number
|
||||
tst.test(
|
||||
"<|tool_call>call:amount{orig:1.5e10}<tool_call|>")
|
||||
.tools({ amount_tool })
|
||||
.expect(message_with_tool_calls("amount", R"({"orig": 1.5e10})"))
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
// Qwen-QwQ-32B (reasoning model)
|
||||
auto tst = peg_tester("models/templates/Qwen-QwQ-32B.jinja");
|
||||
@@ -1929,6 +2098,22 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// .run();
|
||||
}
|
||||
|
||||
{
|
||||
// IBM Granite 4.0 (production template shared by h-tiny, h-small, micro)
|
||||
// Uses <tool_call> XML tags for tool calls, tools in system message
|
||||
auto tst = peg_tester("models/templates/ibm-granite-granite-4.0.jinja", detailed_debug);
|
||||
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
// ByteDance-Seed-OSS (reasoning and tool calling model)
|
||||
auto tst = peg_tester("models/templates/ByteDance-Seed-OSS.jinja", detailed_debug);
|
||||
@@ -3159,6 +3344,24 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_reasoning("I will execute python to say hello")
|
||||
.expect_content("")
|
||||
.run();
|
||||
|
||||
// Edge cases
|
||||
|
||||
// "<|channel|>commentary to=assistant" before reasoning
|
||||
tst.test(
|
||||
"<|channel|>commentary to=assistant<|channel|>analysis<|message|>I'm\nthinking<|end|><|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's "
|
||||
"up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist_thoughts)
|
||||
.run();
|
||||
|
||||
// "<|channel|>commentary to=assistant" before final message
|
||||
tst.test(
|
||||
"<|channel|>analysis<|message|>I'm\nthinking<|end|><|start|>assistant<|channel|>commentary to=assistant<|channel|>final<|message|>Hello, world!\nWhat's "
|
||||
"up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist_thoughts)
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
@@ -116,6 +116,39 @@ int main() {
|
||||
// Verify tensor count
|
||||
TEST_ASSERT(model3.tensors.size() == 780, "expected tensor count == 780");
|
||||
|
||||
// Test a hybrid-attention model with array-valued head counts
|
||||
auto result4 = gguf_fetch_model_meta("ggml-org/Step-3.5-Flash-GGUF", "Q4_K");
|
||||
if (!result4.has_value()) {
|
||||
fprintf(stderr, "FAIL: could not fetch Step-3.5-Flash metadata\n");
|
||||
return 1;
|
||||
}
|
||||
const auto & model4 = result4.value();
|
||||
|
||||
fprintf(stderr, "Architecture: %s\n", model4.architecture.c_str());
|
||||
fprintf(stderr, "n_embd: %u\n", model4.n_embd);
|
||||
fprintf(stderr, "n_ff: %u\n", model4.n_ff);
|
||||
fprintf(stderr, "n_vocab: %u\n", model4.n_vocab);
|
||||
fprintf(stderr, "n_layer: %u\n", model4.n_layer);
|
||||
fprintf(stderr, "n_head: %u\n", model4.n_head);
|
||||
fprintf(stderr, "n_head_kv: %u\n", model4.n_head_kv);
|
||||
fprintf(stderr, "n_expert: %u\n", model4.n_expert);
|
||||
fprintf(stderr, "n_embd_head_k: %u\n", model4.n_embd_head_k);
|
||||
fprintf(stderr, "n_embd_head_v: %u\n", model4.n_embd_head_v);
|
||||
fprintf(stderr, "tensors: %zu\n", model4.tensors.size());
|
||||
|
||||
TEST_ASSERT(model4.architecture == "step35", "expected architecture 'step35'");
|
||||
|
||||
TEST_ASSERT(model4.n_layer == 45, "expected n_layer == 45");
|
||||
TEST_ASSERT(model4.n_embd == 4096, "expected n_embd == 4096");
|
||||
TEST_ASSERT(model4.n_ff == 11264, "expected n_ff == 11264");
|
||||
TEST_ASSERT(model4.n_head == 64, "expected n_head == 64 (first element of per-layer array)");
|
||||
TEST_ASSERT(model4.n_head_kv == 8, "expected n_head_kv == 8 (first element of per-layer array)");
|
||||
TEST_ASSERT(model4.n_expert == 288, "expected n_expert == 288");
|
||||
TEST_ASSERT(model4.n_embd_head_k == 128, "expected n_embd_head_k == 128");
|
||||
TEST_ASSERT(model4.n_embd_head_v == 128, "expected n_embd_head_v == 128");
|
||||
TEST_ASSERT(model4.n_vocab == 128896, "expected n_vocab == 128896");
|
||||
TEST_ASSERT(model4.tensors.size() == 754, "expected tensor count == 754");
|
||||
|
||||
fprintf(stderr, "=== ALL TESTS PASSED ===\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -385,6 +385,9 @@ static int save_models(const llm_arch target_arch, const size_t seed, const ggml
|
||||
if (arch == LLM_ARCH_CHAMELEON) {
|
||||
continue; // Only half-implemented and to be removed in the future.
|
||||
}
|
||||
if (arch == LLM_ARCH_GEMMA4) {
|
||||
continue; // FIXME @ngxson
|
||||
}
|
||||
if (arch == LLM_ARCH_RWKV6 || arch == LLM_ARCH_RWKV6QWEN2 || arch == LLM_ARCH_RWKV7 || arch == LLM_ARCH_ARWKV7) {
|
||||
continue; // FIXME
|
||||
}
|
||||
@@ -451,6 +454,9 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
|
||||
if (arch == LLM_ARCH_CHAMELEON) {
|
||||
continue; // Only half-implemented and to be removed in the future.
|
||||
}
|
||||
if (arch == LLM_ARCH_GEMMA4) {
|
||||
continue; // FIXME @ngxson
|
||||
}
|
||||
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
|
||||
continue; // FIXME CUDA backend crashes.
|
||||
}
|
||||
|
||||
520
tests/test-quant-type-selection.cpp
Normal file
520
tests/test-quant-type-selection.cpp
Normal file
@@ -0,0 +1,520 @@
|
||||
#include "../src/llama-ext.h"
|
||||
#include "ggml-cpp.h"
|
||||
#include "gguf-model-data.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ftype name <-> enum mapping
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct ftype_name_entry {
|
||||
const char * name;
|
||||
llama_ftype ftype;
|
||||
};
|
||||
|
||||
static const ftype_name_entry ftype_name_table[] = {
|
||||
{ "F32", LLAMA_FTYPE_ALL_F32 },
|
||||
{ "F16", LLAMA_FTYPE_MOSTLY_F16 },
|
||||
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16 },
|
||||
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0 },
|
||||
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1 },
|
||||
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0 },
|
||||
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1 },
|
||||
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0 },
|
||||
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K },
|
||||
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S },
|
||||
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S },
|
||||
{ "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M },
|
||||
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L },
|
||||
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S },
|
||||
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M },
|
||||
{ "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S },
|
||||
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M },
|
||||
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K },
|
||||
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S },
|
||||
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M },
|
||||
{ "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS },
|
||||
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS },
|
||||
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S },
|
||||
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M },
|
||||
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS },
|
||||
{ "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS },
|
||||
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S },
|
||||
{ "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M },
|
||||
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL },
|
||||
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS },
|
||||
{ "TQ1_0", LLAMA_FTYPE_MOSTLY_TQ1_0 },
|
||||
{ "TQ2_0", LLAMA_FTYPE_MOSTLY_TQ2_0 },
|
||||
{ "MXFP4_MOE", LLAMA_FTYPE_MOSTLY_MXFP4_MOE },
|
||||
{ "NVFP4", LLAMA_FTYPE_MOSTLY_NVFP4 },
|
||||
};
|
||||
|
||||
static llama_ftype llama_ftype_from_name(const char * name) {
|
||||
for (const auto & e : ftype_name_table) {
|
||||
if (strcmp(name, e.name) == 0) {
|
||||
return e.ftype;
|
||||
}
|
||||
}
|
||||
return (llama_ftype) -1;
|
||||
}
|
||||
|
||||
static const char * llama_ftype_to_name(llama_ftype ftype) {
|
||||
for (const auto & e : ftype_name_table) {
|
||||
if (e.ftype == ftype) {
|
||||
return e.name;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ggml_type name lookup
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static ggml_type ggml_type_from_name(const std::string & name) {
|
||||
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
|
||||
const char * tname = ggml_type_name((ggml_type) i);
|
||||
if (tname && name == tname) {
|
||||
return (ggml_type) i;
|
||||
}
|
||||
}
|
||||
return GGML_TYPE_COUNT;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// File parser for snapshot files (quant type schemas)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct snapshot_section {
|
||||
llama_ftype ftype;
|
||||
ggml_type default_type;
|
||||
std::vector<std::pair<std::string, ggml_type>> overrides;
|
||||
};
|
||||
|
||||
// This function is pretty ugly, but it's a trade-off of readable snapshot files
|
||||
// versus readable parsing code
|
||||
static bool parse_snapshot_file(const std::string & path, std::vector<snapshot_section> & sections) {
|
||||
std::ifstream f(path);
|
||||
if (!f.good()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
snapshot_section * cur = nullptr;
|
||||
std::string line;
|
||||
|
||||
while (std::getline(f, line)) {
|
||||
if (line.empty() || line[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
|
||||
// section header: [FTYPE_NAME] default_type
|
||||
if (line[0] == '[') {
|
||||
auto close = line.find(']');
|
||||
if (close == std::string::npos) {
|
||||
fprintf(stderr, "parse error: missing ] in '%s'\n", line.c_str());
|
||||
return false;
|
||||
}
|
||||
std::string ftype_str = line.substr(1, close - 1);
|
||||
std::string default_str;
|
||||
size_t pos = close + 1;
|
||||
while (pos < line.size() && line[pos] == ' ') {
|
||||
pos++;
|
||||
}
|
||||
default_str = line.substr(pos);
|
||||
|
||||
llama_ftype ftype = llama_ftype_from_name(ftype_str.c_str());
|
||||
if ((int) ftype < 0) {
|
||||
fprintf(stderr, "parse error: unknown ftype '%s'\n", ftype_str.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_type dtype = ggml_type_from_name(default_str);
|
||||
if (dtype == GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "parse error: unknown default type '%s'\n", default_str.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
sections.push_back({ ftype, dtype, {} });
|
||||
cur = §ions.back();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!cur) {
|
||||
fprintf(stderr, "parse error: tensor line before any section: '%s'\n", line.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto sp = line.rfind(' ');
|
||||
if (sp == std::string::npos) {
|
||||
fprintf(stderr, "parse error: no space in tensor line: '%s'\n", line.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string tname = line.substr(0, sp);
|
||||
std::string ttype = line.substr(sp + 1);
|
||||
|
||||
ggml_type gt = ggml_type_from_name(ttype);
|
||||
if (gt == GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "parse error: unknown type '%s' for tensor '%s'\n", ttype.c_str(), tname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
cur->overrides.push_back({ tname, gt });
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Remote model support using gguf-model-data.cpp
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct remote_model_spec {
|
||||
const char * repo;
|
||||
const char * quant;
|
||||
};
|
||||
|
||||
// Get model name from repo: strip org prefix, strip -GGUF suffix,
|
||||
// and strip anything up to and including first '_' (e.g. "deepseek-ai_DeepSeek-V3.1").
|
||||
static std::string model_name_from_repo(const char * repo) {
|
||||
std::string s(repo);
|
||||
|
||||
auto slash = s.find('/');
|
||||
if (slash != std::string::npos) {
|
||||
s = s.substr(slash + 1);
|
||||
}
|
||||
|
||||
const std::string suffix = "-GGUF";
|
||||
if (s.size() >= suffix.size() && s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0) {
|
||||
s = s.substr(0, s.size() - suffix.size());
|
||||
}
|
||||
|
||||
auto underscore = s.find('_');
|
||||
if (underscore != std::string::npos) {
|
||||
s = s.substr(underscore + 1);
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
static std::string snapshot_file_from_name(const std::string & name) {
|
||||
std::string lower = name;
|
||||
for (auto & c : lower) {
|
||||
c = std::tolower(c);
|
||||
}
|
||||
return lower;
|
||||
}
|
||||
|
||||
static const remote_model_spec model_specs[] = {
|
||||
{ "ggml-org/Qwen3-0.6B-GGUF", "Q8_0" },
|
||||
{ "ggml-org/GLM-4.6V-GGUF", "Q8_0" },
|
||||
{ "ggml-org/Step-3.5-Flash-GGUF", "Q4_K" },
|
||||
{ "ggml-org/Qwen3-Coder-Next-GGUF", "Q8_0" },
|
||||
{ "ggml-org/Qwen3-14B-GGUF", "Q8_0" },
|
||||
{ "ggml-org/Nemotron-Nano-3-30B-A3B-GGUF", "Q8_0" },
|
||||
{ "ggml-org/gpt-oss-120b-GGUF", "mxfp4" },
|
||||
{ "ggml-org/gemma-3-4b-it-GGUF", "Q8_0" },
|
||||
{ "bartowski/Meta-Llama-3.1-70B-Instruct-GGUF", "Q4_K_M" },
|
||||
{ "bartowski/deepseek-ai_DeepSeek-V3.1-GGUF", "IQ1_M" },
|
||||
{ "bartowski/Qwen_Qwen3.5-397B-A17B-GGUF", "IQ1_S" }, // TODO: swap with ggml-org if/when it's released
|
||||
{ "bartowski/Qwen_Qwen3.5-27B-GGUF", "Q8_0" }, // TODO: swap with ggml-org if/when it's released
|
||||
};
|
||||
|
||||
static const int n_model_specs = (int) (sizeof(model_specs) / sizeof(model_specs[0]));
|
||||
|
||||
static llama_model * build_mock_model_from_remote(const gguf_remote_model & remote) {
|
||||
llama_quant_model_desc desc = {};
|
||||
desc.architecture = remote.architecture.c_str();
|
||||
desc.n_embd = remote.n_embd;
|
||||
desc.n_ff = remote.n_ff;
|
||||
desc.n_layer = remote.n_layer;
|
||||
desc.n_head = remote.n_head;
|
||||
desc.n_head_kv = remote.n_head_kv;
|
||||
desc.n_expert = remote.n_expert;
|
||||
desc.n_embd_head_k = remote.n_embd_head_k;
|
||||
desc.n_embd_head_v = remote.n_embd_head_v;
|
||||
return llama_quant_model_from_metadata(&desc);
|
||||
}
|
||||
|
||||
// Single ggml context holding all quantizable tensors for a model.
|
||||
struct mock_tensors {
|
||||
ggml_context_ptr ctx;
|
||||
std::vector<ggml_tensor *> tensors;
|
||||
};
|
||||
|
||||
static mock_tensors build_mock_tensors(const quantize_state_impl * qs, const gguf_remote_model & remote) {
|
||||
const size_t ctx_size = remote.tensors.size() * ggml_tensor_overhead();
|
||||
struct ggml_init_params params = { ctx_size, nullptr, true };
|
||||
ggml_context_ptr ctx(ggml_init(params));
|
||||
|
||||
std::vector<ggml_tensor *> result;
|
||||
|
||||
for (const auto & t : remote.tensors) {
|
||||
ggml_tensor * gt = ggml_new_tensor_4d(ctx.get(), GGML_TYPE_F32, t.ne[0], t.ne[1], t.ne[2], t.ne[3]);
|
||||
ggml_set_name(gt, t.name.c_str());
|
||||
if (llama_quant_tensor_allows_quantization(qs, gt)) {
|
||||
result.push_back(gt);
|
||||
}
|
||||
}
|
||||
|
||||
// sort by layer index then name, matching llama_model_loader::weight_name_comparer
|
||||
std::sort(result.begin(), result.end(), [](const ggml_tensor * a, const ggml_tensor * b) {
|
||||
int a_layer = -1, b_layer = -1;
|
||||
sscanf(a->name, "blk.%d.", &a_layer);
|
||||
sscanf(b->name, "blk.%d.", &b_layer);
|
||||
if (a_layer != b_layer) {
|
||||
return a_layer < b_layer;
|
||||
}
|
||||
return strcmp(a->name, b->name) < 0;
|
||||
});
|
||||
|
||||
return { std::move(ctx), std::move(result) };
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Generate mode: regenerate all snapshot files
|
||||
// Use this when either adding new models or modifying quants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static std::string generate_snapshot(const std::string & name,
|
||||
const gguf_remote_model & remote,
|
||||
quantize_state_impl * qs,
|
||||
mock_tensors & mt) {
|
||||
std::ostringstream out;
|
||||
|
||||
out << "# Model: " << name << "\n";
|
||||
out << "# n_embd=" << remote.n_embd << ", n_ff=" << remote.n_ff << ", n_vocab=" << remote.n_vocab
|
||||
<< ", n_layer=" << remote.n_layer << ", n_head=" << remote.n_head << ", n_head_kv=" << remote.n_head_kv;
|
||||
if (remote.n_expert > 0) {
|
||||
out << ", n_expert=" << remote.n_expert;
|
||||
}
|
||||
out << "\n";
|
||||
|
||||
for (int i = 0; i < LLAMA_FTYPE_GUESSED; i++) {
|
||||
llama_ftype ft = (llama_ftype) i;
|
||||
ggml_type default_type = llama_ftype_get_default_type(ft);
|
||||
if (default_type == GGML_TYPE_COUNT) {
|
||||
continue;
|
||||
}
|
||||
const char * fname = llama_ftype_to_name(ft);
|
||||
if (!fname) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<ggml_type> result_types(mt.tensors.size());
|
||||
llama_quant_compute_types(qs, ft, mt.tensors.data(), result_types.data(), mt.tensors.size());
|
||||
|
||||
out << "\n[" << fname << "] " << ggml_type_name(default_type) << "\n";
|
||||
for (size_t j = 0; j < mt.tensors.size(); j++) {
|
||||
if (result_types[j] != default_type) {
|
||||
out << ggml_get_name(mt.tensors[j]) << " " << ggml_type_name(result_types[j]) << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out.str();
|
||||
}
|
||||
|
||||
static int run_generate(const std::string & snapshot_dir) {
|
||||
fprintf(stderr, "This will overwrite all snapshot files in:\n %s\n", snapshot_dir.c_str());
|
||||
fprintf(stderr, "Continue? [y/N] ");
|
||||
int ch = fgetc(stdin);
|
||||
if (ch != 'y' && ch != 'Y') {
|
||||
fprintf(stderr, "Aborted.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
int n_written = 0;
|
||||
|
||||
for (int m = 0; m < n_model_specs; m++) {
|
||||
const auto & spec = model_specs[m];
|
||||
std::string name = model_name_from_repo(spec.repo);
|
||||
|
||||
fprintf(stderr, "Fetching model metadata for %s from %s...\n", name.c_str(), spec.repo);
|
||||
auto result = gguf_fetch_model_meta(spec.repo, spec.quant);
|
||||
if (!result.has_value()) {
|
||||
fprintf(stderr, "ERROR: could not fetch model metadata for %s\n", name.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto & remote = result.value();
|
||||
llama_model * model = build_mock_model_from_remote(remote);
|
||||
llama_model_quantize_params qparams = llama_model_quantize_default_params();
|
||||
quantize_state_impl * qs = llama_quant_init(model, &qparams);
|
||||
auto mt = build_mock_tensors(qs, remote);
|
||||
|
||||
std::string content = generate_snapshot(name, remote, qs, mt);
|
||||
std::string path = snapshot_dir + "/" + snapshot_file_from_name(name) + ".schema";
|
||||
|
||||
std::ofstream f(path);
|
||||
if (!f.good()) {
|
||||
fprintf(stderr, "ERROR: could not write %s\n", path.c_str());
|
||||
llama_quant_free(qs);
|
||||
llama_model_free(model);
|
||||
return 1;
|
||||
}
|
||||
f << content;
|
||||
n_written++;
|
||||
fprintf(stderr, " wrote %s\n", path.c_str());
|
||||
llama_quant_free(qs);
|
||||
llama_model_free(model);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%d files written\n", n_written);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test mode: compare against snapshot files
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static bool run_test_section(quantize_state_impl * qs, mock_tensors & mt, const snapshot_section & section) {
|
||||
// verify default_type matches what llama_ftype_get_default_type returns
|
||||
ggml_type computed_default = llama_ftype_get_default_type(section.ftype);
|
||||
if (computed_default != section.default_type) {
|
||||
printf(" FAIL [%s] default type mismatch: file says %s, code says %s\n", llama_ftype_to_name(section.ftype),
|
||||
ggml_type_name(section.default_type), ggml_type_name(computed_default));
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ggml_type> result_types(mt.tensors.size());
|
||||
llama_quant_compute_types(qs, section.ftype, mt.tensors.data(), result_types.data(), mt.tensors.size());
|
||||
|
||||
std::map<std::string, ggml_type> override_map(section.overrides.begin(), section.overrides.end());
|
||||
|
||||
bool all_pass = true;
|
||||
int n_override_found = 0;
|
||||
|
||||
for (size_t i = 0; i < mt.tensors.size(); i++) {
|
||||
const char * name = ggml_get_name(mt.tensors[i]);
|
||||
ggml_type got = result_types[i];
|
||||
|
||||
ggml_type expected = section.default_type;
|
||||
auto it = override_map.find(name);
|
||||
if (it != override_map.end()) {
|
||||
expected = it->second;
|
||||
n_override_found++;
|
||||
}
|
||||
|
||||
if (got != expected) {
|
||||
printf(" FAIL %-50s %-10s expected %s, got %s\n", name, llama_ftype_to_name(section.ftype),
|
||||
ggml_type_name(expected), ggml_type_name(got));
|
||||
all_pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_override_found != (int) section.overrides.size()) {
|
||||
printf(" FAIL [%s] override count mismatch: listed %d, matched %d\n", llama_ftype_to_name(section.ftype),
|
||||
(int) section.overrides.size(), n_override_found);
|
||||
all_pass = false;
|
||||
}
|
||||
|
||||
return all_pass;
|
||||
}
|
||||
|
||||
static int run_remote_tests(const std::string & snapshot_dir, const char * argv0) {
|
||||
int total_pass = 0;
|
||||
int total_fail = 0;
|
||||
int total_skip = 0;
|
||||
|
||||
for (int m = 0; m < n_model_specs; m++) {
|
||||
const auto & spec = model_specs[m];
|
||||
std::string name = model_name_from_repo(spec.repo);
|
||||
printf("=== %s ===\n", name.c_str());
|
||||
|
||||
auto result = gguf_fetch_model_meta(spec.repo, spec.quant, "", false);
|
||||
if (!result.has_value()) {
|
||||
printf(" SKIP (could not fetch model metadata)\n\n");
|
||||
total_skip++;
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto & remote = result.value();
|
||||
llama_model * model = build_mock_model_from_remote(remote);
|
||||
llama_model_quantize_params qparams = llama_model_quantize_default_params();
|
||||
quantize_state_impl * qs = llama_quant_init(model, &qparams);
|
||||
auto mt = build_mock_tensors(qs, remote);
|
||||
|
||||
std::string snapshot_path = snapshot_dir + "/" + snapshot_file_from_name(name) + ".schema";
|
||||
std::vector<snapshot_section> sections;
|
||||
if (!parse_snapshot_file(snapshot_path, sections)) {
|
||||
printf(" SKIP (could not read snapshot file: %s)\n\n", snapshot_path.c_str());
|
||||
llama_quant_free(qs);
|
||||
llama_model_free(model);
|
||||
total_skip++;
|
||||
continue;
|
||||
}
|
||||
|
||||
int model_pass = 0;
|
||||
int model_fail = 0;
|
||||
|
||||
for (const auto & section : sections) {
|
||||
bool pass = run_test_section(qs, mt, section);
|
||||
if (pass) {
|
||||
model_pass++;
|
||||
} else {
|
||||
model_fail++;
|
||||
}
|
||||
}
|
||||
|
||||
printf(" %s %s: %d/%d ftype sections passed (%d tensors)\n", model_fail == 0 ? "PASS" : "FAIL", name.c_str(),
|
||||
model_pass, model_pass + model_fail, (int) mt.tensors.size());
|
||||
printf("\n");
|
||||
|
||||
if (model_fail == 0) {
|
||||
total_pass++;
|
||||
} else {
|
||||
total_fail++;
|
||||
}
|
||||
|
||||
llama_quant_free(qs);
|
||||
llama_model_free(model);
|
||||
}
|
||||
|
||||
printf("%d/%d models passed", total_pass, total_pass + total_fail);
|
||||
if (total_skip > 0) {
|
||||
printf(", %d skipped", total_skip);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
if (total_fail > 0) {
|
||||
printf("\nIf these changes are intentional, regenerate snapshot files with:\n");
|
||||
printf(" %s --generate\n", argv0);
|
||||
}
|
||||
|
||||
return total_fail > 0 ? 1 : 0;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::string snapshot_dir = SNAPSHOT_DIR;
|
||||
bool generate = false;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
if (strcmp(argv[i], "--generate") == 0) {
|
||||
generate = true;
|
||||
} else if (strcmp(argv[i], "--snapshot-dir") == 0 && i + 1 < argc) {
|
||||
snapshot_dir = argv[++i];
|
||||
}
|
||||
}
|
||||
|
||||
if (generate) {
|
||||
return run_generate(snapshot_dir);
|
||||
}
|
||||
|
||||
// suppress llama log warnings during test (e.g. tensor type fallback messages)
|
||||
llama_log_set([](enum ggml_log_level, const char *, void *) {}, nullptr);
|
||||
|
||||
return run_remote_tests(snapshot_dir, argv[0]);
|
||||
}
|
||||
@@ -17,6 +17,7 @@ add_library(mtmd
|
||||
models/models.h
|
||||
models/cogvlm.cpp
|
||||
models/conformer.cpp
|
||||
models/gemma4v.cpp
|
||||
models/glm4v.cpp
|
||||
models/internvl.cpp
|
||||
models/kimivl.cpp
|
||||
|
||||
@@ -29,7 +29,7 @@ struct clip_graph {
|
||||
const int n_layer;
|
||||
const int n_mmproj_embd;
|
||||
const float eps;
|
||||
const float kq_scale;
|
||||
float kq_scale; // TODO: maybe move this to hparams
|
||||
const clip_flash_attn_type flash_attn_type;
|
||||
|
||||
ggml_context_ptr ctx0_ptr;
|
||||
|
||||
@@ -88,8 +88,11 @@
|
||||
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
|
||||
#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm
|
||||
#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm
|
||||
#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale
|
||||
#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale
|
||||
#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale
|
||||
#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale
|
||||
#define TN_LS_OUT "%s.blk.%d.out_scale.%s" // layer out scale (gemma4)
|
||||
#define TN_ATTN_POST_NORM "%s.blk.%d.attn_post_norm.%s" // post-attn norm (gemma4)
|
||||
#define TN_FFN_POST_NORM "%s.blk.%d.ffn_post_norm.%s" // post-FFN norm (gemma4)
|
||||
#define TN_LN_PRE "%s.pre_ln.%s"
|
||||
#define TN_LN_POST "%s.post_ln.%s"
|
||||
#define TN_LLAVA_PROJ "mm.%d.%s"
|
||||
@@ -213,6 +216,10 @@
|
||||
#define TN_MNV5_MSFA_FFN_PROJ_BN "v.msfa.ffn.pw_proj.bn.weight"
|
||||
#define TN_MNV5_MSFA_NORM "v.msfa.norm.weight"
|
||||
|
||||
// gemma4
|
||||
#define TN_STD_BIAS "v.std_bias"
|
||||
#define TN_STD_SCALE "v.std_scale"
|
||||
|
||||
|
||||
// align x to upper multiple of n
|
||||
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
||||
@@ -233,6 +240,8 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_GEMMA3,
|
||||
PROJECTOR_TYPE_GEMMA3NV,
|
||||
PROJECTOR_TYPE_GEMMA3NA,
|
||||
PROJECTOR_TYPE_GEMMA4V,
|
||||
PROJECTOR_TYPE_GEMMA4A,
|
||||
PROJECTOR_TYPE_PHI4,
|
||||
PROJECTOR_TYPE_IDEFICS3,
|
||||
PROJECTOR_TYPE_PIXTRAL,
|
||||
@@ -272,6 +281,8 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||
{ PROJECTOR_TYPE_GEMMA3NV, "gemma3nv"},
|
||||
{ PROJECTOR_TYPE_GEMMA3NA, "gemma3na"},
|
||||
{ PROJECTOR_TYPE_GEMMA4V, "gemma4v"},
|
||||
{ PROJECTOR_TYPE_GEMMA4A, "gemma4a"},
|
||||
{ PROJECTOR_TYPE_PHI4, "phi4"},
|
||||
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
||||
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
||||
@@ -476,6 +487,18 @@ static std::vector<std::string> string_split_str(std::string s, const std::strin
|
||||
return tokens;
|
||||
}
|
||||
|
||||
// remove when moving to c++20
|
||||
inline bool string_starts_with(std::string_view str, std::string_view prefix) {
|
||||
return str.size() >= prefix.size() &&
|
||||
str.compare(0, prefix.size(), prefix) == 0;
|
||||
}
|
||||
|
||||
// remove when moving to c++20
|
||||
inline bool string_ends_with(std::string_view str, std::string_view suffix) {
|
||||
return str.size() >= suffix.size() &&
|
||||
str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
//
|
||||
// gguf utils
|
||||
//
|
||||
|
||||
@@ -143,6 +143,10 @@ struct clip_hparams {
|
||||
};
|
||||
|
||||
struct clip_layer {
|
||||
// layernorm 1 (or layer input norm, or pre-attention norm)
|
||||
ggml_tensor * ln_1_w = nullptr;
|
||||
ggml_tensor * ln_1_b = nullptr;
|
||||
|
||||
// attention
|
||||
ggml_tensor * k_w = nullptr;
|
||||
ggml_tensor * k_b = nullptr;
|
||||
@@ -159,9 +163,7 @@ struct clip_layer {
|
||||
ggml_tensor * k_norm = nullptr;
|
||||
ggml_tensor * q_norm = nullptr;
|
||||
|
||||
// layernorm 1
|
||||
ggml_tensor * ln_1_w = nullptr;
|
||||
ggml_tensor * ln_1_b = nullptr;
|
||||
ggml_tensor * attn_post_norm_w = nullptr;
|
||||
|
||||
ggml_tensor * ff_up_w = nullptr;
|
||||
ggml_tensor * ff_up_b = nullptr;
|
||||
@@ -170,13 +172,16 @@ struct clip_layer {
|
||||
ggml_tensor * ff_down_w = nullptr;
|
||||
ggml_tensor * ff_down_b = nullptr;
|
||||
|
||||
// layernorm 2
|
||||
// layernorm 2 (or pre-FFN norm)
|
||||
ggml_tensor * ln_2_w = nullptr;
|
||||
ggml_tensor * ln_2_b = nullptr;
|
||||
|
||||
ggml_tensor * ff_post_norm_w = nullptr;
|
||||
|
||||
// layer scale (no bias)
|
||||
ggml_tensor * ls_1_w = nullptr;
|
||||
ggml_tensor * ls_2_w = nullptr;
|
||||
ggml_tensor * ls_1_w = nullptr;
|
||||
ggml_tensor * ls_2_w = nullptr;
|
||||
ggml_tensor * ls_out_w = nullptr; // gemma4
|
||||
|
||||
// qwen3vl deepstack merger
|
||||
ggml_tensor * deepstack_norm_w = nullptr;
|
||||
@@ -437,6 +442,18 @@ struct clip_model {
|
||||
ggml_tensor * pre_encode_out_w = nullptr;
|
||||
ggml_tensor * pre_encode_out_b = nullptr;
|
||||
|
||||
// gemma4
|
||||
ggml_tensor * std_bias = nullptr;
|
||||
ggml_tensor * std_scale = nullptr;
|
||||
// Gemma4ClippableLinear
|
||||
struct clamp_info {
|
||||
float inp_max;
|
||||
float inp_min;
|
||||
float out_max;
|
||||
float out_min;
|
||||
};
|
||||
std::map<std::string, clamp_info> clamp_info_map;
|
||||
|
||||
bool audio_has_avgpool() const {
|
||||
return proj_type == PROJECTOR_TYPE_QWEN2A
|
||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#include <limits>
|
||||
#include <array>
|
||||
#include <functional>
|
||||
#include <float.h>
|
||||
|
||||
struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL};
|
||||
|
||||
@@ -379,19 +380,34 @@ ggml_tensor * clip_graph::build_vit(
|
||||
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
|
||||
}
|
||||
|
||||
if (layer.q_norm) {
|
||||
Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
|
||||
cb(Qcur, "Qcur_norm", il);
|
||||
}
|
||||
// if true, norm must be applied after reshaping to (d_head, n_head, n_pos)
|
||||
bool norm_per_head = layer.q_norm && layer.q_norm->ne[0] == d_head;
|
||||
|
||||
if (layer.k_norm) {
|
||||
Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
|
||||
cb(Kcur, "Kcur_norm", il);
|
||||
if (!norm_per_head) {
|
||||
if (layer.q_norm) {
|
||||
Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
|
||||
cb(Qcur, "Qcur_norm", il);
|
||||
}
|
||||
if (layer.k_norm) {
|
||||
Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
|
||||
cb(Kcur, "Kcur_norm", il);
|
||||
}
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
|
||||
|
||||
if (norm_per_head) {
|
||||
if (layer.q_norm) {
|
||||
Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
|
||||
cb(Qcur, "Qcur_norm_per_head", il);
|
||||
}
|
||||
if (layer.k_norm) {
|
||||
Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
|
||||
cb(Kcur, "Kcur_norm_per_head", il);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
@@ -405,6 +421,11 @@ ggml_tensor * clip_graph::build_vit(
|
||||
cb(Kcur, "Kcur_pos", il);
|
||||
}
|
||||
|
||||
if (proj_type == PROJECTOR_TYPE_GEMMA4V) {
|
||||
Vcur = ggml_rms_norm(ctx0, Vcur, eps);
|
||||
cb(Vcur, "Vcur_normed", il);
|
||||
}
|
||||
|
||||
cur = build_attn(layer.o_w, layer.o_b,
|
||||
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
@@ -415,6 +436,11 @@ ggml_tensor * clip_graph::build_vit(
|
||||
cb(cur, "attn_out_scaled", il);
|
||||
}
|
||||
|
||||
if (layer.attn_post_norm_w) {
|
||||
cur = build_norm(cur, layer.attn_post_norm_w, nullptr, norm_t, eps, il);
|
||||
cb(cur, "attn_post_normed", il);
|
||||
}
|
||||
|
||||
// re-add the layer input, e.g., residual
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
|
||||
@@ -422,7 +448,7 @@ ggml_tensor * clip_graph::build_vit(
|
||||
|
||||
cb(cur, "ffn_inp", il);
|
||||
|
||||
// layernorm2
|
||||
// layernorm2 (pre-ffn norm)
|
||||
cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
|
||||
cb(cur, "ffn_inp_normed", il);
|
||||
|
||||
@@ -435,6 +461,11 @@ ggml_tensor * clip_graph::build_vit(
|
||||
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
if (layer.ff_post_norm_w) {
|
||||
cur = build_norm(cur, layer.ff_post_norm_w, nullptr, norm_t, eps, il);
|
||||
cb(cur, "ffn_post_normed", il);
|
||||
}
|
||||
|
||||
if (layer.ls_2_w) {
|
||||
cur = ggml_mul(ctx0, cur, layer.ls_2_w);
|
||||
cb(cur, "ffn_out_scaled", il);
|
||||
@@ -444,6 +475,11 @@ ggml_tensor * clip_graph::build_vit(
|
||||
cur = ggml_add(ctx0, inpL, cur);
|
||||
cb(cur, "layer_out", il);
|
||||
|
||||
if (layer.ls_out_w) {
|
||||
cur = ggml_mul(ctx0, cur, layer.ls_out_w);
|
||||
cb(cur, "layer_out_scaled", il);
|
||||
}
|
||||
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
@@ -808,6 +844,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
{
|
||||
builder = std::make_unique<clip_graph_mobilenetv5>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4V:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_gemma4v>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
@@ -1257,6 +1297,17 @@ struct clip_model_loader {
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
|
||||
} break;
|
||||
|
||||
case PROJECTOR_TYPE_GEMMA4V:
|
||||
{
|
||||
hparams.rope_theta = 100.0f;
|
||||
hparams.n_merge = 3; // pooling_kernel_size
|
||||
hparams.image_resize_algo = RESIZE_ALGO_BILINEAR;
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
|
||||
// @ngxson : the model performs quite poor with small images, we need to bump minimum image tokens to 40 to avoid that
|
||||
hparams.set_limit_image_tokens(252, 280);
|
||||
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
|
||||
} break;
|
||||
|
||||
case PROJECTOR_TYPE_GEMMA3NV:
|
||||
{
|
||||
// Gemma3n uses MobileNetV5 which produces 256 tokens (16x16)
|
||||
@@ -1442,6 +1493,11 @@ struct clip_model_loader {
|
||||
std::map<std::string, size_t> tensor_offset;
|
||||
std::vector<ggml_tensor *> tensors_to_load;
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
if (!fin) {
|
||||
throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
|
||||
}
|
||||
|
||||
// TODO @ngxson : support both audio and video in the future
|
||||
const char * prefix = model.modality == CLIP_MODALITY_AUDIO ? "a" : "v";
|
||||
|
||||
@@ -1478,6 +1534,18 @@ struct clip_model_loader {
|
||||
return cur;
|
||||
};
|
||||
|
||||
auto get_scalar = [&](const std::string & name, float default_val) {
|
||||
auto it = tensor_offset.find(name);
|
||||
if (it == tensor_offset.end()) {
|
||||
return default_val;
|
||||
}
|
||||
size_t offset = it->second;
|
||||
fin.seekg(offset, std::ios::beg);
|
||||
float value;
|
||||
fin.read(reinterpret_cast<char*>(&value), sizeof(float));
|
||||
return value;
|
||||
};
|
||||
|
||||
model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
|
||||
|
||||
model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
|
||||
@@ -1512,8 +1580,11 @@ struct clip_model_loader {
|
||||
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
|
||||
layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false);
|
||||
layer.ln_2_w = get_tensor(string_format(TN_LN_2, prefix, il, "weight"), false);
|
||||
layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias
|
||||
layer.ls_2_w = get_tensor(string_format(TN_LS_2, prefix, il, "weight"), false); // no bias
|
||||
layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias
|
||||
layer.ls_2_w = get_tensor(string_format(TN_LS_2, prefix, il, "weight"), false); // no bias
|
||||
layer.ls_out_w = get_tensor(string_format(TN_LS_OUT, prefix, il, "weight"), false); // no bias
|
||||
layer.attn_post_norm_w = get_tensor(string_format(TN_ATTN_POST_NORM, prefix, il, "weight"), false); // no bias
|
||||
layer.ff_post_norm_w = get_tensor(string_format(TN_FFN_POST_NORM, prefix, il, "weight"), false); // no bias
|
||||
|
||||
layer.k_b = get_tensor(string_format(TN_ATTN_K, prefix, il, "bias"), false);
|
||||
layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false);
|
||||
@@ -1713,6 +1784,32 @@ struct clip_model_loader {
|
||||
model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
|
||||
model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4V:
|
||||
{
|
||||
model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
|
||||
model.std_bias = get_tensor(TN_STD_BIAS, false);
|
||||
model.std_scale = get_tensor(TN_STD_SCALE, false);
|
||||
// load scalar for Gemma4ClippableLinear
|
||||
for (auto * tensor : tensors_to_load) {
|
||||
std::string name = tensor->name;
|
||||
if (string_ends_with(name, ".weight")) {
|
||||
std::string name_inp_max = name;
|
||||
std::string name_inp_min = name;
|
||||
std::string name_out_max = name;
|
||||
std::string name_out_min = name;
|
||||
string_replace_all(name_inp_max, ".weight", ".input_max");
|
||||
string_replace_all(name_inp_min, ".weight", ".input_min");
|
||||
string_replace_all(name_out_max, ".weight", ".output_max");
|
||||
string_replace_all(name_out_min, ".weight", ".output_min");
|
||||
model.clamp_info_map[name] = {
|
||||
get_scalar(name_inp_max, FLT_MAX),
|
||||
get_scalar(name_inp_min, -FLT_MAX),
|
||||
get_scalar(name_out_max, FLT_MAX),
|
||||
get_scalar(name_out_min, -FLT_MAX)
|
||||
};
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA3NV:
|
||||
{
|
||||
model.mobilenet_stem_conv_w = get_tensor(TN_MNV5_STEM_CONV, false);
|
||||
@@ -2042,11 +2139,6 @@ struct clip_model_loader {
|
||||
{
|
||||
std::vector<uint8_t> read_buf;
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
if (!fin) {
|
||||
throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
|
||||
}
|
||||
|
||||
// alloc memory and offload data
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend);
|
||||
ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft));
|
||||
@@ -2345,7 +2437,8 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
|
||||
|
||||
// TODO: we don't support audio for Gemma 3N, but GGUF contains audio tensors
|
||||
// we can remove this check when we implement audio support for Gemma 3N
|
||||
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV;
|
||||
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV
|
||||
|| ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA4V;
|
||||
}
|
||||
|
||||
if (loader.has_audio && !skip_audio) {
|
||||
@@ -2581,6 +2674,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
n_patches = x_patch * y_patch;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_GEMMA4V:
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
@@ -3031,6 +3125,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
}
|
||||
set_input_i32("patches", patches);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4V:
|
||||
{
|
||||
// set (col, row) patch positions for learned positional embedding
|
||||
const int n_cols = image_size_width / patch_size;
|
||||
std::vector<int> pos_x(num_patches), pos_y(num_patches);
|
||||
for (int i = 0; i < num_patches; i++) {
|
||||
pos_x[i] = i % n_cols;
|
||||
pos_y[i] = i / n_cols;
|
||||
}
|
||||
set_input_i32("pos_x", pos_x);
|
||||
set_input_i32("pos_y", pos_y);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_DEEPSEEKOCR:
|
||||
{
|
||||
GGML_ASSERT(pos_w == pos_h);
|
||||
@@ -3218,6 +3324,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_GEMMA3NV:
|
||||
return ctx->model.mm_input_proj_w->ne[0];
|
||||
case PROJECTOR_TYPE_GEMMA4V:
|
||||
return ctx->model.mm_input_proj_w->ne[1];
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
return ctx->model.mm_fc_w->ne[1];
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
|
||||
151
tools/mtmd/models/gemma4v.cpp
Normal file
151
tools/mtmd/models/gemma4v.cpp
Normal file
@@ -0,0 +1,151 @@
|
||||
#include "models.h"
|
||||
#include <cmath>
|
||||
|
||||
ggml_cgraph * clip_graph_gemma4v::build() {
|
||||
ggml_tensor * inp_raw = build_inp_raw();
|
||||
|
||||
// patches = 2 * (patches - 0.5)
|
||||
// equivalent to: patches * 2 - 1
|
||||
inp_raw = ggml_scale_bias(ctx0, inp_raw, 2.0f, -1.0f);
|
||||
ggml_set_name(inp_raw, "inp_raw_scaled");
|
||||
|
||||
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
|
||||
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
||||
ggml_set_name(inp, "inp");
|
||||
// note: no patch bias
|
||||
|
||||
ggml_tensor * pos_x = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
|
||||
ggml_set_name(pos_x, "pos_x");
|
||||
ggml_set_input(pos_x);
|
||||
|
||||
ggml_tensor * pos_y = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
|
||||
ggml_set_name(pos_y, "pos_y");
|
||||
ggml_set_input(pos_y);
|
||||
|
||||
{
|
||||
const int64_t pos_size = model.position_embeddings->ne[1];
|
||||
const size_t nb1 = ggml_row_size(model.position_embeddings->type, n_embd);
|
||||
|
||||
// positional embeddings are stored as lookup tables (one for x, one for y)
|
||||
ggml_tensor * tbl_x = ggml_view_2d(ctx0, model.position_embeddings,
|
||||
n_embd, pos_size, nb1, 0);
|
||||
ggml_tensor * tbl_y = ggml_view_2d(ctx0, model.position_embeddings,
|
||||
n_embd, pos_size, nb1, pos_size * nb1);
|
||||
|
||||
// ggml_get_rows: [n_embd, n_patches]
|
||||
ggml_tensor * emb_x = ggml_get_rows(ctx0, tbl_x, pos_x);
|
||||
ggml_tensor * emb_y = ggml_get_rows(ctx0, tbl_y, pos_y);
|
||||
|
||||
inp = ggml_add(ctx0, inp, emb_x);
|
||||
inp = ggml_add(ctx0, inp, emb_y);
|
||||
cb(inp, "pos_embd", -1);
|
||||
}
|
||||
|
||||
// similar to build_rope_2d, but use neox ordering
|
||||
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
|
||||
const int64_t n_dim = cur->ne[0];
|
||||
const int64_t n_head = cur->ne[1];
|
||||
const int64_t n_pos = cur->ne[2];
|
||||
|
||||
// first half
|
||||
ggml_tensor * first;
|
||||
{
|
||||
first = ggml_view_3d(ctx0, cur,
|
||||
n_dim/2, n_head, n_pos,
|
||||
cur->nb[1],
|
||||
cur->nb[2],
|
||||
0);
|
||||
first = ggml_rope_ext(
|
||||
ctx0,
|
||||
first,
|
||||
pos_x, // positions
|
||||
nullptr, // freq factors
|
||||
n_dim/2, // n_dims
|
||||
GGML_ROPE_TYPE_NEOX, 0, hparams.rope_theta,
|
||||
1.0f, 0.0f, 1.0f, 0.0f, 0.0f
|
||||
);
|
||||
}
|
||||
|
||||
// second half
|
||||
ggml_tensor * second;
|
||||
{
|
||||
second = ggml_view_3d(ctx0, cur,
|
||||
n_dim/2, n_head, n_pos,
|
||||
cur->nb[1],
|
||||
cur->nb[2],
|
||||
n_dim/2 * ggml_element_size(cur));
|
||||
second = ggml_rope_ext(
|
||||
ctx0,
|
||||
second,
|
||||
pos_y, // positions
|
||||
nullptr, // freq factors
|
||||
n_dim/2, // n_dims
|
||||
GGML_ROPE_TYPE_NEOX, 0, hparams.rope_theta,
|
||||
1.0f, 0.0f, 1.0f, 0.0f, 0.0f
|
||||
);
|
||||
}
|
||||
|
||||
cur = ggml_concat(ctx0, first, second, 0);
|
||||
return cur;
|
||||
};
|
||||
|
||||
kq_scale = 1.0f;
|
||||
ggml_tensor * cur = build_vit(
|
||||
inp, n_patches,
|
||||
NORM_TYPE_RMS,
|
||||
hparams.ffn_op,
|
||||
nullptr, // pos embd is already handled above
|
||||
add_pos);
|
||||
|
||||
// Gemma4VisionPooler
|
||||
{
|
||||
const int kernel_size = hparams.n_merge;
|
||||
GGML_ASSERT(kernel_size > 0);
|
||||
|
||||
// [n_embd, n_patches] -> [n_patches_x, n_patches_y, n_embd, 1]
|
||||
cur = ggml_cont_4d(ctx0, ggml_transpose(ctx0, cur), n_patches_x, n_patches_y, n_embd, 1);
|
||||
cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG,
|
||||
kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
|
||||
const int out_x = n_patches_x / kernel_size;
|
||||
const int out_y = n_patches_y / kernel_size;
|
||||
// [out_x, out_y, n_embd, 1] -> [n_embd, out_x * out_y]
|
||||
cur = ggml_reshape_3d(ctx0, cur, out_x * out_y, n_embd, 1);
|
||||
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
|
||||
cur = ggml_scale(ctx0, cur, sqrtf((float)n_embd));
|
||||
cb(cur, "pooled", -1);
|
||||
}
|
||||
|
||||
// hidden_states = (hidden_states - self.std_bias) * self.std_scale
|
||||
if (model.std_bias && model.std_scale) {
|
||||
cur = ggml_sub(ctx0, cur, model.std_bias);
|
||||
cur = ggml_mul(ctx0, cur, model.std_scale);
|
||||
cb(cur, "std_scaled", -1);
|
||||
}
|
||||
|
||||
// Gemma4MultimodalEmbedder
|
||||
cur = build_mm(model.mm_input_proj_w, cur);
|
||||
cb(cur, "projected", -1);
|
||||
|
||||
// embedding_post_projection_norm
|
||||
cur = ggml_rms_norm(ctx0, cur, hparams.eps);
|
||||
cb(cur, "projected_normed", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_tensor * clip_graph_gemma4v::build_mm(ggml_tensor * w, ggml_tensor * x) const {
|
||||
// Gemma4ClippableLinear
|
||||
|
||||
auto it = model.clamp_info_map.find(w->name);
|
||||
if (it == model.clamp_info_map.end()) {
|
||||
return ggml_mul_mat(ctx0, w, x);
|
||||
} else {
|
||||
const auto & clamp_info = it->second;
|
||||
ggml_tensor * clamped = ggml_clamp(ctx0, x, clamp_info.inp_min, clamp_info.inp_max);
|
||||
ggml_tensor * out = ggml_mul_mat(ctx0, w, clamped);
|
||||
out = ggml_clamp(ctx0, out, clamp_info.out_min, clamp_info.out_max);
|
||||
return out;
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,12 @@ struct clip_graph_siglip : clip_graph {
|
||||
ggml_cgraph * build() override;
|
||||
};
|
||||
|
||||
struct clip_graph_gemma4v : clip_graph {
|
||||
clip_graph_gemma4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
ggml_tensor * build_mm(ggml_tensor * w, ggml_tensor * x) const override;
|
||||
};
|
||||
|
||||
struct clip_graph_pixtral : clip_graph {
|
||||
clip_graph_pixtral(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
|
||||
@@ -394,6 +394,13 @@ struct mtmd_context {
|
||||
img_end = "<|IMAGE_END|>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_dyn_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4V:
|
||||
{
|
||||
// <|image> ... (image embeddings) ... <image|>
|
||||
img_beg = "<|image>";
|
||||
img_end = "<image|>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_dyn_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_DEEPSEEKOCR:
|
||||
{
|
||||
img_end = "\n"; // prevent empty batch on llama-server
|
||||
@@ -974,6 +981,7 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
|
||||
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
|
||||
switch (ctx->proj_type_v()) {
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_GEMMA4V:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -1196,6 +1196,10 @@ server_http_proxy::server_http_proxy(
|
||||
// disable Accept-Encoding to avoid compressed responses
|
||||
continue;
|
||||
}
|
||||
if (key == "Transfer-Encoding") {
|
||||
// the body is already decoded
|
||||
continue;
|
||||
}
|
||||
if (key == "Host" || key == "host") {
|
||||
bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80);
|
||||
req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port));
|
||||
|
||||
Reference in New Issue
Block a user