mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-12 14:03:20 +02:00
Compare commits
59 Commits
compilade/
...
b5176
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7604a7d6b8 | ||
|
|
b3b6d862cf | ||
|
|
5630406959 | ||
|
|
ecda2ec4b3 | ||
|
|
eb1776b15a | ||
|
|
2cca6c01e4 | ||
|
|
658987cfc9 | ||
|
|
dc39a5e7a8 | ||
|
|
ab47dec3d3 | ||
|
|
7b53389c24 | ||
|
|
243453533e | ||
|
|
1d735c0b4f | ||
|
|
5368ddda7a | ||
|
|
84a9bf2fc2 | ||
|
|
2016f07bd1 | ||
|
|
6602304814 | ||
|
|
66168204be | ||
|
|
4ba9d711ba | ||
|
|
00137157fc | ||
|
|
fb28f4f80e | ||
|
|
37b9f0d29d | ||
|
|
6408210082 | ||
|
|
aff9d107b0 | ||
|
|
35370ba945 | ||
|
|
8d66005763 | ||
|
|
b9154ecff9 | ||
|
|
2db9ba1464 | ||
|
|
2f74c354c0 | ||
|
|
207c22ec2d | ||
|
|
7a395f67a7 | ||
|
|
971f245b3b | ||
|
|
12b17501e6 | ||
|
|
015022bb53 | ||
|
|
b43d89e311 | ||
|
|
80f19b4186 | ||
|
|
f8f820cc4d | ||
|
|
54a7272043 | ||
|
|
84778e9770 | ||
|
|
510676475f | ||
|
|
daa422881a | ||
|
|
eccc7a1602 | ||
|
|
0019279bb5 | ||
|
|
b0c75ac9f9 | ||
|
|
d6d2c2ab8c | ||
|
|
75afa0ae31 | ||
|
|
c772d54926 | ||
|
|
81c7e64fc2 | ||
|
|
526739b879 | ||
|
|
a25355e264 | ||
|
|
e959d32b1c | ||
|
|
307bfa253d | ||
|
|
71e90e8813 | ||
|
|
bc091a4dc5 | ||
|
|
a4837577aa | ||
|
|
e59ea539b8 | ||
|
|
c94085df28 | ||
|
|
e8a62631b3 | ||
|
|
b6930ebc42 | ||
|
|
68b08f36d0 |
10
.github/workflows/build.yml
vendored
10
.github/workflows/build.yml
vendored
@@ -601,8 +601,9 @@ jobs:
|
||||
-DGGML_SYCL_F16=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
build-linux-cross:
|
||||
uses: ./.github/workflows/build-linux-cross.yml
|
||||
# Disabled for now due to sporadic issue syncing.
|
||||
# build-linux-cross:
|
||||
# uses: ./.github/workflows/build-linux-cross.yml
|
||||
|
||||
macOS-latest-cmake-ios:
|
||||
runs-on: macos-latest
|
||||
@@ -1766,16 +1767,17 @@ jobs:
|
||||
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -el {0}
|
||||
runs-on: ubuntu-24.04-arm
|
||||
shell: bash -el {0}
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [x86, aarch64]
|
||||
cann:
|
||||
- '8.1.RC1.alpha001-910b-openeuler22.03-py3.10'
|
||||
device:
|
||||
- 'ascend910b3'
|
||||
build:
|
||||
- 'Release'
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
container: ascendai/cann:${{ matrix.cann }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
4
Makefile
4
Makefile
@@ -780,10 +780,6 @@ ifdef GGML_HIP
|
||||
|
||||
MK_CPPFLAGS += -DGGML_USE_HIP -DGGML_USE_CUDA
|
||||
|
||||
ifdef GGML_HIP_UMA
|
||||
MK_CPPFLAGS += -DGGML_HIP_UMA
|
||||
endif # GGML_HIP_UMA
|
||||
|
||||
MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
|
||||
MK_LDFLAGS += -L$(ROCM_PATH)/lib64 -Wl,-rpath=$(ROCM_PATH)/lib64
|
||||
MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas
|
||||
|
||||
@@ -16,6 +16,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
||||
|
||||
## Hot topics
|
||||
|
||||
- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli` and `gemma3-cli` https://github.com/ggml-org/llama.cpp/pull/13012, `libllava` will be deprecated
|
||||
- **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggml-org/llama.cpp/pull/11427
|
||||
- **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode
|
||||
- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
|
||||
|
||||
@@ -40,7 +40,8 @@ To protect sensitive data from potential leaks or unauthorized access, it is cru
|
||||
### Untrusted environments or networks
|
||||
|
||||
If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions:
|
||||
* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value
|
||||
* Do not use the RPC backend, [rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/examples/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/examples/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
|
||||
* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value.
|
||||
* Encrypt your data if sending it over the network.
|
||||
|
||||
### Multi-Tenant environments
|
||||
|
||||
@@ -976,14 +976,13 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
|
||||
"llama-gritlm",
|
||||
"llama-imatrix",
|
||||
"llama-infill",
|
||||
"llama-llava-cli",
|
||||
"llama-mtmd-cli",
|
||||
"llama-llava-clip-quantize-cli",
|
||||
"llama-lookahead",
|
||||
"llama-lookup",
|
||||
"llama-lookup-create",
|
||||
"llama-lookup-merge",
|
||||
"llama-lookup-stats",
|
||||
"llama-minicpmv-cli",
|
||||
"llama-parallel",
|
||||
"llama-passkey",
|
||||
"llama-perplexity",
|
||||
@@ -2726,7 +2725,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.chat_template = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LLAVA}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
|
||||
string_format(
|
||||
|
||||
@@ -1622,7 +1622,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
}
|
||||
|
||||
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
||||
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
|
||||
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null() && params.tools.is_array() && params.json_schema.is_null()) {
|
||||
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
||||
}
|
||||
|
||||
|
||||
@@ -830,7 +830,7 @@ std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#ifdef __linux__
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else {
|
||||
@@ -840,7 +840,9 @@ std::string fs_get_cache_directory() {
|
||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||
#elif defined(_WIN32)
|
||||
cache_directory = std::getenv("LOCALAPPDATA");
|
||||
#endif // __linux__
|
||||
#else
|
||||
# error Unknown architecture
|
||||
#endif
|
||||
cache_directory = ensure_trailing_slash(cache_directory);
|
||||
cache_directory += "llama.cpp";
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -115,6 +115,7 @@ models = [
|
||||
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
|
||||
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", },
|
||||
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
||||
import gguf
|
||||
|
||||
# reuse model definitions from convert_hf_to_gguf.py
|
||||
from convert_hf_to_gguf import LazyTorchTensor, Model
|
||||
from convert_hf_to_gguf import LazyTorchTensor, ModelBase
|
||||
|
||||
logger = logging.getLogger("lora-to-gguf")
|
||||
|
||||
@@ -340,11 +340,11 @@ if __name__ == '__main__':
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.info(f"Loading base model: {dir_base_model.name}")
|
||||
hparams = Model.load_hparams(dir_base_model)
|
||||
hparams = ModelBase.load_hparams(dir_base_model)
|
||||
|
||||
with torch.inference_mode():
|
||||
try:
|
||||
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
||||
model_class = ModelBase.from_model_architecture(hparams["architectures"][0])
|
||||
except NotImplementedError:
|
||||
logger.error(f"Model {hparams['architectures'][0]} is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -259,8 +259,6 @@ You can download it from your Linux distro's package manager or from here: [ROCm
|
||||
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||
&& cmake --build build --config Release -- -j 16
|
||||
```
|
||||
On Linux it is also possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting `-DGGML_HIP_UMA=ON`.
|
||||
However, this hurts performance for non-integrated GPUs (but enables working with integrated GPUs).
|
||||
|
||||
To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system.
|
||||
|
||||
@@ -296,6 +294,10 @@ You can download it from your Linux distro's package manager or from here: [ROCm
|
||||
The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
|
||||
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3.
|
||||
|
||||
### Unified Memory
|
||||
|
||||
On Linux it is possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1`. However, this hurts performance for non-integrated GPUs (but enables working with integrated GPUs).
|
||||
|
||||
## Vulkan
|
||||
|
||||
**Windows**
|
||||
|
||||
@@ -9,15 +9,15 @@ The implementation is based on llava, and is compatible with llava and mobileVLM
|
||||
Notice: The overall process of model inference for both **MobileVLM** and **MobileVLM_V2** models is the same, but the process of model conversion is a little different. Therefore, using **MobileVLM-1.7B** as an example, the different conversion step will be shown.
|
||||
|
||||
## Usage
|
||||
Build with cmake or run `make llama-llava-cli` to build it.
|
||||
|
||||
After building, run: `./llama-llava-cli` to see the usage. For example:
|
||||
Build the `llama-mtmd-cli` binary.
|
||||
|
||||
After building, run: `./llama-mtmd-cli` to see the usage. For example:
|
||||
|
||||
```sh
|
||||
./llama-llava-cli -m MobileVLM-1.7B/ggml-model-q4_k.gguf \
|
||||
./llama-mtmd-cli -m MobileVLM-1.7B/ggml-model-q4_k.gguf \
|
||||
--mmproj MobileVLM-1.7B/mmproj-model-f16.gguf \
|
||||
--image path/to/an/image.jpg \
|
||||
-p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\nWho is the author of this book? Answer the question using a single word or phrase. ASSISTANT:"
|
||||
--chat-template deepseek
|
||||
```
|
||||
|
||||
## Model conversion
|
||||
@@ -82,7 +82,7 @@ refer to `android/adb_run.sh`, modify resources' `name` and `path`
|
||||
### case 1
|
||||
**input**
|
||||
```sh
|
||||
/data/local/tmp/llama-llava-cli \
|
||||
/data/local/tmp/llama-mtmd-cli \
|
||||
-m /data/local/tmp/ggml-model-q4_k.gguf \
|
||||
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
|
||||
-t 4 \
|
||||
@@ -102,7 +102,7 @@ llama_print_timings: total time = 34731.93 ms
|
||||
### case 2
|
||||
**input**
|
||||
```sh
|
||||
/data/local/tmp/llama-llava-cli \
|
||||
/data/local/tmp/llama-mtmd-cli \
|
||||
-m /data/local/tmp/ggml-model-q4_k.gguf \
|
||||
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
|
||||
-t 4 \
|
||||
@@ -123,10 +123,10 @@ llama_print_timings: total time = 34570.79 ms
|
||||
|
||||
## Some result on Android with `Snapdragon 778G` chip
|
||||
### MobileVLM-1.7B case
|
||||
#### llava-cli release-b2005
|
||||
#### mtmd-cli release-b2005
|
||||
**input**
|
||||
```sh
|
||||
/data/local/tmp/llama-llava-cli \
|
||||
/data/local/tmp/llama-mtmd-cli \
|
||||
-m /data/local/tmp/ggml-model-q4_k.gguf \
|
||||
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
|
||||
-t 4 \
|
||||
@@ -147,7 +147,7 @@ llama_print_timings: prompt eval time = 8119.49 ms / 191 tokens ( 42.51 m
|
||||
llama_print_timings: eval time = 1005.75 ms / 14 runs ( 71.84 ms per token, 13.92 tokens per second)
|
||||
llama_print_timings: total time = 28038.34 ms / 205 tokens
|
||||
```
|
||||
#### llava-cli latest-version
|
||||
#### mtmd-cli latest-version
|
||||
**input**
|
||||
|
||||
Just the same as above.
|
||||
@@ -169,7 +169,7 @@ llama_print_timings: eval time = 43894.02 ms / 13 runs ( 3376.46 m
|
||||
llama_print_timings: total time = 865441.76 ms / 204 tokens
|
||||
```
|
||||
### MobileVLM_V2-1.7B case
|
||||
#### llava-cli release-2005b
|
||||
#### mtmd-cli release-2005b
|
||||
**input**
|
||||
|
||||
Just the same as above.
|
||||
@@ -200,7 +200,7 @@ make GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_87 GGML_CUDA_F16=1 -j 32
|
||||
### case 1
|
||||
**input**
|
||||
```sh
|
||||
./llama-llava-cli \
|
||||
./llama-mtmd-cli \
|
||||
-m /data/local/tmp/ggml-model-q4_k.gguf \
|
||||
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
|
||||
--image /data/local/tmp/demo.jpeg \
|
||||
@@ -224,7 +224,7 @@ llama_print_timings: total time = 1352.63 ms / 252 tokens
|
||||
### case 2
|
||||
**input**
|
||||
```sh
|
||||
./llama-llava-cli \
|
||||
./llama-mtmd-cli \
|
||||
-m /data/local/tmp/ggml-model-q4_k.gguf \
|
||||
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
|
||||
-p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\nWhat is in the image? ASSISTANT:" \
|
||||
@@ -11,26 +11,27 @@ You can use pre-quantized model from [ggml-org](https://huggingface.co/ggml-org)
|
||||
```bash
|
||||
# build
|
||||
cmake -B build
|
||||
cmake --build build --target llama-gemma3-cli
|
||||
cmake --build build --target llama-mtmd-cli
|
||||
|
||||
# alternatively, install from brew (MacOS)
|
||||
brew install llama.cpp
|
||||
|
||||
# run it
|
||||
llama-gemma3-cli -hf ggml-org/gemma-3-4b-it-GGUF
|
||||
llama-gemma3-cli -hf ggml-org/gemma-3-12b-it-GGUF
|
||||
llama-gemma3-cli -hf ggml-org/gemma-3-27b-it-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF
|
||||
|
||||
# note: 1B model does not support vision
|
||||
```
|
||||
|
||||
## How to get mmproj.gguf?
|
||||
|
||||
Simply to add `--mmproj` in when converting model via `convert_hf_to_gguf.py`:
|
||||
|
||||
```bash
|
||||
cd gemma-3-4b-it
|
||||
python ../llama.cpp/examples/llava/gemma3_convert_encoder_to_gguf.py .
|
||||
|
||||
# output file is mmproj.gguf
|
||||
python ../llama.cpp/convert_hf_to_gguf.py --outfile model.gguf --outtype f16 --mmproj .
|
||||
# output file: mmproj-model.gguf
|
||||
```
|
||||
|
||||
## How to run it?
|
||||
@@ -43,8 +44,8 @@ What you need:
|
||||
```bash
|
||||
# build
|
||||
cmake -B build
|
||||
cmake --build build --target llama-gemma3-cli
|
||||
cmake --build build --target llama-mtmd-cli
|
||||
|
||||
# run it
|
||||
./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
|
||||
./build/bin/llama-mtmd-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
|
||||
```
|
||||
@@ -3,12 +3,12 @@
|
||||
Currently this implementation supports [glm-edge-v-2b](https://huggingface.co/THUDM/glm-edge-v-2b) and [glm-edge-v-5b](https://huggingface.co/THUDM/glm-edge-v-5b).
|
||||
|
||||
## Usage
|
||||
Build with cmake or run `make llama-llava-cli` to build it.
|
||||
Build the `llama-mtmd-cli` binary.
|
||||
|
||||
After building, run: `./llama-llava-cli` to see the usage. For example:
|
||||
After building, run: `./llama-mtmd-cli` to see the usage. For example:
|
||||
|
||||
```sh
|
||||
./llama-llava-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf --image img_path/image.jpg -p "<|system|>\n system prompt <image><|user|>\n prompt <|assistant|>\n"
|
||||
./llama-mtmd-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf
|
||||
```
|
||||
|
||||
**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
|
||||
@@ -176,15 +176,11 @@ Note that currently you cannot quantize the visual encoder because granite visio
|
||||
|
||||
|
||||
### 5. Running the Model in Llama cpp
|
||||
Build llama cpp normally; you should have a target binary named `llama-llava-cli`, which you can pass two binaries to. As an example, we pass the the llama.cpp banner.
|
||||
Build llama cpp normally; you should have a target binary named `llama-mtmd-cli`, which you can pass two binaries to. As an example, we pass the the llama.cpp banner.
|
||||
|
||||
```bash
|
||||
$ ./build/bin/llama-llava-cli -m $LLM_GGUF_PATH \
|
||||
$ ./build/bin/llama-mtmd-cli -m $LLM_GGUF_PATH \
|
||||
--mmproj $VISUAL_GGUF_PATH \
|
||||
--image ./media/llama0-banner.png \
|
||||
-c 16384 \
|
||||
-p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|user|>\n\<image>\nWhat does the text in this image say?\n<|assistant|>\n" \
|
||||
--temp 0
|
||||
```
|
||||
|
||||
Sample output: `The text in the image reads "LLAMA C++ Can it run DOOM Llama?"`
|
||||
143
docs/multimodal/llava.md
Normal file
143
docs/multimodal/llava.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# LLaVA
|
||||
|
||||
Currently this implementation supports [llava-v1.5](https://huggingface.co/liuhaotian/llava-v1.5-7b) variants,
|
||||
as well as llava-1.6 [llava-v1.6](https://huggingface.co/collections/liuhaotian/llava-16-65b9e40155f60fd046a5ccf2) variants.
|
||||
|
||||
The pre-converted [7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
|
||||
and [13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
|
||||
models are available.
|
||||
For llava-1.6 a variety of prepared gguf models are available as well [7b-34b](https://huggingface.co/cmp-nct/llava-1.6-gguf)
|
||||
|
||||
After API is confirmed, more models will be supported / uploaded.
|
||||
|
||||
## Usage
|
||||
Build the `llama-mtmd-cli` binary.
|
||||
|
||||
After building, run: `./llama-mtmd-cli` to see the usage. For example:
|
||||
|
||||
```sh
|
||||
./llama-mtmd-cli -m ../llava-v1.5-7b/ggml-model-f16.gguf \
|
||||
--mmproj ../llava-v1.5-7b/mmproj-model-f16.gguf \
|
||||
--chat-template vicuna
|
||||
```
|
||||
|
||||
**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
|
||||
**note**: For GPU offloading ensure to use the `-ngl` flag just like usual
|
||||
|
||||
## LLaVA 1.5
|
||||
|
||||
1. Clone a LLaVA and a CLIP model ([available options](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)). For example:
|
||||
|
||||
```sh
|
||||
git clone https://huggingface.co/liuhaotian/llava-v1.5-7b
|
||||
|
||||
git clone https://huggingface.co/openai/clip-vit-large-patch14-336
|
||||
```
|
||||
|
||||
2. Install the required Python packages:
|
||||
|
||||
```sh
|
||||
pip install -r examples/llava/requirements.txt
|
||||
```
|
||||
|
||||
3. Use `llava_surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents:
|
||||
|
||||
```sh
|
||||
python ./examples/llava/llava_surgery.py -m ../llava-v1.5-7b
|
||||
```
|
||||
|
||||
4. Use `convert_image_encoder_to_gguf.py` to convert the LLaVA image encoder to GGUF:
|
||||
|
||||
```sh
|
||||
python ./examples/llava/convert_image_encoder_to_gguf.py -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b
|
||||
```
|
||||
|
||||
5. Use `examples/convert_legacy_llama.py` to convert the LLaMA part of LLaVA to GGUF:
|
||||
|
||||
```sh
|
||||
python ./examples/convert_legacy_llama.py ../llava-v1.5-7b --skip-unknown
|
||||
```
|
||||
|
||||
Now both the LLaMA part and the image encoder are in the `llava-v1.5-7b` directory.
|
||||
|
||||
## LLaVA 1.6 gguf conversion
|
||||
1) First clone a LLaVA 1.6 model:
|
||||
```console
|
||||
git clone https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b
|
||||
```
|
||||
|
||||
2) Install the required Python packages:
|
||||
|
||||
```sh
|
||||
pip install -r examples/llava/requirements.txt
|
||||
```
|
||||
|
||||
3) Use `llava_surgery_v2.py` which also supports llava-1.5 variants pytorch as well as safetensor models:
|
||||
```console
|
||||
python examples/llava/llava_surgery_v2.py -C -m ../llava-v1.6-vicuna-7b/
|
||||
```
|
||||
- you will find a llava.projector and a llava.clip file in your model directory
|
||||
|
||||
4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory:
|
||||
```console
|
||||
mkdir vit
|
||||
cp ../llava-v1.6-vicuna-7b/llava.clip vit/pytorch_model.bin
|
||||
cp ../llava-v1.6-vicuna-7b/llava.projector vit/
|
||||
curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json
|
||||
```
|
||||
|
||||
5) Create the visual gguf model:
|
||||
```console
|
||||
python ./examples/llava/convert_image_encoder_to_gguf.py -m vit --llava-projector vit/llava.projector --output-dir vit --clip-model-is-vision
|
||||
```
|
||||
- This is similar to llava-1.5, the difference is that we tell the encoder that we are working with the pure vision model part of CLIP
|
||||
|
||||
6) Then convert the model to gguf format:
|
||||
```console
|
||||
python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknown
|
||||
```
|
||||
|
||||
7) And finally we can run the llava cli using the 1.6 model version:
|
||||
```console
|
||||
./llama-mtmd-cli -m ../llava-v1.6-vicuna-7b/ggml-model-f16.gguf --mmproj vit/mmproj-model-f16.gguf
|
||||
```
|
||||
|
||||
**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)
|
||||
|
||||
**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)
|
||||
|
||||
**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model.
|
||||
|
||||
```python
|
||||
import os
|
||||
import transformers
|
||||
|
||||
model_path = ...
|
||||
llm_export_path = ...
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
model = transformers.AutoModelForImageTextToText.from_pretrained(model_path)
|
||||
|
||||
tokenizer.save_pretrained(llm_export_path)
|
||||
model.language_model.save_pretrained(llm_export_path)
|
||||
```
|
||||
|
||||
Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures.
|
||||
|
||||
## Chat template
|
||||
|
||||
For llava-1.5 and llava-1.6, you need to use `vicuna` chat template. Simply add `--chat-template vicuna` to activate this template.
|
||||
|
||||
|
||||
## How to know if you are running in llava-1.5 or llava-1.6 mode
|
||||
|
||||
When running llava-cli you will see a visual information right before the prompt is being processed:
|
||||
|
||||
**Llava-1.5:**
|
||||
`encode_image_with_clip: image embedding created: 576 tokens`
|
||||
|
||||
**Llava-1.6 (anything above 576):**
|
||||
`encode_image_with_clip: image embedding created: 2880 tokens`
|
||||
|
||||
|
||||
Alternatively just pay notice to how many "tokens" have been used for your prompt, it will also show 1000+ tokens for llava-1.6
|
||||
@@ -40,9 +40,9 @@ python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model
|
||||
|
||||
Inference on Linux or Mac
|
||||
```bash
|
||||
# run f16 version
|
||||
./build/bin/llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
# run in single-turn mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# run quantized int4 version
|
||||
./build/bin/llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
# run in conversation mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf
|
||||
```
|
||||
@@ -39,9 +39,9 @@ python ./convert_hf_to_gguf.py ../MiniCPM-Llama3-V-2_5/model
|
||||
|
||||
Inference on Linux or Mac
|
||||
```bash
|
||||
# run f16 version
|
||||
./build/bin/llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
# run in single-turn mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# run quantized int4 version
|
||||
./build/bin/llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
# run in conversation mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf
|
||||
```
|
||||
@@ -39,9 +39,9 @@ python ./convert_hf_to_gguf.py ../MiniCPM-V-2_6/model
|
||||
|
||||
Inference on Linux or Mac
|
||||
```bash
|
||||
# run f16 version
|
||||
./build/bin/llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
# run in single-turn mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# run quantized int4 version
|
||||
./build/bin/llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
# run in conversation mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf
|
||||
```
|
||||
@@ -61,19 +61,9 @@ if(TARGET BUILD_INFO)
|
||||
add_dependencies(mtmd BUILD_INFO)
|
||||
endif()
|
||||
|
||||
set(TARGET llama-llava-cli)
|
||||
add_executable(${TARGET} llava-cli.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-cli)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
set(TARGET llama-minicpmv-cli)
|
||||
add_executable(${TARGET} minicpmv-cli.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-minicpmv-cli)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
add_executable(llama-llava-cli deprecation-warning.cpp)
|
||||
add_executable(llama-gemma3-cli deprecation-warning.cpp)
|
||||
add_executable(llama-minicpmv-cli deprecation-warning.cpp)
|
||||
|
||||
set(TARGET llama-qwen2vl-cli)
|
||||
add_executable(${TARGET} qwen2vl-cli.cpp)
|
||||
@@ -82,9 +72,9 @@ install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
set(TARGET llama-gemma3-cli)
|
||||
add_executable(${TARGET} gemma3-cli.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
|
||||
set(TARGET llama-mtmd-cli)
|
||||
add_executable(${TARGET} mtmd-cli.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -1,158 +1,75 @@
|
||||
# LLaVA
|
||||
# Multimodal Support in llama.cpp
|
||||
|
||||
Currently this implementation supports [llava-v1.5](https://huggingface.co/liuhaotian/llava-v1.5-7b) variants,
|
||||
as well as llava-1.6 [llava-v1.6](https://huggingface.co/collections/liuhaotian/llava-16-65b9e40155f60fd046a5ccf2) variants.
|
||||
This directory provides multimodal capabilities for `llama.cpp`. Initially intended as a showcase for running LLaVA models, its scope has expanded significantly over time to include various other vision-capable models. As a result, LLaVA is no longer the only multimodal architecture supported.
|
||||
|
||||
The pre-converted [7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
|
||||
and [13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
|
||||
models are available.
|
||||
For llava-1.6 a variety of prepared gguf models are available as well [7b-34b](https://huggingface.co/cmp-nct/llava-1.6-gguf)
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> Multimodal support can be viewed as a sub-project within `llama.cpp`. It is under **very heavy development**, and **breaking changes are expected**.
|
||||
|
||||
After API is confirmed, more models will be supported / uploaded.
|
||||
The naming and structure related to multimodal support have evolved, which might cause some confusion. Here's a brief timeline to clarify:
|
||||
|
||||
## Usage
|
||||
Build with cmake or run `make llama-llava-cli` to build it.
|
||||
- [#3436](https://github.com/ggml-org/llama.cpp/pull/3436): Initial support for LLaVA 1.5 was added, introducing `llava.cpp` and `clip.cpp`. The `llava-cli` binary was created for model interaction.
|
||||
- [#4954](https://github.com/ggml-org/llama.cpp/pull/4954): Support for MobileVLM was added, becoming the second vision model supported. This built upon the existing `llava.cpp`, `clip.cpp`, and `llava-cli` infrastructure.
|
||||
- **Expansion & Fragmentation:** Many new models were subsequently added (e.g., [#7599](https://github.com/ggml-org/llama.cpp/pull/7599), [#10361](https://github.com/ggml-org/llama.cpp/pull/10361), [#12344](https://github.com/ggml-org/llama.cpp/pull/12344), and others). However, `llava-cli` lacked support for the increasingly complex chat templates required by these models. This led to the creation of model-specific binaries like `qwen2vl-cli`, `minicpmv-cli`, and `gemma3-cli`. While functional, this proliferation of command-line tools became confusing for users.
|
||||
- [#12849](https://github.com/ggml-org/llama.cpp/pull/12849): `libmtmd` was introduced as a replacement for `llava.cpp`. Its goals include providing a single, unified command-line interface, improving the user/developer experience (UX/DX), and supporting both audio and image inputs.
|
||||
- [#13012](https://github.com/ggml-org/llama.cpp/pull/13012): `mtmd-cli` was added, consolidating the various model-specific CLIs into a single tool powered by `libmtmd`.
|
||||
|
||||
After building, run: `./llama-llava-cli` to see the usage. For example:
|
||||
## Pre-quantized models
|
||||
|
||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default:
|
||||
|
||||
```sh
|
||||
./llama-llava-cli -m ../llava-v1.5-7b/ggml-model-f16.gguf --mmproj ../llava-v1.5-7b/mmproj-model-f16.gguf --image path/to/an/image.jpg
|
||||
# Gemma 3
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF
|
||||
|
||||
# SmolVLM
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM-256M-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM-500M-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM2-2.2B-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM2-256M-Video-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
|
||||
|
||||
# Pixtral 12B
|
||||
llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
|
||||
```
|
||||
|
||||
**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
|
||||
**note**: For GPU offloading ensure to use the `-ngl` flag just like usual
|
||||
## How it works and what is `mmproj`?
|
||||
|
||||
## LLaVA 1.5
|
||||
Multimodal support in `llama.cpp` works by encoding images into embeddings using a separate model component, and then feeding these embeddings into the language model.
|
||||
|
||||
1. Clone a LLaVA and a CLIP model ([available options](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)). For example:
|
||||
This approach keeps the multimodal components distinct from the core `libllama` library. Separating these allows for faster, independent development cycles. While many modern vision models are based on Vision Transformers (ViTs), their specific pre-processing and projection steps can vary significantly. Integrating this diverse complexity directly into `libllama` is currently challenging.
|
||||
|
||||
```sh
|
||||
git clone https://huggingface.co/liuhaotian/llava-v1.5-7b
|
||||
Consequently, running a multimodal model typically requires two GGUF files:
|
||||
1. The standard language model file.
|
||||
2. A corresponding **multimodal projector (`mmproj`)** file, which handles the image encoding and projection.
|
||||
|
||||
git clone https://huggingface.co/openai/clip-vit-large-patch14-336
|
||||
```
|
||||
## What is `libmtmd`?
|
||||
|
||||
2. Install the required Python packages:
|
||||
As outlined in the history, `libmtmd` is the modern library designed to replace the original `llava.cpp` implementation for handling multimodal inputs.
|
||||
|
||||
```sh
|
||||
pip install -r examples/llava/requirements.txt
|
||||
```
|
||||
Built upon `clip.cpp` (similar to `llava.cpp`), `libmtmd` offers several advantages:
|
||||
- **Unified Interface:** Aims to consolidate interaction for various multimodal models.
|
||||
- **Improved UX/DX:** Features a more intuitive API, inspired by the `Processor` class in the Hugging Face `transformers` library.
|
||||
- **Flexibility:** Designed to support multiple input types (text, audio, images) while respecting the wide variety of chat templates used by different models.
|
||||
|
||||
3. Use `llava_surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents:
|
||||
## How to obtain `mmproj`
|
||||
|
||||
```sh
|
||||
python ./examples/llava/llava_surgery.py -m ../llava-v1.5-7b
|
||||
```
|
||||
Multimodal projector (`mmproj`) files are specific to each model architecture. Please refer to the relevant guide for instructions on how to obtain or create them:
|
||||
|
||||
4. Use `convert_image_encoder_to_gguf.py` to convert the LLaVA image encoder to GGUF:
|
||||
- [LLaVA](../../docs/multimodal/llava.md)
|
||||
- [MobileVLM](../../docs/multimodal/MobileVLM.md)
|
||||
- [GLM-Edge](../../docs/multimodal/glmedge.md)
|
||||
- [MiniCPM-V 2.5](../../docs/multimodal/minicpmv2.5.md)
|
||||
- [MiniCPM-V 2.6](../../docs/multimodal/minicpmv2.6.md)
|
||||
- [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
|
||||
- [IBM Granite Vision](../../docs/multimodal/granitevision.md)
|
||||
- [Google Gemma 3](../../docs/multimodal/gemma3.md)
|
||||
|
||||
```sh
|
||||
python ./examples/llava/convert_image_encoder_to_gguf.py -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b
|
||||
```
|
||||
|
||||
5. Use `examples/convert_legacy_llama.py` to convert the LLaMA part of LLaVA to GGUF:
|
||||
|
||||
```sh
|
||||
python ./examples/convert_legacy_llama.py ../llava-v1.5-7b --skip-unknown
|
||||
```
|
||||
|
||||
Now both the LLaMA part and the image encoder are in the `llava-v1.5-7b` directory.
|
||||
|
||||
## LLaVA 1.6 gguf conversion
|
||||
1) First clone a LLaVA 1.6 model:
|
||||
```console
|
||||
git clone https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b
|
||||
```
|
||||
|
||||
2) Install the required Python packages:
|
||||
|
||||
```sh
|
||||
pip install -r examples/llava/requirements.txt
|
||||
```
|
||||
|
||||
3) Use `llava_surgery_v2.py` which also supports llava-1.5 variants pytorch as well as safetensor models:
|
||||
```console
|
||||
python examples/llava/llava_surgery_v2.py -C -m ../llava-v1.6-vicuna-7b/
|
||||
```
|
||||
- you will find a llava.projector and a llava.clip file in your model directory
|
||||
|
||||
4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory:
|
||||
```console
|
||||
mkdir vit
|
||||
cp ../llava-v1.6-vicuna-7b/llava.clip vit/pytorch_model.bin
|
||||
cp ../llava-v1.6-vicuna-7b/llava.projector vit/
|
||||
curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json
|
||||
```
|
||||
|
||||
5) Create the visual gguf model:
|
||||
```console
|
||||
python ./examples/llava/convert_image_encoder_to_gguf.py -m vit --llava-projector vit/llava.projector --output-dir vit --clip-model-is-vision
|
||||
```
|
||||
- This is similar to llava-1.5, the difference is that we tell the encoder that we are working with the pure vision model part of CLIP
|
||||
|
||||
6) Then convert the model to gguf format:
|
||||
```console
|
||||
python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknown
|
||||
```
|
||||
|
||||
7) And finally we can run the llava cli using the 1.6 model version:
|
||||
```console
|
||||
./llama-llava-cli -m ../llava-v1.6-vicuna-7b/ggml-model-f16.gguf --mmproj vit/mmproj-model-f16.gguf --image some-image.jpg -c 4096
|
||||
```
|
||||
|
||||
**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)
|
||||
|
||||
**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)
|
||||
|
||||
**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model.
|
||||
|
||||
```python
|
||||
import os
|
||||
import transformers
|
||||
|
||||
model_path = ...
|
||||
llm_export_path = ...
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
model = transformers.AutoModelForImageTextToText.from_pretrained(model_path)
|
||||
|
||||
tokenizer.save_pretrained(llm_export_path)
|
||||
model.language_model.save_pretrained(llm_export_path)
|
||||
```
|
||||
|
||||
Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures.
|
||||
|
||||
## llava-cli templating and llava-1.6 prompting
|
||||
|
||||
llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."`
|
||||
For llava-1.5 models which are not vicuna (mistral and Yi) you need to adapt system prompt as well as user prompt, for this purpose llava-cli has a basic templating system:
|
||||
|
||||
**For Mistral and using llava-cli binary:**
|
||||
Add this: `-p "<image>\nUSER:\nProvide a full description.\nASSISTANT:\n"`
|
||||
The mistral template for llava-1.6 seems to be no system print and a USER/ASSISTANT role
|
||||
|
||||
**For the 34B this should work:**
|
||||
Add this: `-e -p <|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nProvide a full description.<|im_end|><|im_start|>assistant\n`
|
||||
|
||||
|
||||
## How to know if you are running in llava-1.5 or llava-1.6 mode
|
||||
|
||||
When running llava-cli you will see a visual information right before the prompt is being processed:
|
||||
|
||||
**Llava-1.5:**
|
||||
`encode_image_with_clip: image embedding created: 576 tokens`
|
||||
|
||||
**Llava-1.6 (anything above 576):**
|
||||
`encode_image_with_clip: image embedding created: 2880 tokens`
|
||||
|
||||
|
||||
Alternatively just pay notice to how many "tokens" have been used for your prompt, it will also show 1000+ tokens for llava-1.6
|
||||
|
||||
|
||||
|
||||
|
||||
## TODO
|
||||
|
||||
- [x] Support non-CPU backend for the image encoding part.
|
||||
- [ ] Support different sampling methods.
|
||||
- [ ] Support more model variants.
|
||||
For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
|
||||
- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
|
||||
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
|
||||
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
|
||||
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
|
||||
|
||||
@@ -10,7 +10,7 @@ prompt="A chat between a curious user and an artificial intelligence assistant.
|
||||
# prompt="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\nWhat is in the image? ASSISTANT:"
|
||||
|
||||
program_dir="build_64/bin"
|
||||
binName="llama-llava-cli"
|
||||
binName="llama-mtmd-cli"
|
||||
n_threads=4
|
||||
|
||||
|
||||
|
||||
@@ -33,13 +33,13 @@
|
||||
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||
#define KEY_PROJ_DIM "clip.%s.projection_dim"
|
||||
#define KEY_TOKENS "tokenizer.ggml.tokens"
|
||||
#define KEY_N_POSITIONS "clip.text.context_length"
|
||||
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||
|
||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||
@@ -50,7 +50,6 @@
|
||||
// tensor name constants
|
||||
//
|
||||
|
||||
#define TN_TOKEN_EMBD "%s.token_embd.weight"
|
||||
#define TN_POS_EMBD "%s.position_embd.weight"
|
||||
#define TN_CLASS_EMBD "v.class_embd"
|
||||
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
|
||||
@@ -61,13 +60,12 @@
|
||||
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
|
||||
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
|
||||
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
|
||||
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
|
||||
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
|
||||
#define TN_LN_1 "%s.blk.%d.ln1.%s"
|
||||
#define TN_LN_2 "%s.blk.%d.ln2.%s"
|
||||
#define TN_LN_PRE "%s.pre_ln.%s"
|
||||
#define TN_LN_POST "%s.post_ln.%s"
|
||||
#define TN_TEXT_PROJ "text_projection.weight"
|
||||
#define TN_VIS_PROJ "visual_projection.weight"
|
||||
#define TN_LLAVA_PROJ "mm.%d.%s"
|
||||
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
|
||||
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
||||
@@ -75,6 +73,8 @@
|
||||
#define TN_IMAGE_NEWLINE "model.image_newline"
|
||||
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
||||
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
||||
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
|
||||
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
|
||||
|
||||
// mimicpmv
|
||||
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
|
||||
@@ -102,6 +102,8 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_GLM_EDGE,
|
||||
PROJECTOR_TYPE_MERGER,
|
||||
PROJECTOR_TYPE_GEMMA3,
|
||||
PROJECTOR_TYPE_IDEFICS3,
|
||||
PROJECTOR_TYPE_PIXTRAL,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -113,6 +115,8 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
||||
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
|
||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
||||
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -30,12 +30,13 @@ struct clip_image_size {
|
||||
int height;
|
||||
};
|
||||
|
||||
struct clip_image_f32;
|
||||
struct clip_image_u8_batch;
|
||||
struct clip_image_f32_batch;
|
||||
|
||||
struct clip_context_params {
|
||||
bool use_gpu;
|
||||
ggml_log_level verbosity;
|
||||
enum ggml_log_level verbosity;
|
||||
};
|
||||
|
||||
// deprecated, use clip_init
|
||||
@@ -84,7 +85,7 @@ CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
|
||||
CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
|
||||
CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
|
||||
CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
|
||||
CLIP_API clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
|
||||
CLIP_API struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
|
||||
|
||||
/**
|
||||
* Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
|
||||
|
||||
22
examples/llava/deprecation-warning.cpp
Normal file
22
examples/llava/deprecation-warning.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::string filename = "main";
|
||||
if (argc >= 1) {
|
||||
filename = argv[0];
|
||||
}
|
||||
|
||||
// Get only the program name from the full path
|
||||
size_t pos = filename.find_last_of("/\\");
|
||||
if (pos != std::string::npos) {
|
||||
filename = filename.substr(pos+1);
|
||||
}
|
||||
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str());
|
||||
fprintf(stdout, "Please use 'llama-mtmd-cli' instead.\n");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
@@ -1,307 +0,0 @@
|
||||
import gguf
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import torch
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import cast, ContextManager, Any, Iterator
|
||||
from pathlib import Path
|
||||
from torch import Tensor
|
||||
|
||||
logger = logging.getLogger("gemma3-mmproj")
|
||||
|
||||
|
||||
# (copied from convert_hf_to_gguf.py)
|
||||
# tree of lazy tensors
|
||||
class LazyTorchTensor(gguf.LazyBase):
|
||||
_tensor_type = torch.Tensor
|
||||
# to keep the type-checker happy
|
||||
dtype: torch.dtype
|
||||
shape: torch.Size
|
||||
|
||||
# only used when converting a torch.Tensor to a np.ndarray
|
||||
_dtype_map: dict[torch.dtype, type] = {
|
||||
torch.float16: np.float16,
|
||||
torch.float32: np.float32,
|
||||
}
|
||||
|
||||
# used for safetensors slices
|
||||
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
|
||||
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
|
||||
_dtype_str_map: dict[str, torch.dtype] = {
|
||||
"F64": torch.float64,
|
||||
"F32": torch.float32,
|
||||
"BF16": torch.bfloat16,
|
||||
"F16": torch.float16,
|
||||
# "U64": torch.uint64,
|
||||
"I64": torch.int64,
|
||||
# "U32": torch.uint32,
|
||||
"I32": torch.int32,
|
||||
# "U16": torch.uint16,
|
||||
"I16": torch.int16,
|
||||
"U8": torch.uint8,
|
||||
"I8": torch.int8,
|
||||
"BOOL": torch.bool,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
"F8_E5M2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
def numpy(self) -> gguf.LazyNumpyTensor:
|
||||
dtype = self._dtype_map[self.dtype]
|
||||
return gguf.LazyNumpyTensor(
|
||||
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
|
||||
args=(self,),
|
||||
func=(lambda s: s.numpy())
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
|
||||
return torch.empty(size=shape, dtype=dtype, device="meta")
|
||||
|
||||
@classmethod
|
||||
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
|
||||
dtype = cls._dtype_str_map[st_slice.get_dtype()]
|
||||
shape: tuple[int, ...] = tuple(st_slice.get_shape())
|
||||
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
|
||||
return cast(torch.Tensor, lazy)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
del types # unused
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if func is torch.Tensor.numpy:
|
||||
return args[0].numpy()
|
||||
|
||||
return cls._wrap_fn(func)(*args, **kwargs)
|
||||
|
||||
|
||||
class Gemma3VisionTower:
|
||||
hparams: dict
|
||||
gguf_writer: gguf.GGUFWriter
|
||||
fname_out: Path
|
||||
ftype: gguf.LlamaFileType
|
||||
|
||||
@staticmethod
|
||||
def load_hparams(dir_model: Path):
|
||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
|
||||
part_names: list[str] = []
|
||||
for filename in os.listdir(dir_model):
|
||||
if filename.startswith(prefix) and filename.endswith(suffix):
|
||||
part_names.append(filename)
|
||||
part_names.sort()
|
||||
return part_names
|
||||
|
||||
def __init__(self,
|
||||
dir_model: Path,
|
||||
fname_out: Path,
|
||||
ftype: gguf.LlamaFileType,
|
||||
is_big_endian: bool,):
|
||||
hparams = Gemma3VisionTower.load_hparams(dir_model)
|
||||
self.hparams = hparams
|
||||
self.fname_out = fname_out
|
||||
self.ftype = ftype
|
||||
endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch="clip", endianess=endianess)
|
||||
|
||||
text_config = hparams["text_config"]
|
||||
vision_config = hparams["vision_config"]
|
||||
|
||||
assert hparams["architectures"][0] == "Gemma3ForConditionalGeneration"
|
||||
assert text_config is not None
|
||||
assert vision_config is not None
|
||||
|
||||
self.gguf_writer.add_string ("clip.projector_type", "gemma3")
|
||||
self.gguf_writer.add_bool ("clip.has_text_encoder", False)
|
||||
self.gguf_writer.add_bool ("clip.has_vision_encoder", True)
|
||||
self.gguf_writer.add_bool ("clip.has_llava_projector", False) # legacy
|
||||
self.gguf_writer.add_uint32 ("clip.vision.image_size", vision_config["image_size"])
|
||||
self.gguf_writer.add_uint32 ("clip.vision.patch_size", vision_config["patch_size"])
|
||||
self.gguf_writer.add_uint32 ("clip.vision.embedding_length", vision_config["hidden_size"])
|
||||
self.gguf_writer.add_uint32 ("clip.vision.feed_forward_length", vision_config["intermediate_size"])
|
||||
self.gguf_writer.add_uint32 ("clip.vision.projection_dim", text_config["hidden_size"])
|
||||
self.gguf_writer.add_uint32 ("clip.vision.block_count", vision_config["num_hidden_layers"])
|
||||
self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", vision_config["num_attention_heads"])
|
||||
self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", vision_config.get("layer_norm_eps", 1e-6))
|
||||
# default values taken from HF tranformers code
|
||||
self.gguf_writer.add_array ("clip.vision.image_mean", [0.5, 0.5, 0.5])
|
||||
self.gguf_writer.add_array ("clip.vision.image_std", [0.5, 0.5, 0.5])
|
||||
self.gguf_writer.add_bool ("clip.use_gelu", True)
|
||||
|
||||
# load tensors
|
||||
for name, data_torch in self.get_tensors(dir_model):
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
self.add_tensor(name, data_torch)
|
||||
|
||||
def get_tensors(self, dir_model: Path) -> Iterator[tuple[str, Tensor]]:
|
||||
part_names = Gemma3VisionTower.get_model_part_names(dir_model, "model", ".safetensors")
|
||||
tensor_names_from_parts: set[str] = set()
|
||||
for part_name in part_names:
|
||||
logger.info(f"gguf: loading model part '{part_name}'")
|
||||
from safetensors import safe_open
|
||||
ctx = cast(ContextManager[Any], safe_open(dir_model / part_name, framework="pt", device="cpu"))
|
||||
with ctx as model_part:
|
||||
tensor_names_from_parts.update(model_part.keys())
|
||||
|
||||
for name in model_part.keys():
|
||||
data = model_part.get_slice(name)
|
||||
data = LazyTorchTensor.from_safetensors_slice(data)
|
||||
yield name, data
|
||||
|
||||
def add_tensor(self, name: str, data_torch: Tensor):
|
||||
is_1d = len(data_torch.shape) == 1
|
||||
is_embd = ".embeddings." in name
|
||||
old_dtype = data_torch.dtype
|
||||
can_quantize = not is_1d and not is_embd
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
|
||||
# this is to support old checkpoint
|
||||
# TODO: remove this when we have the final model
|
||||
name = name.replace("vision_model.vision_model.", "vision_tower.vision_model.")
|
||||
name = name.replace("multimodal_projector.", "multi_modal_projector.")
|
||||
|
||||
# filter only vision tensors
|
||||
if not name.startswith("vision_tower.vision_model.") and not name.startswith("multi_modal_projector."):
|
||||
return
|
||||
# prefix
|
||||
name = name.replace("vision_tower.vision_model.encoder.layers.", "v.blk.")
|
||||
name = name.replace("vision_tower.vision_model.", "v.")
|
||||
# projector and input embd
|
||||
name = name.replace(".embeddings.patch_embedding.", ".patch_embd.")
|
||||
name = name.replace(".embeddings.position_embedding.", ".position_embd.")
|
||||
name = name.replace(
|
||||
"multi_modal_projector.mm_input_projection_weight",
|
||||
"mm.input_projection.weight"
|
||||
)
|
||||
name = name.replace(
|
||||
"multi_modal_projector.mm_soft_emb_norm.weight",
|
||||
"mm.soft_emb_norm.weight"
|
||||
)
|
||||
name = name.replace("post_layernorm.", "post_ln.")
|
||||
# each block
|
||||
name = name.replace(".self_attn.k_proj.", ".attn_k.")
|
||||
name = name.replace(".self_attn.v_proj.", ".attn_v.")
|
||||
name = name.replace(".self_attn.q_proj.", ".attn_q.")
|
||||
name = name.replace(".self_attn.out_proj.", ".attn_out.")
|
||||
name = name.replace(".layer_norm1.", ".ln1.")
|
||||
name = name.replace(".layer_norm2.", ".ln2.")
|
||||
name = name.replace(".mlp.fc1.", ".ffn_down.")
|
||||
name = name.replace(".mlp.fc2.", ".ffn_up.")
|
||||
|
||||
if can_quantize:
|
||||
if self.ftype == gguf.LlamaFileType.ALL_F32:
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
|
||||
data_qtype = gguf.GGMLQuantizationType.Q8_0
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {self.ftype}")
|
||||
|
||||
# corrent norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
|
||||
# the other norm values are part of SigLIP model, and they are already correct
|
||||
# ref code: Gemma3RMSNorm
|
||||
if "soft_emb_norm.weight" in name:
|
||||
logger.info(f"Correcting norm value for '{name}'")
|
||||
data_torch = data_torch + 1
|
||||
|
||||
data = data_torch.numpy()
|
||||
|
||||
try:
|
||||
data = gguf.quants.quantize(data, data_qtype)
|
||||
except Exception as e:
|
||||
logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16")
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
data = gguf.quants.quantize(data, data_qtype)
|
||||
|
||||
# reverse shape to make it similar to the internal ggml dimension order
|
||||
shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}"
|
||||
logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
|
||||
|
||||
self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype)
|
||||
|
||||
def write(self):
|
||||
self.gguf_writer.write_header_to_file(path=self.fname_out)
|
||||
self.gguf_writer.write_kv_data_to_file()
|
||||
self.gguf_writer.write_tensors_to_file(progress=True)
|
||||
self.gguf_writer.close()
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Gemma 3 vision tower safetensors to GGUF format",)
|
||||
parser.add_argument(
|
||||
"--outfile", type=Path, default="mmproj.gguf",
|
||||
help="path to write to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
|
||||
help="output format",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bigendian", action="store_true",
|
||||
help="model is executed on big endian machine",
|
||||
)
|
||||
parser.add_argument(
|
||||
"model", type=Path,
|
||||
help="directory containing model file",
|
||||
nargs="?",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true",
|
||||
help="increase output verbosity",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.model is None:
|
||||
parser.error("the following arguments are required: model")
|
||||
return args
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
dir_model = args.model
|
||||
|
||||
if not dir_model.is_dir():
|
||||
logger.error(f'Error: {args.model} is not a directory')
|
||||
sys.exit(1)
|
||||
|
||||
ftype_map: dict[str, gguf.LlamaFileType] = {
|
||||
"f32": gguf.LlamaFileType.ALL_F32,
|
||||
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
||||
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
||||
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
|
||||
}
|
||||
|
||||
logger.info(f"Loading model: {dir_model.name}")
|
||||
|
||||
with torch.inference_mode():
|
||||
gemma3_vision_tower = Gemma3VisionTower(
|
||||
dir_model=dir_model,
|
||||
fname_out=args.outfile,
|
||||
ftype=ftype_map[args.outtype],
|
||||
is_big_endian=args.bigendian,
|
||||
)
|
||||
gemma3_vision_tower.write()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@@ -1,332 +0,0 @@
|
||||
#include "arg.h"
|
||||
#include "base64.hpp"
|
||||
#include "log.h"
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "clip.h"
|
||||
#include "llava.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) {
|
||||
int N = (int) tokens.size();
|
||||
for (int i = 0; i < N; i += n_batch) {
|
||||
int n_eval = (int) tokens.size() - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
}
|
||||
*n_past += n_eval;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
|
||||
std::vector<llama_token> tokens;
|
||||
tokens.push_back(id);
|
||||
return eval_tokens(ctx_llama, tokens, 1, n_past);
|
||||
}
|
||||
|
||||
static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){
|
||||
std::string str2 = str;
|
||||
std::vector<llama_token> embd_inp = common_tokenize(ctx_llama, str2, add_bos, true);
|
||||
eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
|
||||
return true;
|
||||
}
|
||||
|
||||
static const char * sample(struct common_sampler * smpl,
|
||||
struct llama_context * ctx_llama,
|
||||
int * n_past) {
|
||||
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
|
||||
common_sampler_accept(smpl, id, true);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx_llama);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
static std::string ret;
|
||||
if (llama_vocab_is_eog(vocab, id)) {
|
||||
ret = "</s>";
|
||||
} else {
|
||||
ret = common_token_to_piece(ctx_llama, id);
|
||||
}
|
||||
eval_id(ctx_llama, id, n_past);
|
||||
return ret.c_str();
|
||||
}
|
||||
|
||||
static const char* IMG_BASE64_TAG_BEGIN = "<img src=\"data:image/jpeg;base64,";
|
||||
static const char* IMG_BASE64_TAG_END = "\">";
|
||||
|
||||
static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
|
||||
begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
|
||||
end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
|
||||
}
|
||||
|
||||
static bool prompt_contains_image(const std::string& prompt) {
|
||||
size_t begin, end;
|
||||
find_image_tag_in_prompt(prompt, begin, end);
|
||||
return (begin != std::string::npos);
|
||||
}
|
||||
|
||||
// replaces the base64 image tag in the prompt with `replacement`
|
||||
static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) {
|
||||
size_t img_base64_str_start, img_base64_str_end;
|
||||
find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
|
||||
if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
|
||||
LOG_ERR("%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
|
||||
auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
|
||||
auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
|
||||
|
||||
auto required_bytes = base64::required_encode_size(base64_str.size());
|
||||
auto img_bytes = std::vector<unsigned char>(required_bytes);
|
||||
base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
|
||||
|
||||
auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size());
|
||||
if (!embed) {
|
||||
LOG_ERR("%s: could not load image from base64 string.\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return embed;
|
||||
}
|
||||
|
||||
static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
|
||||
size_t begin, end;
|
||||
find_image_tag_in_prompt(prompt, begin, end);
|
||||
if (begin == std::string::npos || end == std::string::npos) {
|
||||
return prompt;
|
||||
}
|
||||
auto pre = prompt.substr(0, begin);
|
||||
auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END));
|
||||
return pre + replacement + post;
|
||||
}
|
||||
|
||||
struct llava_context {
|
||||
struct clip_ctx * ctx_clip = NULL;
|
||||
struct llama_context * ctx_llama = NULL;
|
||||
struct llama_model * model = NULL;
|
||||
};
|
||||
|
||||
static void print_usage(int, char ** argv) {
|
||||
LOG("\n example usage:\n");
|
||||
LOG("\n %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
|
||||
LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
|
||||
}
|
||||
|
||||
static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) {
|
||||
|
||||
// load and preprocess the image
|
||||
llava_image_embed * embed = NULL;
|
||||
auto prompt = params->prompt;
|
||||
if (prompt_contains_image(prompt)) {
|
||||
if (!params->image.empty()) {
|
||||
LOG_INF("using base64 encoded image instead of command line image path\n");
|
||||
}
|
||||
embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt);
|
||||
if (!embed) {
|
||||
LOG_ERR("%s: can't load image from prompt\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
params->prompt = remove_image_from_prompt(prompt);
|
||||
} else {
|
||||
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str());
|
||||
if (!embed) {
|
||||
fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return embed;
|
||||
}
|
||||
|
||||
static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) {
|
||||
int n_past = 0;
|
||||
|
||||
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
|
||||
|
||||
std::string system_prompt, user_prompt;
|
||||
size_t image_pos = prompt.find("<image>");
|
||||
if (image_pos != std::string::npos) {
|
||||
// new templating mode: Provide the full prompt including system message and use <image> as a placeholder for the image
|
||||
system_prompt = prompt.substr(0, image_pos);
|
||||
user_prompt = prompt.substr(image_pos + std::string("<image>").length());
|
||||
LOG_INF("system_prompt: %s\n", system_prompt.c_str());
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
LOG_INF("user_prompt: %s\n", user_prompt.c_str());
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// llava-1.5 native mode
|
||||
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:";
|
||||
user_prompt = prompt + "\nASSISTANT:";
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true);
|
||||
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
|
||||
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
|
||||
|
||||
// generate the response
|
||||
|
||||
LOG("\n");
|
||||
|
||||
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
|
||||
if (!smpl) {
|
||||
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string response = "";
|
||||
for (int i = 0; i < max_tgt_len; i++) {
|
||||
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0) break;
|
||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||
LOG("%s", tmp);
|
||||
if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
|
||||
if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6
|
||||
if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
common_sampler_free(smpl);
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
static struct llama_model * llava_init(common_params * params) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params->numa);
|
||||
|
||||
llama_model_params model_params = common_model_params_to_llama(*params);
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params);
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: unable to load model\n" , __func__);
|
||||
return NULL;
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
|
||||
const char * clip_path = params->mmproj.path.c_str();
|
||||
|
||||
auto prompt = params->prompt;
|
||||
if (prompt.empty()) {
|
||||
prompt = "describe the image in detail.";
|
||||
}
|
||||
|
||||
auto ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO);
|
||||
|
||||
llama_context_params ctx_params = common_context_params_to_llama(*params);
|
||||
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
|
||||
|
||||
llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
|
||||
|
||||
if (ctx_llama == NULL) {
|
||||
LOG_ERR("%s: failed to create the llama_context\n" , __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
|
||||
|
||||
ctx_llava->ctx_llama = ctx_llama;
|
||||
ctx_llava->ctx_clip = ctx_clip;
|
||||
ctx_llava->model = model;
|
||||
return ctx_llava;
|
||||
}
|
||||
|
||||
static void llava_free(struct llava_context * ctx_llava) {
|
||||
if (ctx_llava->ctx_clip) {
|
||||
clip_free(ctx_llava->ctx_clip);
|
||||
ctx_llava->ctx_clip = NULL;
|
||||
}
|
||||
|
||||
llama_free(ctx_llava->ctx_llama);
|
||||
llama_model_free(ctx_llava->model);
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto * model = llava_init(¶ms);
|
||||
if (model == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (prompt_contains_image(params.prompt)) {
|
||||
auto * ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
auto * image_embed = load_image(ctx_llava, ¶ms, "");
|
||||
|
||||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
} else {
|
||||
for (auto & image : params.image) {
|
||||
auto * ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
auto * image_embed = load_image(ctx_llava, ¶ms, image);
|
||||
if (!image_embed) {
|
||||
LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
}
|
||||
}
|
||||
|
||||
llama_model_free(model);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,354 +0,0 @@
|
||||
#include "arg.h"
|
||||
#include "log.h"
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "clip.h"
|
||||
#include "llava.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <iostream> // TODO: remove me
|
||||
|
||||
struct llava_context {
|
||||
struct clip_ctx * ctx_clip = NULL;
|
||||
struct llama_context * ctx_llama = NULL;
|
||||
struct llama_model * model = NULL;
|
||||
};
|
||||
|
||||
static void show_additional_info(int /*argc*/, char ** argv) {
|
||||
LOG("\nexample usage:\n\n%s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
|
||||
LOG("\nnote: a lower temperature value like 0.1 is recommended for better quality.\n");
|
||||
}
|
||||
|
||||
static struct llama_model * llava_init(common_params * params) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params->numa);
|
||||
|
||||
llama_model_params model_params = common_model_params_to_llama(*params);
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params);
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: unable to load model\n" , __func__);
|
||||
return NULL;
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
|
||||
auto prompt = params->prompt;
|
||||
if (prompt.empty()) {
|
||||
prompt = "describe the image in detail.";
|
||||
}
|
||||
|
||||
llama_context_params ctx_params = common_context_params_to_llama(*params);
|
||||
if (params->n_ctx < 2048) {
|
||||
// warn user here, "Image processing requires at least 2048 context, setting context to 2048"
|
||||
LOG_WRN("%s: Image processing requires at least 2048 context, setting context to 2048\n" , __func__);
|
||||
ctx_params.n_ctx = 2048;
|
||||
} else {
|
||||
ctx_params.n_ctx = params->n_ctx;
|
||||
}
|
||||
|
||||
llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
|
||||
|
||||
if (ctx_llama == NULL) {
|
||||
LOG_ERR("%s: failed to create the llama_context\n" , __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
|
||||
|
||||
ctx_llava->ctx_llama = ctx_llama;
|
||||
ctx_llava->model = model;
|
||||
return ctx_llava;
|
||||
}
|
||||
|
||||
static void llava_free(struct llava_context * ctx_llava) {
|
||||
if (ctx_llava->ctx_clip) {
|
||||
clip_free(ctx_llava->ctx_clip);
|
||||
ctx_llava->ctx_clip = NULL;
|
||||
}
|
||||
|
||||
llama_free(ctx_llava->ctx_llama);
|
||||
llama_model_free(ctx_llava->model);
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
static struct clip_ctx * clip_init_context(common_params * params) {
|
||||
const char * clip_path = params->mmproj.path.c_str();
|
||||
|
||||
auto prompt = params->prompt;
|
||||
if (prompt.empty()) {
|
||||
prompt = "describe the image in detail.";
|
||||
}
|
||||
struct clip_context_params clip_params = {
|
||||
/* use_gpu */ params->n_gpu_layers != 0,
|
||||
/* verbosity */ GGML_LOG_LEVEL_INFO, // TODO: make this configurable
|
||||
};
|
||||
auto * ctx_clip = clip_init(clip_path, clip_params);
|
||||
return ctx_clip;
|
||||
}
|
||||
|
||||
static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) {
|
||||
int N = (int) tokens.size();
|
||||
for (int i = 0; i < N; i += n_batch) {
|
||||
int n_eval = (int) tokens.size() - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
}
|
||||
*n_past += n_eval;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
|
||||
std::vector<llama_token> tokens;
|
||||
tokens.push_back(id);
|
||||
return eval_tokens(ctx_llama, tokens, 1, n_past);
|
||||
}
|
||||
|
||||
static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){
|
||||
std::string str2 = str;
|
||||
std::vector<llama_token> embd_inp = common_tokenize(ctx_llama, str2, add_bos, true);
|
||||
return eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
|
||||
}
|
||||
|
||||
static void process_eval_image_embed(struct llava_context * ctx_llava, const struct llava_image_embed * embeds, int n_batch, int * n_past, int idx) {
|
||||
float * image_embed = (float *)malloc(clip_embd_nbytes(ctx_llava->ctx_clip));
|
||||
std::memcpy(image_embed, embeds->embed + idx * clip_n_patches(ctx_llava->ctx_clip) * clip_n_mmproj_embd(ctx_llava->ctx_clip), clip_embd_nbytes(ctx_llava->ctx_clip));
|
||||
|
||||
auto * slice_embed = (llava_image_embed*)malloc(sizeof(llava_image_embed));
|
||||
slice_embed->embed = image_embed;
|
||||
slice_embed->n_image_pos = clip_n_patches(ctx_llava->ctx_clip);
|
||||
llava_eval_image_embed(ctx_llava->ctx_llama, slice_embed, n_batch, n_past);
|
||||
llava_image_embed_free(slice_embed);
|
||||
}
|
||||
|
||||
static void process_image(struct llava_context * ctx_llava, struct llava_image_embed * embeds, common_params * params, int &n_past) {
|
||||
std::string system_prompt;
|
||||
int idx = 0;
|
||||
int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip);
|
||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
|
||||
if (has_minicpmv_projector == 2) {
|
||||
system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n";
|
||||
}
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
system_prompt = "<|im_start|>user\n";
|
||||
}
|
||||
else if (has_minicpmv_projector == 4) {
|
||||
system_prompt = "<|im_start|>user\n";
|
||||
}
|
||||
LOG_INF("%s: image token past: %d\n", __func__, n_past);
|
||||
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
|
||||
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
|
||||
eval_string(ctx_llava->ctx_llama, std::string("</image>").c_str(), params->n_batch, &n_past, false);
|
||||
if (num_image_embeds > 1) {
|
||||
if (has_minicpmv_projector == 2) {
|
||||
size_t num_image_embeds_col = clip_uhd_num_image_embeds_col(ctx_llava->ctx_clip);
|
||||
eval_string(ctx_llava->ctx_llama, std::string("<slice>").c_str(), params->n_batch, &n_past, false);
|
||||
for (size_t i = 0; i < (num_image_embeds-1)/num_image_embeds_col; ++i) {
|
||||
for (size_t j = 0; j < num_image_embeds_col; ++j) {
|
||||
eval_string(ctx_llava->ctx_llama, std::string("<image>").c_str(), params->n_batch, &n_past, false);
|
||||
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
|
||||
eval_string(ctx_llava->ctx_llama, std::string("</image>").c_str(), params->n_batch, &n_past, false);
|
||||
if (j == num_image_embeds_col - 1) {
|
||||
eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
eval_string(ctx_llava->ctx_llama, std::string("</slice>").c_str(), params->n_batch, &n_past, false);
|
||||
}
|
||||
else if (has_minicpmv_projector == 3 || has_minicpmv_projector == 4) {
|
||||
size_t num_image_embeds_col = clip_uhd_num_image_embeds_col(ctx_llava->ctx_clip);
|
||||
for (size_t i = 0; i < (num_image_embeds-1)/num_image_embeds_col; ++i) {
|
||||
for (size_t j = 0; j < num_image_embeds_col; ++j) {
|
||||
eval_string(ctx_llava->ctx_llama, std::string("<slice>").c_str(), params->n_batch, &n_past, false);
|
||||
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
|
||||
eval_string(ctx_llava->ctx_llama, std::string("</slice>").c_str(), params->n_batch, &n_past, false);
|
||||
if (j == num_image_embeds_col - 1) {
|
||||
eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG_INF("%s: image token past: %d\n", __func__, n_past);
|
||||
}
|
||||
|
||||
static const char * sample(struct common_sampler * smpl,
|
||||
struct llama_context * ctx_llama,
|
||||
int * n_past) {
|
||||
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
|
||||
common_sampler_accept(smpl, id, true);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx_llama);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
static std::string ret;
|
||||
if (llama_vocab_is_eog(vocab, id)) {
|
||||
ret = "</s>";
|
||||
} else {
|
||||
ret = common_token_to_piece(ctx_llama, id);
|
||||
}
|
||||
eval_id(ctx_llama, id, n_past);
|
||||
return ret.c_str();
|
||||
}
|
||||
|
||||
static struct llava_context * minicpmv_init(common_params * params, const std::string & fname, int &n_past){
|
||||
auto * ctx_clip = clip_init_context(params);
|
||||
auto * embeds = llava_image_embed_make_with_filename(ctx_clip, params->cpuparams.n_threads, fname.c_str());
|
||||
if (!embeds) {
|
||||
LOG_ERR("failed to load image %s. Terminating\n\n", fname.c_str());
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// process the prompt
|
||||
if (params->prompt.empty() && params->interactive == false) {
|
||||
LOG_ERR("prompt should be given or interactive mode should be on");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto * model = llava_init(params);
|
||||
if (model == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to init minicpmv model\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
const int64_t t_llava_init_start_us = ggml_time_us();
|
||||
auto * ctx_llava = llava_init_context(params, model);
|
||||
ctx_llava->ctx_clip = ctx_clip;
|
||||
const int64_t t_llava_init_end_us = ggml_time_us();
|
||||
float t_llava_init_ms = (t_llava_init_end_us - t_llava_init_start_us) / 1000.0;
|
||||
LOG_INF("%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms);
|
||||
|
||||
const int64_t t_process_image_start_us = ggml_time_us();
|
||||
process_image(ctx_llava, embeds, params, n_past);
|
||||
const int64_t t_process_image_end_us = ggml_time_us();
|
||||
float t_process_image_ms = (t_process_image_end_us - t_process_image_start_us) / 1000.0;
|
||||
LOG_INF("%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms);
|
||||
|
||||
llava_image_embed_free(embeds);
|
||||
return ctx_llava;
|
||||
}
|
||||
|
||||
static struct common_sampler * llama_init(struct llava_context * ctx_llava, common_params * params, const std::string & prompt, int & n_past, bool is_first = false){
|
||||
std::string user_prompt = prompt;
|
||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
|
||||
if (!is_first) {
|
||||
if (has_minicpmv_projector == 2) {
|
||||
user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt;
|
||||
}
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
user_prompt = "<|im_start|>user\n" + prompt;
|
||||
}
|
||||
else if (has_minicpmv_projector == 4) {
|
||||
user_prompt = "<|im_start|>user\n" + prompt;
|
||||
}
|
||||
}
|
||||
|
||||
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
|
||||
if (has_minicpmv_projector == 2) {
|
||||
eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false);
|
||||
}
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
|
||||
}
|
||||
else if (has_minicpmv_projector == 4) {
|
||||
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
|
||||
}
|
||||
|
||||
// generate the response
|
||||
|
||||
LOG_INF("\n");
|
||||
|
||||
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
|
||||
return smpl;
|
||||
}
|
||||
|
||||
static const char * llama_loop(struct llava_context * ctx_llava,struct common_sampler * smpl, int &n_past){
|
||||
|
||||
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.mmproj.path.empty() || (params.image.empty())) {
|
||||
show_additional_info(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (auto & image : params.image) {
|
||||
int n_past = 0;
|
||||
auto * ctx_llava = minicpmv_init(¶ms, image, n_past);
|
||||
|
||||
if (!params.prompt.empty()) {
|
||||
LOG("<user>%s\n", params.prompt.c_str());
|
||||
LOG("<assistant>");
|
||||
auto * smpl = llama_init(ctx_llava, ¶ms, params.prompt, n_past, true);
|
||||
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
||||
std::string response;
|
||||
bool have_tmp = false;
|
||||
for (int i = 0; i < max_tgt_len; i++) {
|
||||
const auto * tmp = llama_loop(ctx_llava, smpl, n_past);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0){
|
||||
if (!have_tmp) {
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||
have_tmp = true;
|
||||
printf("%s", tmp);
|
||||
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
common_sampler_free(smpl);
|
||||
}else {
|
||||
while (true) {
|
||||
LOG("<user>");
|
||||
std::string prompt;
|
||||
std::getline(std::cin, prompt);
|
||||
LOG("<assistant>");
|
||||
auto * smpl = llama_init(ctx_llava, ¶ms, prompt, n_past, true);
|
||||
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
||||
std::string response;
|
||||
for (int i = 0; i < max_tgt_len; i++) {
|
||||
const auto * tmp = llama_loop(ctx_llava, smpl, n_past);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0) break;
|
||||
printf("%s", tmp);// mistral llava-1.6
|
||||
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
|
||||
fflush(stdout);
|
||||
}
|
||||
common_sampler_free(smpl);
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -24,19 +24,22 @@
|
||||
#include <signal.h>
|
||||
#endif
|
||||
|
||||
static bool g_is_generating = false;
|
||||
// volatile, because of signal being an interrupt
|
||||
static volatile bool g_is_generating = false;
|
||||
static volatile bool g_is_interrupted = false;
|
||||
|
||||
/**
|
||||
* Please note that this is NOT a production-ready stuff.
|
||||
* It is a playground for trying Gemma 3 vision capabilities.
|
||||
* It is a playground for trying multimodal support in llama.cpp.
|
||||
* For contributors: please keep this code simple and easy to understand.
|
||||
*/
|
||||
|
||||
static void show_additional_info(int /*argc*/, char ** argv) {
|
||||
LOG(
|
||||
"Experimental CLI for using Gemma 3 vision model\n\n"
|
||||
"Experimental CLI for multimodal\n\n"
|
||||
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> -p <prompt>\n\n"
|
||||
" -m and --mmproj are required\n"
|
||||
" -hf user/repo can replace both -m and --mmproj in most cases\n"
|
||||
" --image and -p are optional, if NOT provided, the CLI will run in chat mode\n",
|
||||
argv[0]
|
||||
);
|
||||
@@ -49,14 +52,16 @@ static void sigint_handler(int signo) {
|
||||
g_is_generating = false;
|
||||
} else {
|
||||
console::cleanup();
|
||||
LOG("\nInterrupted by user\n");
|
||||
_exit(130);
|
||||
if (g_is_interrupted) {
|
||||
_exit(1);
|
||||
}
|
||||
g_is_interrupted = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
struct gemma3_context {
|
||||
struct mtmd_cli_context {
|
||||
mtmd_context_ptr ctx_vision;
|
||||
common_init_result llama_init;
|
||||
|
||||
@@ -70,18 +75,38 @@ struct gemma3_context {
|
||||
// so here we don't need to keep track of chat history
|
||||
common_chat_templates_ptr tmpls;
|
||||
|
||||
// support for legacy templates (models not having EOT token)
|
||||
llama_tokens antiprompt_tokens;
|
||||
|
||||
int n_threads = 1;
|
||||
llama_pos n_past = 0;
|
||||
|
||||
gemma3_context(common_params & params) : llama_init(common_init_from_params(params)) {
|
||||
mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) {
|
||||
model = llama_init.model.get();
|
||||
lctx = llama_init.context.get();
|
||||
vocab = llama_model_get_vocab(model);
|
||||
n_threads = params.cpuparams.n_threads;
|
||||
batch = llama_batch_init(params.n_batch, 0, 1);
|
||||
n_batch = params.n_batch;
|
||||
|
||||
if (!llama_model_chat_template(model, nullptr) && params.chat_template.empty()) {
|
||||
LOG_ERR("Model does not have chat template.\n");
|
||||
LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n");
|
||||
LOG_ERR(" For MobileVLM models, use '--chat-template deepseek'\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
tmpls = common_chat_templates_init(model, params.chat_template);
|
||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja).c_str());
|
||||
|
||||
init_vision_context(params);
|
||||
|
||||
// load antiprompt tokens for legacy templates
|
||||
if (params.chat_template == "vicuna") {
|
||||
antiprompt_tokens = common_tokenize(lctx, "ASSISTANT:", false, true);
|
||||
} else if (params.chat_template == "deepseek") {
|
||||
antiprompt_tokens = common_tokenize(lctx, "###", false, true);
|
||||
}
|
||||
}
|
||||
|
||||
void init_vision_context(common_params & params) {
|
||||
@@ -97,6 +122,17 @@ struct gemma3_context {
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
bool check_antiprompt(const llama_tokens & generated_tokens) {
|
||||
if (antiprompt_tokens.empty() || generated_tokens.size() < antiprompt_tokens.size()) {
|
||||
return false;
|
||||
}
|
||||
return std::equal(
|
||||
generated_tokens.end() - antiprompt_tokens.size(),
|
||||
generated_tokens.end(),
|
||||
antiprompt_tokens.begin()
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
struct decode_embd_batch {
|
||||
@@ -132,17 +168,19 @@ struct decode_embd_batch {
|
||||
}
|
||||
};
|
||||
|
||||
static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
|
||||
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
|
||||
llama_tokens generated_tokens;
|
||||
for (int i = 0; i < n_predict; i++) {
|
||||
if (i > n_predict || !g_is_generating) {
|
||||
if (i > n_predict || !g_is_generating || g_is_interrupted) {
|
||||
printf("\n");
|
||||
break;
|
||||
}
|
||||
|
||||
llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1);
|
||||
generated_tokens.push_back(token_id);
|
||||
common_sampler_accept(smpl, token_id, true);
|
||||
|
||||
if (llama_vocab_is_eog(ctx.vocab, token_id)) {
|
||||
if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
|
||||
printf("\n");
|
||||
break; // end of generation
|
||||
}
|
||||
@@ -150,6 +188,11 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
|
||||
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
|
||||
fflush(stdout);
|
||||
|
||||
if (g_is_interrupted) {
|
||||
printf("\n");
|
||||
break;
|
||||
}
|
||||
|
||||
// eval the token
|
||||
common_batch_clear(ctx.batch);
|
||||
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
|
||||
@@ -161,7 +204,7 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
|
||||
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
|
||||
std::vector<mtmd_bitmap> bitmaps;
|
||||
|
||||
common_chat_templates_inputs tmpl_inputs;
|
||||
@@ -184,18 +227,22 @@ static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector
|
||||
text.text = formatted_chat.prompt;
|
||||
text.add_special = add_bos;
|
||||
text.parse_special = true;
|
||||
mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps));
|
||||
if (chunks == nullptr) {
|
||||
LOG_ERR("Unable to tokenize prompt\n");
|
||||
mtmd_input_chunks chunks;
|
||||
|
||||
if (g_is_interrupted) return 0;
|
||||
|
||||
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
|
||||
if (res != 0) {
|
||||
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) {
|
||||
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
|
||||
LOG_ERR("Unable to eval prompt\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
ctx.n_past += mtmd_helper_get_n_tokens(chunks.get());
|
||||
ctx.n_past += mtmd_helper_get_n_tokens(chunks);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -217,7 +264,7 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
gemma3_context ctx(params);
|
||||
mtmd_cli_context ctx(params);
|
||||
printf("%s: %s\n", __func__, params.model.path.c_str());
|
||||
|
||||
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
|
||||
@@ -241,6 +288,8 @@ int main(int argc, char ** argv) {
|
||||
#endif
|
||||
}
|
||||
|
||||
if (g_is_interrupted) return 130;
|
||||
|
||||
if (is_single_turn) {
|
||||
g_is_generating = true;
|
||||
if (params.prompt.find("<__image__>") == std::string::npos) {
|
||||
@@ -252,7 +301,7 @@ int main(int argc, char ** argv) {
|
||||
if (eval_message(ctx, msg, params.image, true)) {
|
||||
return 1;
|
||||
}
|
||||
if (generate_response(ctx, smpl, n_predict)) {
|
||||
if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -267,12 +316,13 @@ int main(int argc, char ** argv) {
|
||||
std::vector<std::string> images_fname;
|
||||
std::string content;
|
||||
|
||||
while (true) {
|
||||
while (!g_is_interrupted) {
|
||||
g_is_generating = false;
|
||||
LOG("\n> ");
|
||||
console::set_display(console::user_input);
|
||||
std::string line;
|
||||
console::readline(line, false);
|
||||
if (g_is_interrupted) break;
|
||||
console::set_display(console::reset);
|
||||
line = string_strip(line);
|
||||
if (line.empty()) {
|
||||
@@ -300,6 +350,7 @@ int main(int argc, char ** argv) {
|
||||
msg.role = "user";
|
||||
msg.content = content;
|
||||
int ret = eval_message(ctx, msg, images_fname, is_first_msg);
|
||||
if (g_is_interrupted) break;
|
||||
if (ret == 2) {
|
||||
// non-fatal error
|
||||
images_fname.clear();
|
||||
@@ -317,6 +368,7 @@ int main(int argc, char ** argv) {
|
||||
is_first_msg = false;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
if (g_is_interrupted) LOG("\nInterrupted by user\n");
|
||||
llama_perf_context_print(ctx.lctx);
|
||||
return g_is_interrupted ? 130 : 0;
|
||||
}
|
||||
@@ -12,19 +12,43 @@
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
// slice template, used by some llava-uhd models to correctly place the special tokens around image embeddings
|
||||
// models not having it (llava-1.6) will process embeddings without any special tokens in-between
|
||||
enum mtmd_slice_tmpl {
|
||||
MTMD_SLICE_TMPL_NONE,
|
||||
MTMD_SLICE_TMPL_MINICPMV_2_5,
|
||||
MTMD_SLICE_TMPL_MINICPMV_2_6,
|
||||
// TODO @ngxson : add support for idefics (SmolVLM)
|
||||
};
|
||||
|
||||
struct mtmd_context {
|
||||
struct clip_ctx * ctx_clip;
|
||||
const struct llama_model * text_model;
|
||||
std::vector<float> image_embd_v; // image embedding vector
|
||||
|
||||
bool print_timings;
|
||||
int n_threads;
|
||||
std::string image_marker;
|
||||
|
||||
// for minicpmv, we need special tokens in-between slices
|
||||
mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE;
|
||||
llama_token tok_ov_img_start = LLAMA_TOKEN_NULL; // overview image
|
||||
llama_token tok_ov_img_end = LLAMA_TOKEN_NULL; // overview image
|
||||
llama_token tok_slices_start = LLAMA_TOKEN_NULL; // start of all slices
|
||||
llama_token tok_slices_end = LLAMA_TOKEN_NULL; // end of all slices
|
||||
llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice
|
||||
llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice
|
||||
llama_token tok_row_end = LLAMA_TOKEN_NULL; // end of row
|
||||
|
||||
// TODO @ngxson : add timings
|
||||
|
||||
mtmd_context(const char * mmproj_fname,
|
||||
const llama_model * text_model,
|
||||
const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) {
|
||||
const mtmd_context_params & ctx_params) :
|
||||
print_timings(ctx_params.print_timings),
|
||||
n_threads (ctx_params.n_threads),
|
||||
image_marker (ctx_params.image_marker)
|
||||
{
|
||||
clip_context_params ctx_clip_params;
|
||||
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
||||
ctx_clip_params.verbosity = ctx_params.verbosity;
|
||||
@@ -33,11 +57,66 @@ struct mtmd_context {
|
||||
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
|
||||
}
|
||||
this->text_model = text_model;
|
||||
|
||||
GGML_ASSERT(!clip_is_qwen2vl(ctx_clip) && "Qwen2VL model is not supported yet, use llama-qwen2vl-cli instead");
|
||||
|
||||
int minicpmv_version = clip_is_minicpmv(ctx_clip);
|
||||
if (minicpmv_version == 2) {
|
||||
// minicpmv 2.5 format:
|
||||
// <image> (overview) </image><slice><image> (slice) </image><image> (slice) </image>\n ... </slice>
|
||||
slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_5;
|
||||
tok_ov_img_start = lookup_token("<image>");
|
||||
tok_ov_img_end = lookup_token("</image>");
|
||||
tok_slices_start = lookup_token("<slice>");
|
||||
tok_slices_end = lookup_token("</slice>");
|
||||
tok_sli_img_start = tok_ov_img_start;
|
||||
tok_sli_img_end = tok_ov_img_end;
|
||||
tok_row_end = lookup_token("\n");
|
||||
|
||||
} else if (minicpmv_version == 3 || minicpmv_version == 4) {
|
||||
// minicpmv 2.6 format:
|
||||
// <image> (overview) </image><slice> (slice) </slice><slice> (slice) </slice>\n ...
|
||||
slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6;
|
||||
tok_ov_img_start = lookup_token("<image>");
|
||||
tok_ov_img_end = lookup_token("</image>");
|
||||
tok_sli_img_start = lookup_token("<slice>");
|
||||
tok_sli_img_end = lookup_token("</slice>");
|
||||
tok_row_end = lookup_token("\n");
|
||||
|
||||
} else if (minicpmv_version != 0) {
|
||||
GGML_ASSERT(false && "unsupported minicpmv version");
|
||||
}
|
||||
}
|
||||
|
||||
~mtmd_context() {
|
||||
clip_free(ctx_clip);
|
||||
}
|
||||
|
||||
private:
|
||||
llama_token lookup_token(const std::string & token_text) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(text_model);
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
if (token_to_piece(vocab, i, true) == token_text) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
|
||||
std::string token_to_piece(const llama_vocab * vocab, llama_token token, bool special) {
|
||||
std::string piece;
|
||||
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
||||
const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
|
||||
if (n_chars < 0) {
|
||||
piece.resize(-n_chars);
|
||||
int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
|
||||
GGML_ASSERT(check == -n_chars);
|
||||
} else {
|
||||
piece.resize(n_chars);
|
||||
}
|
||||
return piece;
|
||||
}
|
||||
};
|
||||
|
||||
struct mtmd_image_tokens_data {
|
||||
@@ -49,6 +128,7 @@ struct mtmd_image_tokens {
|
||||
uint32_t ny; // number of tokens in y direction
|
||||
uint32_t n_tokens() const { return nx * ny; }
|
||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||
std::string id; // optional user-defined ID, useful for KV cache tracking
|
||||
};
|
||||
|
||||
mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||
@@ -88,29 +168,78 @@ static std::vector<llama_token> mtmd_tokenize_text_internal(
|
||||
return result;
|
||||
}
|
||||
|
||||
mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
|
||||
const mtmd_input_text & text,
|
||||
const std::vector<mtmd_bitmap> & bitmaps) {
|
||||
mtmd_input_chunks * output = new mtmd_input_chunks;
|
||||
int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
std::vector<mtmd_input_chunk> & output,
|
||||
const mtmd_input_text & text,
|
||||
const std::vector<mtmd_bitmap> & bitmaps) {
|
||||
auto vocab = llama_model_get_vocab(ctx->text_model);
|
||||
|
||||
std::string prompt_modified(text.text);
|
||||
std::string marker_modified(ctx->image_marker);
|
||||
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
|
||||
|
||||
// a bit hacky here, but works for now
|
||||
// for some models, we need to add prefix and suffix to the image embeddings
|
||||
if (proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
if (clip_is_gemma3(ctx->ctx_clip)) {
|
||||
// gemma 3
|
||||
// <start_of_image> ... (image embeddings) ... <end_of_image>
|
||||
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
|
||||
marker_modified = "<fake_token_around_image><global-img>" + ctx->image_marker + "<fake_token_around_image>";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
||||
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
|
||||
marker_modified = ctx->image_marker + "[IMG_END]";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
}
|
||||
|
||||
std::vector<std::string> parts = string_split_str(text.text, ctx->image_marker);
|
||||
output->clear();
|
||||
output->reserve(parts.size());
|
||||
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
|
||||
// for glm-edge, we don't need to add because the tokens are already in the returned embeddings
|
||||
|
||||
// TODO @ngxson : glm-edge : remove BOI / EOI tokens embeddings, decode them as normal tokens
|
||||
|
||||
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
|
||||
output.clear();
|
||||
output.reserve(parts.size());
|
||||
|
||||
size_t i_img = 0;
|
||||
|
||||
// utility for adding raw tokens
|
||||
auto add_text_chunk = [&output](std::vector<llama_token> && tokens) {
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||
std::move(tokens),
|
||||
{},
|
||||
};
|
||||
output.emplace_back(std::move(chunk));
|
||||
};
|
||||
|
||||
// utility for splitting batch of multiple images into chunks of batch having single images
|
||||
auto split_batch_to_chunk = [&ctx](clip_image_f32_batch && batch_f32, const std::string & id) {
|
||||
std::vector<mtmd_input_chunk> chunks;
|
||||
|
||||
for (auto & entry : batch_f32.entries) {
|
||||
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
|
||||
image_tokens->nx = clip_n_patches_by_img(ctx->ctx_clip, entry.get());
|
||||
image_tokens->ny = 1;
|
||||
image_tokens->batch_f32.entries.push_back(std::move(entry));
|
||||
image_tokens->id = id;
|
||||
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||
{},
|
||||
std::move(image_tokens),
|
||||
};
|
||||
chunks.emplace_back(std::move(chunk));
|
||||
}
|
||||
|
||||
return chunks;
|
||||
};
|
||||
|
||||
for (const auto & part : parts) {
|
||||
//printf("tokenizing part: %s\n", part.c_str());
|
||||
bool add_bos = &parts.front() == ∂
|
||||
@@ -123,66 +252,161 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
|
||||
std::move(tokens),
|
||||
{},
|
||||
};
|
||||
output->emplace_back(std::move(chunk));
|
||||
output.emplace_back(std::move(chunk));
|
||||
|
||||
if (&parts.back() != &part) {
|
||||
// add image token to middle of 2 parts
|
||||
|
||||
if (i_img >= bitmaps.size()) {
|
||||
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
|
||||
return nullptr;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// shim layer
|
||||
// convert mtmd_bitmap to clip_image_u8
|
||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||
img_u8->nx = bitmaps[i_img].nx;
|
||||
img_u8->ny = bitmaps[i_img].ny;
|
||||
img_u8->buf.resize(bitmaps[i_img].data.size());
|
||||
std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3);
|
||||
clip_image_size img_u8_size{img_u8->nx, img_u8->ny};
|
||||
|
||||
// preprocess image
|
||||
clip_image_f32_batch batch_f32;
|
||||
bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32);
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to preprocess image\n");
|
||||
return nullptr;
|
||||
return 2;
|
||||
}
|
||||
|
||||
mtmd_image_tokens * image_tokens = new mtmd_image_tokens;
|
||||
image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
|
||||
image_tokens->ny = 1; // TODO
|
||||
image_tokens->batch_f32 = std::move(batch_f32);
|
||||
if (ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6) {
|
||||
// split batch into chunks of single images
|
||||
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img].id);
|
||||
GGML_ASSERT(chunks.size() > 0);
|
||||
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||
{},
|
||||
image_tokens,
|
||||
};
|
||||
output->emplace_back(std::move(chunk));
|
||||
i_img++;
|
||||
// add overview image
|
||||
add_text_chunk({ctx->tok_ov_img_start});
|
||||
output.emplace_back(std::move(chunks.front()));
|
||||
chunks.erase(chunks.begin());
|
||||
add_text_chunk({ctx->tok_ov_img_end});
|
||||
|
||||
// add slices
|
||||
if (!chunks.empty()) {
|
||||
clip_add_load_image_size(ctx->ctx_clip, &img_u8_size);
|
||||
int n_col = clip_uhd_num_image_embeds_col(ctx->ctx_clip);
|
||||
int n_row = (int)chunks.size() / n_col;
|
||||
GGML_ASSERT(n_row * n_col == (int)chunks.size());
|
||||
if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) {
|
||||
add_text_chunk({ctx->tok_slices_start});
|
||||
}
|
||||
for (int y = 0; y < n_row; y++) {
|
||||
for (int x = 0; x < n_col; x++) {
|
||||
if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) {
|
||||
add_text_chunk({ctx->tok_sli_img_start});
|
||||
}
|
||||
output.emplace_back(std::move(chunks[y * n_col + x]));
|
||||
if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) {
|
||||
add_text_chunk({ctx->tok_sli_img_end});
|
||||
}
|
||||
}
|
||||
if (ctx->tok_row_end != LLAMA_TOKEN_NULL && y != n_row - 1) {
|
||||
add_text_chunk({ctx->tok_row_end});
|
||||
}
|
||||
}
|
||||
if (ctx->tok_slices_end != LLAMA_TOKEN_NULL) {
|
||||
add_text_chunk({ctx->tok_slices_end});
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
size_t n_tokens = 0;
|
||||
for (const auto & entry : batch_f32.entries) {
|
||||
n_tokens += clip_n_patches_by_img(ctx->ctx_clip, entry.get());
|
||||
}
|
||||
|
||||
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
|
||||
image_tokens->nx = n_tokens;
|
||||
image_tokens->ny = 1; // TODO
|
||||
image_tokens->batch_f32 = std::move(batch_f32);
|
||||
image_tokens->id = bitmaps[i_img].id; // optional
|
||||
|
||||
LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx);
|
||||
LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny);
|
||||
LOG_DBG("batch_f32 size = %d\n", (int)image_tokens->batch_f32.entries.size());
|
||||
|
||||
if (clip_is_glm(ctx->ctx_clip)) {
|
||||
// glm-edge
|
||||
image_tokens->nx += 2; // add 2 for the begin_of_image and end_of_image token embeddings
|
||||
}
|
||||
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||
{},
|
||||
std::move(image_tokens),
|
||||
};
|
||||
output.emplace_back(std::move(chunk));
|
||||
}
|
||||
|
||||
i_img++; // move to next image
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
return 0;
|
||||
}
|
||||
|
||||
void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
|
||||
for (auto & chunk : *chunks) {
|
||||
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
|
||||
delete chunk.tokens_image;
|
||||
}
|
||||
void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
|
||||
if (image_tokens) {
|
||||
delete image_tokens;
|
||||
}
|
||||
delete chunks;
|
||||
}
|
||||
|
||||
size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
|
||||
return image_tokens->n_tokens();
|
||||
}
|
||||
|
||||
size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) {
|
||||
return image_tokens->nx;
|
||||
}
|
||||
|
||||
size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
|
||||
return image_tokens->ny;
|
||||
}
|
||||
|
||||
std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
|
||||
return image_tokens->id;
|
||||
}
|
||||
|
||||
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
|
||||
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
|
||||
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
|
||||
bool ok = clip_image_batch_encode(
|
||||
ctx->ctx_clip,
|
||||
ctx->n_threads,
|
||||
&image_tokens->batch_f32,
|
||||
ctx->image_embd_v.data());
|
||||
bool ok = false;
|
||||
|
||||
// only effective for minicpmv and qwen2vl, other models will ignore load_image_size
|
||||
{
|
||||
clip_image_size slice_size{
|
||||
image_tokens->batch_f32.entries[0]->nx,
|
||||
image_tokens->batch_f32.entries[0]->ny};
|
||||
clip_add_load_image_size(ctx->ctx_clip, &slice_size);
|
||||
}
|
||||
|
||||
if (clip_is_llava(ctx->ctx_clip) || clip_is_minicpmv(ctx->ctx_clip) || clip_is_glm(ctx->ctx_clip)) {
|
||||
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
|
||||
const auto & entries = image_tokens->batch_f32.entries;
|
||||
for (size_t i = 0; i < entries.size(); i++) {
|
||||
int n_tokens_per_image = clip_n_patches_by_img(ctx->ctx_clip, entries[i].get());
|
||||
ok = clip_image_encode(
|
||||
ctx->ctx_clip,
|
||||
ctx->n_threads,
|
||||
entries[i].get(),
|
||||
ctx->image_embd_v.data() + i*n_mmproj_embd*n_tokens_per_image);
|
||||
}
|
||||
} else {
|
||||
ok = clip_image_batch_encode(
|
||||
ctx->ctx_clip,
|
||||
ctx->n_threads,
|
||||
&image_tokens->batch_f32,
|
||||
ctx->image_embd_v.data());
|
||||
}
|
||||
|
||||
return ok ? 0 : 1;
|
||||
}
|
||||
|
||||
@@ -190,9 +414,9 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
|
||||
return ctx->image_embd_v.data();
|
||||
}
|
||||
|
||||
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) {
|
||||
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
|
||||
size_t n_tokens = 0;
|
||||
for (auto & chunk : *chunks) {
|
||||
for (auto & chunk : chunks) {
|
||||
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
n_tokens += chunk.tokens_text.size();
|
||||
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
@@ -241,35 +465,38 @@ struct decode_embd_batch {
|
||||
|
||||
int32_t mtmd_helper_eval(mtmd_context * ctx,
|
||||
llama_context * lctx,
|
||||
mtmd_input_chunks * chunks,
|
||||
mtmd_input_chunks & chunks,
|
||||
llama_pos pos0,
|
||||
llama_seq_id seq_id,
|
||||
int32_t n_batch) {
|
||||
int32_t ret;
|
||||
llama_pos n_past = pos0;
|
||||
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
|
||||
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
|
||||
|
||||
for (auto & chunk : *chunks) {
|
||||
bool is_last = &chunk == &chunks->back();
|
||||
for (auto & chunk : chunks) {
|
||||
bool is_last = &chunk == &chunks.back();
|
||||
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
// TODO @ngxson : may need to split into smaller batches
|
||||
text_batch.n_tokens = chunk.tokens_text.size();
|
||||
for (size_t i = 0; i < chunk.tokens_text.size(); i++) {
|
||||
text_batch.token [i] = chunk.tokens_text[i];
|
||||
text_batch.pos [i] = n_past++;
|
||||
text_batch.n_seq_id[i] = 1;
|
||||
text_batch.seq_id [i][0] = seq_id;
|
||||
text_batch.logits [i] = false;
|
||||
}
|
||||
if (is_last) {
|
||||
// always get logits for last input chunk
|
||||
text_batch.logits[text_batch.n_tokens - 1] = true;
|
||||
}
|
||||
ret = llama_decode(lctx, text_batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode text\n");
|
||||
llama_batch_free(text_batch);
|
||||
return ret;
|
||||
size_t i = 0;
|
||||
while (i < chunk.tokens_text.size()) { // split into batches
|
||||
for (; i < chunk.tokens_text.size() && text_batch.n_tokens < n_batch; i++) {
|
||||
text_batch.token [i] = chunk.tokens_text[i];
|
||||
text_batch.pos [i] = n_past++;
|
||||
text_batch.n_seq_id[i] = 1;
|
||||
text_batch.seq_id [i][0] = seq_id;
|
||||
text_batch.logits [i] = false;
|
||||
}
|
||||
if (is_last) {
|
||||
// always get logits for last input chunk
|
||||
text_batch.logits[text_batch.n_tokens - 1] = true;
|
||||
}
|
||||
ret = llama_decode(lctx, text_batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode text\n");
|
||||
llama_batch_free(text_batch);
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
@@ -277,33 +504,56 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
|
||||
GGML_ASSERT(chunk.tokens_image != nullptr);
|
||||
int64_t t0 = ggml_time_ms();
|
||||
if (ctx->print_timings) {
|
||||
LOG_INF("encoding image...\n");
|
||||
LOG_INF("encoding image or slice...\n");
|
||||
}
|
||||
ret = mtmd_encode(ctx, chunk.tokens_image);
|
||||
ret = mtmd_encode(ctx, chunk.tokens_image.get());
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to encode image\n");
|
||||
llama_batch_free(text_batch);
|
||||
return ret;
|
||||
}
|
||||
if (ctx->print_timings) {
|
||||
LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
||||
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
||||
}
|
||||
|
||||
int32_t n_tokens = chunk.tokens_image->n_tokens();
|
||||
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
|
||||
int32_t i_batch = 0;
|
||||
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
|
||||
float * embd = mtmd_get_output_embd(ctx);
|
||||
decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
ret = llama_decode(lctx, batch_img.batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode image\n");
|
||||
llama_batch_free(text_batch);
|
||||
return ret;
|
||||
}
|
||||
if (ctx->print_timings) {
|
||||
LOG_INF("image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
llama_set_causal_attn(lctx, false);
|
||||
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
|
||||
}
|
||||
|
||||
n_past += n_tokens;
|
||||
while (i_batch < n_img_batches) { // split into batches
|
||||
int32_t pos_offset = i_batch*n_batch;
|
||||
int32_t n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
|
||||
float * embd_batch = embd + pos_offset*n_mmproj_embd;
|
||||
decode_embd_batch batch_img(embd_batch, n_tokens_batch, n_past, 0);
|
||||
|
||||
printf("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
ret = llama_decode(lctx, batch_img.batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode image\n");
|
||||
llama_set_causal_attn(lctx, true); // restore causal attn
|
||||
llama_batch_free(text_batch);
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (ctx->print_timings) {
|
||||
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
|
||||
}
|
||||
|
||||
i_batch++;
|
||||
n_past += n_tokens_batch;
|
||||
}
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
llama_set_causal_attn(lctx, true);
|
||||
}
|
||||
|
||||
} else {
|
||||
GGML_ASSERT(false && "chunk type not supported");
|
||||
@@ -339,3 +589,15 @@ int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & outp
|
||||
std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
|
||||
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
|
||||
if (proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
|
||||
mtmd_image_tokens_free(val);
|
||||
}
|
||||
|
||||
@@ -39,12 +39,18 @@ struct mtmd_bitmap {
|
||||
uint32_t nx;
|
||||
uint32_t ny;
|
||||
std::vector<unsigned char> data;
|
||||
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
|
||||
};
|
||||
|
||||
struct mtmd_image_tokens_deleter {
|
||||
void operator()(mtmd_image_tokens * val); // forward declaration
|
||||
};
|
||||
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
|
||||
|
||||
struct mtmd_input_chunk {
|
||||
mtmd_input_chunk_type type;
|
||||
std::vector<llama_token> tokens_text;
|
||||
mtmd_image_tokens * tokens_image = nullptr;
|
||||
mtmd_image_tokens_ptr tokens_image;
|
||||
};
|
||||
|
||||
using mtmd_input_chunks = std::vector<mtmd_input_chunk>;
|
||||
@@ -82,12 +88,21 @@ MTMD_API void mtmd_free(mtmd_context * ctx);
|
||||
// 3. "<end_of_image>\ndescribe it in detail."
|
||||
// number of bitmaps must be equal to the number of image markers in the prompt
|
||||
// this function is thread-safe (shared ctx)
|
||||
MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
|
||||
// return values:
|
||||
// 0 on success
|
||||
// 1 on number of images not matching the number of markers
|
||||
// 2 on image preprocessing error
|
||||
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
std::vector<mtmd_input_chunk> & output,
|
||||
const mtmd_input_text & text,
|
||||
const std::vector<mtmd_bitmap> & bitmaps);
|
||||
|
||||
// free image chunk data
|
||||
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
|
||||
// access mtmd_image_tokens
|
||||
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens);
|
||||
|
||||
// returns 0 on success
|
||||
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
|
||||
@@ -96,12 +111,17 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
|
||||
// get output embeddings from the last encode pass
|
||||
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||
|
||||
// whether we need to set non-causal mask before llama_decode
|
||||
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
|
||||
|
||||
|
||||
|
||||
//
|
||||
// helper functions (can be implemented based on other functions)
|
||||
//
|
||||
|
||||
// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past
|
||||
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
|
||||
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks);
|
||||
|
||||
// helper function that automatically:
|
||||
// 1. run llama_decode() on text chunks
|
||||
@@ -110,7 +130,7 @@ MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
|
||||
// otherwise, returns 0 on success
|
||||
MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx,
|
||||
llama_context * lctx,
|
||||
mtmd_input_chunks * chunks,
|
||||
mtmd_input_chunks & chunks,
|
||||
llama_pos pos0,
|
||||
llama_seq_id seq_id,
|
||||
int32_t n_batch);
|
||||
@@ -132,11 +152,6 @@ struct mtmd_context_deleter {
|
||||
};
|
||||
using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
|
||||
|
||||
struct mtmd_input_chunks_deleter {
|
||||
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
|
||||
};
|
||||
using mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;
|
||||
|
||||
#else
|
||||
|
||||
static_assert(false && "C header is not yet supported by this library");
|
||||
|
||||
@@ -13,30 +13,60 @@ mkdir -p $SCRIPT_DIR/output
|
||||
PROJ_ROOT="$SCRIPT_DIR/../.."
|
||||
cd $PROJ_ROOT
|
||||
|
||||
# Check if the first argument is "big", then run test with big models
|
||||
# This is useful if we're running the script on a larger machine, so we can test the big models
|
||||
RUN_BIG_TESTS=false
|
||||
if [ "${1:-}" = "big" ]; then
|
||||
RUN_BIG_TESTS=true
|
||||
echo "Include BIG models..."
|
||||
fi
|
||||
|
||||
###############
|
||||
|
||||
arr_bin=()
|
||||
arr_hf=()
|
||||
arr_tmpl=() # chat template
|
||||
|
||||
add_test() {
|
||||
local bin=$1
|
||||
local hf=$2
|
||||
local tmpl=${3:-""} # default to empty string if not provided
|
||||
arr_bin+=("$bin")
|
||||
arr_hf+=("$hf")
|
||||
arr_tmpl+=("$tmpl")
|
||||
}
|
||||
|
||||
add_test "llama-gemma3-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
|
||||
add_test "llama-llava-cli" "cmp-nct/Yi-VL-6B-GGUF:Q5_K"
|
||||
add_test "llama-llava-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M"
|
||||
add_test "llama-llava-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
|
||||
add_test "llama-llava-cli" "second-state/Llava-v1.5-7B-GGUF:Q2_K"
|
||||
add_test "llama-llava-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K"
|
||||
add_test "llama-llava-cli" "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M"
|
||||
add_test "llama-minicpmv-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted
|
||||
add_test "llama-minicpmv-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
|
||||
add_test "llama-minicpmv-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
|
||||
add_test_big() {
|
||||
if [ "$RUN_BIG_TESTS" = true ]; then
|
||||
add_test "$@"
|
||||
fi
|
||||
}
|
||||
|
||||
add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
|
||||
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
|
||||
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
|
||||
add_test "llama-mtmd-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
|
||||
add_test "llama-mtmd-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek"
|
||||
add_test "llama-mtmd-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
|
||||
add_test "llama-mtmd-cli" "second-state/Llava-v1.5-7B-GGUF:Q2_K" "vicuna"
|
||||
add_test "llama-mtmd-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K" "vicuna"
|
||||
add_test "llama-mtmd-cli" "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M"
|
||||
add_test "llama-mtmd-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted
|
||||
add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
|
||||
add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
|
||||
add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
|
||||
|
||||
# to test the big models, run: ./tests.sh big
|
||||
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
|
||||
|
||||
# these models always give the wrong answer, not sure why
|
||||
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
|
||||
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-256M-Instruct-GGUF:Q8_0"
|
||||
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-256M-Video-Instruct-GGUF:Q8_0"
|
||||
|
||||
# this model has broken chat template, not usable
|
||||
# add_test "llama-mtmd-cli" "cmp-nct/Yi-VL-6B-GGUF:Q5_K"
|
||||
|
||||
###############
|
||||
|
||||
cmake --build build -j --target "${arr_bin[@]}"
|
||||
@@ -46,12 +76,20 @@ arr_res=()
|
||||
for i in "${!arr_bin[@]}"; do
|
||||
bin="${arr_bin[$i]}"
|
||||
hf="${arr_hf[$i]}"
|
||||
tmpl="${arr_tmpl[$i]}"
|
||||
|
||||
echo "Running test with binary: $bin and HF model: $hf"
|
||||
echo ""
|
||||
echo ""
|
||||
|
||||
output=$("$PROJ_ROOT/build/bin/$bin" -hf "$hf" --image $SCRIPT_DIR/test-1.jpeg -p "what is the publisher name of the newspaper?" --temp 0 2>&1 | tee /dev/tty)
|
||||
output=$(\
|
||||
"$PROJ_ROOT/build/bin/$bin" \
|
||||
-hf "$hf" \
|
||||
--image $SCRIPT_DIR/test-1.jpeg \
|
||||
-p "what is the publisher name of the newspaper?" \
|
||||
--temp 0 -n 128 \
|
||||
${tmpl:+--chat-template "$tmpl"} \
|
||||
2>&1 | tee /dev/tty)
|
||||
|
||||
echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log
|
||||
|
||||
|
||||
@@ -865,9 +865,22 @@ int main(int argc, char ** argv) {
|
||||
console::set_display(console::reset);
|
||||
display = true;
|
||||
|
||||
// Add tokens to embd only if the input buffer is non-empty
|
||||
// Entering a empty line lets the user pass control back
|
||||
if (buffer.length() > 1) {
|
||||
if (buffer.empty()) { // Ctrl+D on empty line exits
|
||||
LOG("EOF by user\n");
|
||||
break;
|
||||
}
|
||||
|
||||
if (buffer.back() == '\n') {
|
||||
// Implement #587:
|
||||
// If the user wants the text to end in a newline,
|
||||
// this should be accomplished by explicitly adding a newline by using \ followed by return,
|
||||
// then returning control by pressing return again.
|
||||
buffer.pop_back();
|
||||
}
|
||||
|
||||
if (buffer.empty()) { // Enter key on empty line lets the user pass control back
|
||||
LOG_DBG("empty line, passing control back\n");
|
||||
} else { // Add tokens to embd only if the input buffer is non-empty
|
||||
// append input suffix if any
|
||||
if (!params.input_suffix.empty() && !params.conversation_mode) {
|
||||
LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str());
|
||||
@@ -915,8 +928,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
n_remain -= line_inp.size();
|
||||
LOG_DBG("n_remain: %d\n", n_remain);
|
||||
} else {
|
||||
LOG_DBG("empty line, passing control back\n");
|
||||
}
|
||||
|
||||
input_echo = false; // do not echo this again
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <fstream>
|
||||
#include <cmath>
|
||||
#include <cctype>
|
||||
#include <algorithm>
|
||||
|
||||
struct quant_option {
|
||||
std::string name;
|
||||
@@ -16,7 +17,7 @@ struct quant_option {
|
||||
std::string desc;
|
||||
};
|
||||
|
||||
static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
||||
static const std::vector<quant_option> QUANT_OPTIONS = {
|
||||
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
|
||||
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", },
|
||||
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", },
|
||||
@@ -105,7 +106,8 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
|
||||
//
|
||||
[[noreturn]]
|
||||
static void usage(const char * executable) {
|
||||
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
|
||||
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable);
|
||||
printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
|
||||
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
|
||||
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
|
||||
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
|
||||
@@ -114,6 +116,8 @@ static void usage(const char * executable) {
|
||||
printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n");
|
||||
printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n");
|
||||
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
|
||||
printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n");
|
||||
printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n");
|
||||
printf(" --keep-split: will generate quantized model in the same shards as input\n");
|
||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
|
||||
@@ -244,6 +248,107 @@ static ggml_type parse_ggml_type(const char * arg) {
|
||||
return GGML_TYPE_COUNT;
|
||||
}
|
||||
|
||||
// Allowed tensors for arbitrary quantization with --tensor-type option
|
||||
static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
|
||||
"attn_k",
|
||||
"attn_kv_a_mqa",
|
||||
"attn_kv_b",
|
||||
"attn_o",
|
||||
"attn_output",
|
||||
"attn_q",
|
||||
"attn_q_a",
|
||||
"attn_q_b",
|
||||
"attn_qkv",
|
||||
"attn_v",
|
||||
"channel_mix_key",
|
||||
"channel_mix_receptance",
|
||||
"channel_mix_value",
|
||||
"cls",
|
||||
"cls.output",
|
||||
"cross_attn_k",
|
||||
"cross_attn_o",
|
||||
"cross_attn_q",
|
||||
"cross_attn_v",
|
||||
"ffn_act",
|
||||
"ffn_down",
|
||||
"ffn_down_exps",
|
||||
"ffn_down_shexp",
|
||||
"ffn_gate",
|
||||
"ffn_gate_exps",
|
||||
"ffn_gate_shexp",
|
||||
"ffn_up",
|
||||
"ffn_up_exps",
|
||||
"ffn_up_shexp",
|
||||
"ssm_in",
|
||||
"ssm_out",
|
||||
"time_mix_gate",
|
||||
"time_mix_key",
|
||||
"time_mix_output",
|
||||
"time_mix_receptance",
|
||||
"time_mix_value",
|
||||
};
|
||||
|
||||
// changes to this struct must be replicated in llama-quant.cpp
|
||||
struct tensor_quantization {
|
||||
std::string name;
|
||||
ggml_type quant = GGML_TYPE_COUNT;
|
||||
};
|
||||
|
||||
static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
|
||||
const char * sep = strchr(data, '=');
|
||||
if (sep == nullptr) {
|
||||
printf("\n%s: malformed tensor type '%s'\n\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t tn_len = sep - data;
|
||||
if (tn_len == 0) {
|
||||
printf("\n%s: missing tensor name\n\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const size_t qt_len = strlen(sep); qt_len == 1) {
|
||||
printf("\n%s: missing quantization type\n\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string tn(data, tn_len);
|
||||
std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
|
||||
sep++;
|
||||
const std::string qt(sep);
|
||||
|
||||
bool found = false;
|
||||
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
|
||||
std::string tensor;
|
||||
tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn;
|
||||
// handle special case of cls.output
|
||||
std::string cls_output = "cls.output";
|
||||
if (tn.find(cls_output) != std::string::npos) {
|
||||
tensor = "cls.output";
|
||||
}
|
||||
// check if an allowed tensor exists and it's at the end of the kv string
|
||||
if (tensor == allowed) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) {
|
||||
printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_quantization tqz;
|
||||
tqz.name = tn;
|
||||
tqz.quant = parse_ggml_type(qt.c_str());
|
||||
tensor_type.emplace_back(std::move(tqz));
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
if (argc < 3) {
|
||||
usage(argv[0]);
|
||||
@@ -255,6 +360,7 @@ int main(int argc, char ** argv) {
|
||||
std::string imatrix_file;
|
||||
std::vector<std::string> included_weights, excluded_weights;
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
std::vector<tensor_quantization> tensor_types;
|
||||
|
||||
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
|
||||
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
|
||||
@@ -277,6 +383,10 @@ int main(int argc, char ** argv) {
|
||||
} else {
|
||||
usage(argv[0]);
|
||||
}
|
||||
} else if (strcmp(argv[arg_idx], "--tensor-type") == 0) {
|
||||
if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) {
|
||||
usage(argv[0]);
|
||||
}
|
||||
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) {
|
||||
if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) {
|
||||
usage(argv[0]);
|
||||
@@ -361,6 +471,9 @@ int main(int argc, char ** argv) {
|
||||
kv_overrides.back().key[0] = 0;
|
||||
params.kv_overrides = &kv_overrides;
|
||||
}
|
||||
if (!tensor_types.empty()) {
|
||||
params.tensor_types = &tensor_types;
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
|
||||
#include "ggml-rpc.h"
|
||||
#ifdef _WIN32
|
||||
# define NOMINMAX
|
||||
# define DIRECTORY_SEPARATOR '\\'
|
||||
# include <locale>
|
||||
# include <windows.h>
|
||||
@@ -37,6 +38,8 @@
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
#include <filesystem>
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
@@ -126,7 +129,7 @@ static std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#ifdef __linux__
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else {
|
||||
@@ -136,7 +139,9 @@ static std::string fs_get_cache_directory() {
|
||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||
#elif defined(_WIN32)
|
||||
cache_directory = std::getenv("LOCALAPPDATA");
|
||||
#endif // __linux__
|
||||
#else
|
||||
# error Unknown architecture
|
||||
#endif
|
||||
cache_directory = ensure_trailing_slash(cache_directory);
|
||||
cache_directory += "llama.cpp";
|
||||
}
|
||||
@@ -148,12 +153,14 @@ struct rpc_server_params {
|
||||
int port = 50052;
|
||||
size_t backend_mem = 0;
|
||||
bool use_cache = false;
|
||||
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
|
||||
};
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
|
||||
fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||
fprintf(stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n", params.n_threads);
|
||||
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
|
||||
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
|
||||
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
|
||||
@@ -170,6 +177,15 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
|
||||
return false;
|
||||
}
|
||||
params.host = argv[i];
|
||||
} else if (arg == "-t" || arg == "--threads") {
|
||||
if (++i >= argc) {
|
||||
return false;
|
||||
}
|
||||
params.n_threads = std::stoi(argv[i]);
|
||||
if (params.n_threads <= 0) {
|
||||
fprintf(stderr, "error: invalid number of threads: %d\n", params.n_threads);
|
||||
return false;
|
||||
}
|
||||
} else if (arg == "-p" || arg == "--port") {
|
||||
if (++i >= argc) {
|
||||
return false;
|
||||
@@ -197,7 +213,7 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
|
||||
return true;
|
||||
}
|
||||
|
||||
static ggml_backend_t create_backend() {
|
||||
static ggml_backend_t create_backend(const rpc_server_params & params) {
|
||||
ggml_backend_t backend = NULL;
|
||||
#ifdef GGML_USE_CUDA
|
||||
fprintf(stderr, "%s: using CUDA backend\n", __func__);
|
||||
@@ -229,6 +245,7 @@ static ggml_backend_t create_backend() {
|
||||
if (!backend) {
|
||||
fprintf(stderr, "%s: using CPU backend\n", __func__);
|
||||
backend = ggml_backend_cpu_init();
|
||||
ggml_backend_cpu_set_n_threads(backend, params.n_threads);
|
||||
}
|
||||
return backend;
|
||||
}
|
||||
@@ -273,7 +290,7 @@ int main(int argc, char * argv[]) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
ggml_backend_t backend = create_backend();
|
||||
ggml_backend_t backend = create_backend(params);
|
||||
if (!backend) {
|
||||
fprintf(stderr, "Failed to create backend\n");
|
||||
return 1;
|
||||
@@ -295,7 +312,10 @@ int main(int argc, char * argv[]) {
|
||||
}
|
||||
cache_dir = cache_dir_str.c_str();
|
||||
}
|
||||
printf("Starting RPC server\n");
|
||||
printf("Starting RPC server v%d.%d.%d\n",
|
||||
RPC_PROTO_MAJOR_VERSION,
|
||||
RPC_PROTO_MINOR_VERSION,
|
||||
RPC_PROTO_PATCH_VERSION);
|
||||
printf(" endpoint : %s\n", endpoint.c_str());
|
||||
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
|
||||
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
|
||||
|
||||
@@ -1552,29 +1552,30 @@ struct server_queue {
|
||||
std::condition_variable condition_tasks;
|
||||
|
||||
// callback functions
|
||||
std::function<void(server_task)> callback_new_task;
|
||||
std::function<void(void)> callback_update_slots;
|
||||
std::function<void(server_task &&)> callback_new_task;
|
||||
std::function<void(void)> callback_update_slots;
|
||||
|
||||
// Add a new task to the end of the queue
|
||||
int post(server_task task, bool front = false) {
|
||||
int post(server_task && task, bool front = false) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
GGML_ASSERT(task.id != -1);
|
||||
// if this is cancel task make sure to clean up pending tasks
|
||||
if (task.type == SERVER_TASK_TYPE_CANCEL) {
|
||||
cleanup_pending_task(task.id_target);
|
||||
}
|
||||
QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
|
||||
const int task_id = task.id;
|
||||
QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
|
||||
if (front) {
|
||||
queue_tasks.push_front(std::move(task));
|
||||
} else {
|
||||
queue_tasks.push_back(std::move(task));
|
||||
}
|
||||
condition_tasks.notify_one();
|
||||
return task.id;
|
||||
return task_id;
|
||||
}
|
||||
|
||||
// multi-task version of post()
|
||||
int post(std::vector<server_task> & tasks, bool front = false) {
|
||||
int post(std::vector<server_task> && tasks, bool front = false) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
for (auto & task : tasks) {
|
||||
if (task.id == -1) {
|
||||
@@ -1596,7 +1597,7 @@ struct server_queue {
|
||||
}
|
||||
|
||||
// Add a new task, but defer until one slot is available
|
||||
void defer(server_task task) {
|
||||
void defer(server_task && task) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
QUE_DBG("defer task, id = %d\n", task.id);
|
||||
queue_tasks_deferred.push_back(std::move(task));
|
||||
@@ -1611,7 +1612,7 @@ struct server_queue {
|
||||
}
|
||||
|
||||
// Register function to process a new task
|
||||
void on_new_task(std::function<void(server_task)> callback) {
|
||||
void on_new_task(std::function<void(server_task &&)> callback) {
|
||||
callback_new_task = std::move(callback);
|
||||
}
|
||||
|
||||
@@ -1660,7 +1661,7 @@ struct server_queue {
|
||||
lock.unlock();
|
||||
break;
|
||||
}
|
||||
server_task task = queue_tasks.front();
|
||||
server_task task = std::move(queue_tasks.front());
|
||||
queue_tasks.pop_front();
|
||||
lock.unlock();
|
||||
|
||||
@@ -2004,7 +2005,7 @@ struct server_context {
|
||||
|
||||
slot.reset();
|
||||
|
||||
slots.push_back(slot);
|
||||
slots.push_back(std::move(slot));
|
||||
}
|
||||
|
||||
default_generation_settings_for_props = slots[0].to_json();
|
||||
@@ -2105,7 +2106,7 @@ struct server_context {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
||||
bool launch_slot_with_task(server_slot & slot, server_task && task) {
|
||||
slot.reset();
|
||||
slot.id_task = task.id;
|
||||
slot.index = task.index;
|
||||
@@ -2113,10 +2114,10 @@ struct server_context {
|
||||
slot.params = std::move(task.params);
|
||||
slot.prompt_tokens = std::move(task.prompt_tokens);
|
||||
|
||||
if (!are_lora_equal(task.params.lora, slot.lora)) {
|
||||
if (!are_lora_equal(slot.params.lora, slot.lora)) {
|
||||
// if lora is changed, we cannot reuse cached tokens
|
||||
slot.cache_tokens.clear();
|
||||
slot.lora = task.params.lora;
|
||||
slot.lora = slot.params.lora;
|
||||
}
|
||||
|
||||
bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
|
||||
@@ -2547,10 +2548,10 @@ struct server_context {
|
||||
server_task task(SERVER_TASK_TYPE_CANCEL);
|
||||
task.id_target = id_task;
|
||||
queue_results.remove_waiting_task_id(id_task);
|
||||
cancel_tasks.push_back(task);
|
||||
cancel_tasks.push_back(std::move(task));
|
||||
}
|
||||
// push to beginning of the queue, so it has highest priority
|
||||
queue_tasks.post(cancel_tasks, true);
|
||||
queue_tasks.post(std::move(cancel_tasks), true);
|
||||
}
|
||||
|
||||
// receive the results from task(s)
|
||||
@@ -2637,7 +2638,7 @@ struct server_context {
|
||||
// Functions to process the task
|
||||
//
|
||||
|
||||
void process_single_task(server_task task) {
|
||||
void process_single_task(server_task && task) {
|
||||
switch (task.type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
case SERVER_TASK_TYPE_INFILL:
|
||||
@@ -2651,17 +2652,17 @@ struct server_context {
|
||||
if (slot == nullptr) {
|
||||
// if no slot is available, we defer this task for processing later
|
||||
SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
|
||||
queue_tasks.defer(task);
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
if (slot->is_processing()) {
|
||||
// if requested slot is unavailable, we defer this task for processing later
|
||||
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
||||
queue_tasks.defer(task);
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
|
||||
if (!launch_slot_with_task(*slot, task)) {
|
||||
if (!launch_slot_with_task(*slot, std::move(task))) {
|
||||
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
||||
break;
|
||||
}
|
||||
@@ -2740,7 +2741,7 @@ struct server_context {
|
||||
if (slot->is_processing()) {
|
||||
// if requested slot is unavailable, we defer this task for processing later
|
||||
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
||||
queue_tasks.defer(task);
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -2776,7 +2777,7 @@ struct server_context {
|
||||
if (slot->is_processing()) {
|
||||
// if requested slot is unavailable, we defer this task for processing later
|
||||
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
||||
queue_tasks.defer(task);
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -2819,7 +2820,7 @@ struct server_context {
|
||||
if (slot->is_processing()) {
|
||||
// if requested slot is unavailable, we defer this task for processing later
|
||||
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
||||
queue_tasks.defer(task);
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -2871,7 +2872,7 @@ struct server_context {
|
||||
|
||||
server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
|
||||
task.id = queue_tasks.get_new_id();
|
||||
queue_tasks.post(task);
|
||||
queue_tasks.post(std::move(task));
|
||||
}
|
||||
|
||||
// apply context-shift if needed
|
||||
@@ -3633,14 +3634,17 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// request slots data using task queue
|
||||
server_task task(SERVER_TASK_TYPE_METRICS);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task, true); // high-priority task
|
||||
int task_id = ctx_server.queue_tasks.get_new_id();
|
||||
{
|
||||
server_task task(SERVER_TASK_TYPE_METRICS);
|
||||
task.id = task_id;
|
||||
ctx_server.queue_results.add_waiting_task_id(task_id);
|
||||
ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
|
||||
}
|
||||
|
||||
// get the result
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
@@ -3669,16 +3673,17 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// request slots data using task queue
|
||||
server_task task(SERVER_TASK_TYPE_METRICS);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.metrics_reset_bucket = true;
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task, true); // high-priority task
|
||||
int task_id = ctx_server.queue_tasks.get_new_id();
|
||||
{
|
||||
server_task task(SERVER_TASK_TYPE_METRICS);
|
||||
task.id = task_id;
|
||||
ctx_server.queue_results.add_waiting_task_id(task_id);
|
||||
ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
|
||||
}
|
||||
|
||||
// get the result
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
@@ -3775,17 +3780,20 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
std::string filepath = params.slot_save_path + filename;
|
||||
|
||||
server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.slot_action.slot_id = id_slot;
|
||||
task.slot_action.filename = filename;
|
||||
task.slot_action.filepath = filepath;
|
||||
int task_id = ctx_server.queue_tasks.get_new_id();
|
||||
{
|
||||
server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
|
||||
task.id = task_id;
|
||||
task.slot_action.slot_id = id_slot;
|
||||
task.slot_action.filename = filename;
|
||||
task.slot_action.filepath = filepath;
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task);
|
||||
ctx_server.queue_results.add_waiting_task_id(task_id);
|
||||
ctx_server.queue_tasks.post(std::move(task));
|
||||
}
|
||||
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
@@ -3804,17 +3812,20 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
std::string filepath = params.slot_save_path + filename;
|
||||
|
||||
server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.slot_action.slot_id = id_slot;
|
||||
task.slot_action.filename = filename;
|
||||
task.slot_action.filepath = filepath;
|
||||
int task_id = ctx_server.queue_tasks.get_new_id();
|
||||
{
|
||||
server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
|
||||
task.id = task_id;
|
||||
task.slot_action.slot_id = id_slot;
|
||||
task.slot_action.filename = filename;
|
||||
task.slot_action.filepath = filepath;
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task);
|
||||
ctx_server.queue_results.add_waiting_task_id(task_id);
|
||||
ctx_server.queue_tasks.post(std::move(task));
|
||||
}
|
||||
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
@@ -3826,15 +3837,18 @@ int main(int argc, char ** argv) {
|
||||
};
|
||||
|
||||
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
||||
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.slot_action.slot_id = id_slot;
|
||||
int task_id = ctx_server.queue_tasks.get_new_id();
|
||||
{
|
||||
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
|
||||
task.id = task_id;
|
||||
task.slot_action.slot_id = id_slot;
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task);
|
||||
ctx_server.queue_results.add_waiting_task_id(task_id);
|
||||
ctx_server.queue_tasks.post(std::move(task));
|
||||
}
|
||||
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
@@ -3907,6 +3921,21 @@ int main(int argc, char ** argv) {
|
||||
res_ok(res, {{ "success", true }});
|
||||
};
|
||||
|
||||
const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
json data = {
|
||||
{
|
||||
"template", common_chat_templates_source(ctx_server.chat_templates.get()),
|
||||
},
|
||||
{
|
||||
"model_info", {
|
||||
{ "llama.context_length", ctx_server.slots.back().n_ctx, },
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
res_ok(res, data);
|
||||
};
|
||||
|
||||
// handle completion-like requests (completion, chat, infill)
|
||||
// we can optionally provide a custom format for partial results and final results
|
||||
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
|
||||
@@ -3923,9 +3952,10 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
auto completion_id = gen_chatcmplid();
|
||||
std::vector<server_task> tasks;
|
||||
|
||||
std::unordered_set<int> task_ids;
|
||||
try {
|
||||
std::vector<server_task> tasks;
|
||||
|
||||
const auto & prompt = data.at("prompt");
|
||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
@@ -3940,9 +3970,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||
task.params = server_task::params_from_json_cmpl(
|
||||
ctx_server.ctx,
|
||||
ctx_server.params_base,
|
||||
data);
|
||||
ctx_server.ctx,
|
||||
ctx_server.params_base,
|
||||
data);
|
||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
@@ -3950,18 +3980,18 @@ int main(int argc, char ** argv) {
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
tasks.push_back(task);
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
task_ids = server_task::get_list_id(tasks);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
} catch (const std::exception & e) {
|
||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
bool stream = json_value(data, "stream", false);
|
||||
const auto task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
if (!stream) {
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
@@ -4253,6 +4283,7 @@ int main(int argc, char ** argv) {
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
std::unordered_set<int> task_ids;
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
@@ -4265,28 +4296,27 @@ int main(int argc, char ** argv) {
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
|
||||
tasks.push_back(task);
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
task_ids = server_task::get_list_id(tasks);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
// get the result
|
||||
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
for (auto & res : results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
error = true;
|
||||
}, req.is_connection_closed);
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
}
|
||||
|
||||
// get the result
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
for (auto & res : results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
error = true;
|
||||
}, req.is_connection_closed);
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
|
||||
if (error) {
|
||||
return;
|
||||
}
|
||||
@@ -4352,6 +4382,7 @@ int main(int argc, char ** argv) {
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
std::unordered_set<int> task_ids;
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
|
||||
@@ -4361,26 +4392,24 @@ int main(int argc, char ** argv) {
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
|
||||
tasks.push_back(task);
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
task_ids = server_task::get_list_id(tasks);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
// get the result
|
||||
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
for (auto & res : results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
error = true;
|
||||
}, req.is_connection_closed);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
}
|
||||
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
for (auto & res : results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
error = true;
|
||||
}, req.is_connection_closed);
|
||||
|
||||
if (error) {
|
||||
return;
|
||||
}
|
||||
@@ -4416,14 +4445,19 @@ int main(int argc, char ** argv) {
|
||||
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task);
|
||||
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
int task_id = ctx_server.queue_tasks.get_new_id();
|
||||
{
|
||||
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
||||
task.id = task_id;
|
||||
task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
|
||||
ctx_server.queue_results.add_waiting_task_id(task_id);
|
||||
ctx_server.queue_tasks.post(std::move(task));
|
||||
}
|
||||
|
||||
// get the result
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
@@ -4471,6 +4505,7 @@ int main(int argc, char ** argv) {
|
||||
svr->Get ("/metrics", handle_metrics);
|
||||
svr->Get ("/props", handle_props);
|
||||
svr->Post("/props", handle_props_change);
|
||||
svr->Post("/api/show", handle_api_show);
|
||||
svr->Get ("/models", handle_models); // public endpoint (no API key check)
|
||||
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
|
||||
svr->Post("/completion", handle_completions); // legacy
|
||||
@@ -4566,8 +4601,8 @@ int main(int argc, char ** argv) {
|
||||
common_chat_templates_source(ctx_server.chat_templates.get()),
|
||||
common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
|
||||
|
||||
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
|
||||
ctx_server.process_single_task(task);
|
||||
ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) {
|
||||
ctx_server.process_single_task(std::move(task));
|
||||
});
|
||||
|
||||
ctx_server.queue_tasks.on_update_slots([&ctx_server]() {
|
||||
|
||||
@@ -8,10 +8,10 @@ cd build
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
#for FP16
|
||||
#cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON # faster for long-prompt inference
|
||||
#cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DLLAMA_CURL=OFF # faster for long-prompt inference
|
||||
|
||||
#for FP32
|
||||
cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
||||
cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_CURL=OFF
|
||||
|
||||
#build example/main
|
||||
#cmake --build . --config Release --target main
|
||||
|
||||
@@ -107,6 +107,7 @@ message(DEBUG "INS_ENB : ${INS_ENB}")
|
||||
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
|
||||
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
|
||||
option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
|
||||
option(GGML_SSE42 "ggml: enable SSE 4.2" ${INS_ENB})
|
||||
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
|
||||
option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
|
||||
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
|
||||
@@ -170,7 +171,6 @@ option(GGML_HIP "ggml: use HIP"
|
||||
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
|
||||
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
||||
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
|
||||
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
|
||||
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
||||
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
||||
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)
|
||||
|
||||
@@ -7,6 +7,9 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define RPC_PROTO_MAJOR_VERSION 1
|
||||
#define RPC_PROTO_MINOR_VERSION 0
|
||||
#define RPC_PROTO_PATCH_VERSION 0
|
||||
#define GGML_RPC_MAX_SERVERS 16
|
||||
|
||||
// backend API
|
||||
|
||||
@@ -267,6 +267,7 @@ function(ggml_add_cpu_backend_variant tag_name)
|
||||
set(GGML_CPU_TAG_NAME ${tag_name})
|
||||
# other: OPENMP LLAMAFILE CPU_HBM
|
||||
foreach (feat NATIVE
|
||||
SSE42
|
||||
AVX AVX2 BMI2 AVX_VNNI FMA F16C
|
||||
AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16
|
||||
AMX_TILE AMX_INT8 AMX_BF16)
|
||||
@@ -286,14 +287,16 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
|
||||
endif()
|
||||
ggml_add_cpu_backend_variant(sandybridge AVX)
|
||||
ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 BMI2 FMA)
|
||||
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 BMI2 FMA AVX512)
|
||||
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
||||
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 BMI2 FMA AVX_VNNI)
|
||||
ggml_add_cpu_backend_variant(x64)
|
||||
ggml_add_cpu_backend_variant(sse42 SSE42)
|
||||
ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
|
||||
ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA)
|
||||
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
|
||||
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
||||
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
|
||||
if (NOT MSVC)
|
||||
# MSVC doesn't support AMX
|
||||
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||
ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||
endif()
|
||||
elseif (GGML_CPU)
|
||||
ggml_add_cpu_backend_variant_impl("")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,6 +23,7 @@
|
||||
#ifndef CANN_ACLNN_OPS
|
||||
#define CANN_ACLNN_OPS
|
||||
|
||||
#include <functional>
|
||||
#include <aclnnop/aclnn_abs.h>
|
||||
#include <aclnnop/aclnn_neg.h>
|
||||
#include <aclnnop/aclnn_exp.h>
|
||||
@@ -713,6 +714,270 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
*/
|
||||
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/*
|
||||
* @brief A generic wrapper for ACL resources with custom deleter support.
|
||||
*/
|
||||
using any_acl_resource = std::unique_ptr<void, std::function<void(void*)>>;
|
||||
|
||||
/**
|
||||
* @brief Trait structure used to define how to destroy a given ACL resource type.
|
||||
*
|
||||
* @tparam T ACL resource type.
|
||||
*/
|
||||
template<typename T>
|
||||
struct acl_resource_traits;
|
||||
|
||||
/**
|
||||
* @brief Specialization for aclTensor, defines how to destroy an aclTensor resource.
|
||||
*/
|
||||
template<>
|
||||
struct acl_resource_traits<aclTensor> {
|
||||
static void destroy(void* p) {
|
||||
ACL_CHECK(aclDestroyTensor(static_cast<aclTensor*>(p)));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource.
|
||||
*/
|
||||
template<>
|
||||
struct acl_resource_traits<aclIntArray> {
|
||||
static void destroy(void* p) {
|
||||
ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray*>(p)));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization for aclScalar, defines how to destroy an aclScalar resource.
|
||||
*/
|
||||
template<>
|
||||
struct acl_resource_traits<aclScalar> {
|
||||
static void destroy(void* p) {
|
||||
ACL_CHECK(aclDestroyScalar(static_cast<aclScalar*>(p)));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource.
|
||||
*/
|
||||
template<>
|
||||
struct acl_resource_traits<aclTensorList> {
|
||||
static void destroy(void* p) {
|
||||
ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList*>(p)));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Creates a generic ACL resource wrapper with proper destruction logic.
|
||||
*
|
||||
* @tparam T ACL resource type.
|
||||
* @param ptr Raw pointer to ACL resource.
|
||||
* @return any_acl_resource Smart pointer that handles destruction.
|
||||
*/
|
||||
template<typename T>
|
||||
any_acl_resource make_acl_resource(T* ptr) {
|
||||
return any_acl_resource(
|
||||
static_cast<void*>(ptr),
|
||||
[](void* p) {
|
||||
acl_resource_traits<T>::destroy(p);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Registers multiple ACL resources into a vector for lifetime management.
|
||||
*
|
||||
* @tparam Args Variadic list of ACL resource types.
|
||||
* @param vec Target vector to hold ACL resources.
|
||||
* @param args Raw pointers to ACL resources.
|
||||
*/
|
||||
template<typename... Args>
|
||||
void register_acl_resources(std::vector<any_acl_resource>& vec, Args*... args) {
|
||||
(vec.emplace_back(make_acl_resource(args)), ...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Task class that wraps the execution of an aclnn function call.
|
||||
*/
|
||||
class aclnn_task : public cann_task {
|
||||
public:
|
||||
aclnn_task(aclnn_func_t aclnn_func, void * workspace_addr,
|
||||
uint64_t workspace_size, aclOpExecutor * executor,
|
||||
aclrtStream stream) :
|
||||
aclnn_func_(aclnn_func),
|
||||
workspace_addr_(workspace_addr),
|
||||
workspace_size_(workspace_size),
|
||||
executor_(executor),
|
||||
stream_(stream) {}
|
||||
virtual void run_task() override {
|
||||
ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_));
|
||||
}
|
||||
private:
|
||||
aclnn_func_t aclnn_func_;
|
||||
void * workspace_addr_;
|
||||
uint64_t workspace_size_;
|
||||
aclOpExecutor * executor_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class that releases ACL resources after usage.
|
||||
*/
|
||||
class release_resource_task : public cann_task {
|
||||
public:
|
||||
release_resource_task(std::vector<any_acl_resource>&& resources){
|
||||
resource_ = std::move(resources);
|
||||
}
|
||||
|
||||
virtual void run_task() override {
|
||||
resource_.clear();
|
||||
}
|
||||
private:
|
||||
std::vector<any_acl_resource> resource_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class for performing asynchronous memory copy operations.
|
||||
*/
|
||||
class async_memcpy_task : public cann_task {
|
||||
public:
|
||||
async_memcpy_task(void* dst, const void* src, size_t size,
|
||||
aclrtMemcpyKind kind, aclrtStream stream)
|
||||
: dst_(dst), src_(src), size_(size), kind_(kind), stream_(stream) {}
|
||||
|
||||
virtual void run_task() override {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_));
|
||||
}
|
||||
private:
|
||||
void* dst_;
|
||||
const void* src_;
|
||||
size_t size_;
|
||||
aclrtMemcpyKind kind_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class for performing asynchronous memory set operations.
|
||||
*/
|
||||
class async_memset_task : public cann_task {
|
||||
public:
|
||||
async_memset_task(void* buffer, size_t size, int32_t value, aclrtStream stream)
|
||||
: buffer_(buffer), size_(size), value_(value), stream_(stream) {}
|
||||
|
||||
virtual void run_task() override {
|
||||
ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_));
|
||||
}
|
||||
private:
|
||||
void* buffer_;
|
||||
size_t size_;
|
||||
int32_t value_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Launches an asynchronous task using the memory allocator.
|
||||
*
|
||||
* This macro submit an asynchronous task on the specified stream.
|
||||
* The task uses memory allocated by the allocator. It is guaranteed
|
||||
* that the memory will not be accessed by other tasks until this task
|
||||
* completes, due to the sequential execution order within the same stream.
|
||||
*
|
||||
* @param OP_NAME aclnn operator name.
|
||||
* @param args Additional arguments required by the task.
|
||||
*
|
||||
* @note
|
||||
* Memory from the allocator will be "freed" immediately and can be
|
||||
* reallocated to other pointers. However, it won't be accessed by any
|
||||
* other task before this asynchronous task ends, because all tasks in the
|
||||
* same stream are executed in queue order.
|
||||
*/
|
||||
|
||||
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
|
||||
do { \
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor));\
|
||||
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||
workspaceAddr = workspace_allocator.get(); \
|
||||
} \
|
||||
if (CTX.async_mode) { \
|
||||
auto task = \
|
||||
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, \
|
||||
executor, CTX.stream()); \
|
||||
CTX.task_queue.submit_task(std::move(task)); \
|
||||
} else { \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));\
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
/**
|
||||
* @brief Registers and releases multiple ACL resources, optionally deferring the release
|
||||
* using a task.
|
||||
*
|
||||
* @tparam Args Types of the ACL resources.
|
||||
* @param ctx Backend context which manages task submission and async mode.
|
||||
* @param args Pointers to ACL resources to be released.
|
||||
*/
|
||||
template <typename... Args>
|
||||
void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
|
||||
std::vector<any_acl_resource> resources;
|
||||
register_acl_resources(resources, std::forward<Args>(args)...);
|
||||
if(ctx.async_mode) {
|
||||
auto task = std::make_unique<release_resource_task>(std::move(resources));
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs an asynchronous memory copy operation, optionally deferred via task submission.
|
||||
*
|
||||
* @param ctx Backend context containing stream and async configuration.
|
||||
* @param dst Destination memory address.
|
||||
* @param src Source memory address.
|
||||
* @param len Size of memory to copy (in bytes).
|
||||
* @param kind Type of memory copy (host-to-device, device-to-host, etc).
|
||||
*/
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst,
|
||||
const void * src, size_t len, aclrtMemcpyKind kind) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst,
|
||||
const void * src, size_t len, aclrtMemcpyKind kind) {
|
||||
if (ctx->async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
|
||||
ctx->task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs an asynchronous memory set operation, optionally deferred via task submission.
|
||||
*
|
||||
* @param ctx Backend context containing stream and async configuration.
|
||||
* @param buffer Memory buffer to be set.
|
||||
* @param size Size of the memory buffer (in bytes).
|
||||
* @param value Value to set in the buffer.
|
||||
*/
|
||||
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer,
|
||||
size_t size, int value) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies a element-wise operation to two input tensors using the CANN
|
||||
* backend.
|
||||
@@ -742,42 +1007,9 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
|
||||
binary_op(ctx, acl_src0, acl_src1, acl_dst);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_src0));
|
||||
ACL_CHECK(aclDestroyTensor(acl_src1));
|
||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Launches an asynchronous task using the memory allocator.
|
||||
*
|
||||
* This macro submit an asynchronous task on the specified stream.
|
||||
* The task uses memory allocated by the allocator. It is guaranteed
|
||||
* that the memory will not be accessed by other tasks until this task
|
||||
* completes, due to the sequential execution order within the same stream.
|
||||
*
|
||||
* @param OP_NAME aclnn operator name.
|
||||
* @param args Additional arguments required by the task.
|
||||
*
|
||||
* @note
|
||||
* Memory from the allocator will be "freed" immediately and can be
|
||||
* reallocated to other pointers. However, it won't be accessed by any
|
||||
* other task before this asynchronous task ends, because all tasks in the
|
||||
* same stream are executed in queue order.
|
||||
*/
|
||||
#define GGML_CANN_CALL_ACLNN_OP(OP_NAME, ...) \
|
||||
do { \
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
\
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
|
||||
\
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); \
|
||||
workspaceAddr = workspace_allocator.get(); \
|
||||
} \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, ctx.stream())); \
|
||||
} while (0)
|
||||
|
||||
/**
|
||||
* @brief Applies a unary operation to an input tensor using the CANN backend.
|
||||
@@ -799,9 +1031,7 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
unary_op(ctx, acl_src, acl_dst);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_src));
|
||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
ggml_cann_release_resources(ctx, acl_src, acl_dst);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -832,7 +1062,7 @@ void ggml_cann_unary_op(
|
||||
*
|
||||
* Internally, the lambda will call:
|
||||
* @code
|
||||
* GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst);
|
||||
* GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
|
||||
* @endcode
|
||||
*
|
||||
* @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
|
||||
@@ -840,14 +1070,14 @@ void ggml_cann_unary_op(
|
||||
* @see ggml_cann_unary_op
|
||||
* @see GGML_CANN_CALL_ACLNN_OP
|
||||
*/
|
||||
#define GGML_CANN_CALL_UNARY_OP(OP_NAME) \
|
||||
do { \
|
||||
auto lambda = [](ggml_backend_cann_context& ctx, \
|
||||
aclTensor* acl_src, \
|
||||
aclTensor* acl_dst) { \
|
||||
GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst); \
|
||||
}; \
|
||||
ggml_cann_unary_op(lambda, ctx, dst); \
|
||||
} \
|
||||
#define GGML_CANN_CALL_UNARY_OP(OP_NAME) \
|
||||
do { \
|
||||
auto lambda = [](ggml_backend_cann_context& ctx, \
|
||||
aclTensor* acl_src, \
|
||||
aclTensor* acl_dst) { \
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
|
||||
}; \
|
||||
ggml_cann_unary_op(lambda, ctx, dst); \
|
||||
} \
|
||||
while (0)
|
||||
#endif // CANN_ACLNN_OPS
|
||||
|
||||
@@ -31,9 +31,16 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
#include <functional>
|
||||
|
||||
#include "../include/ggml-cann.h"
|
||||
#include "../include/ggml.h"
|
||||
#include "../ggml-impl.h"
|
||||
|
||||
#define MATRIX_ROW_PADDING 512
|
||||
#define GGML_CANN_MAX_STREAMS 8
|
||||
@@ -205,6 +212,127 @@ struct ggml_cann_pool_alloc {
|
||||
ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Function pointer type for ACLNN operator calls.
|
||||
*/
|
||||
using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStream);
|
||||
|
||||
/**
|
||||
* @brief Base class for all CANN tasks to be submitted to the task queue.
|
||||
*
|
||||
* Users should override the run_task() method with actual task logic.
|
||||
*/
|
||||
class cann_task {
|
||||
public:
|
||||
virtual void run_task() {}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
|
||||
*/
|
||||
class cann_task_queue {
|
||||
public:
|
||||
/**
|
||||
* @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
|
||||
*
|
||||
* @param capacity Queue capacity. Must be a power of 2.
|
||||
* @param device Target device ID (used for context setting).
|
||||
*/
|
||||
explicit cann_task_queue(size_t capacity, int32_t device)
|
||||
: buffer_(capacity), capacity_(capacity), head_(0), tail_(0),
|
||||
running_(false), device_(device) {
|
||||
GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
|
||||
mask_ = capacity_ - 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Attempts to enqueue a task into the queue.
|
||||
*
|
||||
* @param item Unique pointer to the task.
|
||||
* @return true if the task was successfully enqueued, false if the queue was full.
|
||||
*/
|
||||
bool enqueue(std::unique_ptr<cann_task>&& item) {
|
||||
size_t next_tail = (tail_ + 1) & mask_;
|
||||
|
||||
if (next_tail == head_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
buffer_[tail_] = std::move(item);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
tail_ = next_tail;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Submits a task to the queue, and starts the worker thread if not already running.
|
||||
*
|
||||
* @param task Task to be submitted.
|
||||
*/
|
||||
void submit_task(std::unique_ptr<cann_task>&& task) {
|
||||
while(!enqueue(std::move(task))) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!running_) {
|
||||
running_ = true;
|
||||
thread_ = std::thread(&cann_task_queue::execute, this);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Waits until the queue is completely empty and no tasks are being processed.
|
||||
*/
|
||||
void wait() {
|
||||
while (running_ && head_ != tail_) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stops the task queue and joins the worker thread.
|
||||
*/
|
||||
void stop() {
|
||||
running_ = false;
|
||||
if (thread_.joinable()) {
|
||||
thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Worker thread function that continuously dequeues and executes tasks.
|
||||
*/
|
||||
void execute() {
|
||||
ggml_cann_set_device(device_);
|
||||
|
||||
while (running_) {
|
||||
if(head_ == tail_) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
buffer_[head_]->run_task();
|
||||
buffer_[head_].reset();
|
||||
head_ = (head_ + 1) & mask_;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<cann_task>> buffer_;
|
||||
const size_t capacity_;
|
||||
size_t mask_;
|
||||
size_t head_;
|
||||
size_t tail_;
|
||||
bool running_;
|
||||
std::thread thread_;
|
||||
int32_t device_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Context for managing CANN backend operations.
|
||||
*/
|
||||
@@ -213,6 +341,8 @@ struct ggml_backend_cann_context {
|
||||
std::string name; /**< Name of the device. */
|
||||
std::string description; /**< Description of the device. */
|
||||
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
|
||||
cann_task_queue task_queue;
|
||||
bool async_mode;
|
||||
|
||||
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
|
||||
|
||||
@@ -221,9 +351,12 @@ struct ggml_backend_cann_context {
|
||||
* @param device Device ID.
|
||||
*/
|
||||
explicit ggml_backend_cann_context(int device)
|
||||
: device(device), name("CANN" + std::to_string(device)) {
|
||||
: device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
|
||||
ggml_cann_set_device(device);
|
||||
description = aclrtGetSocName();
|
||||
async_mode = (getenv("GGML_CANN_ASYNC_MODE") != nullptr);
|
||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
|
||||
device, async_mode ? "ON" : "OFF");
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -231,6 +364,7 @@ struct ggml_backend_cann_context {
|
||||
*/
|
||||
~ggml_backend_cann_context() {
|
||||
ggml_cann_set_device(device);
|
||||
task_queue.stop();
|
||||
if (copy_event != nullptr) {
|
||||
ACL_CHECK(aclrtDestroyEvent(copy_event));
|
||||
}
|
||||
|
||||
@@ -29,6 +29,8 @@
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <chrono>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
@@ -119,9 +121,10 @@ static ggml_cann_device_info ggml_cann_init() {
|
||||
prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = id;
|
||||
prop.reserve = 0;
|
||||
ACL_CHECK(aclrtMemGetAllocationGranularity(
|
||||
err = aclrtMemGetAllocationGranularity(
|
||||
&prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
|
||||
&info.devices[id].vmm_granularity));
|
||||
&info.devices[id].vmm_granularity);
|
||||
info.devices[id].vmm = err == ACL_SUCCESS;
|
||||
|
||||
size_t free, total;
|
||||
ggml_backend_cann_get_device_memory(id, &free, &total);
|
||||
@@ -148,11 +151,223 @@ const ggml_cann_device_info& ggml_cann_info() {
|
||||
|
||||
//#define DEBUG_CANN_MALLOC
|
||||
/**
|
||||
* @brief A pool of CANN buffers(legacy).
|
||||
* @brief A pool of CANN buffers(priority segment buffer).
|
||||
*
|
||||
* This class manages a pool of CANN buffers for a specific device.
|
||||
*/
|
||||
struct ggml_cann_pool_leg : public ggml_cann_pool {
|
||||
struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
|
||||
/**
|
||||
* @brief The maximum reuse margin for a buffer.
|
||||
*/
|
||||
static const size_t max_reuse_margin = 1ull << 22; // 4MB
|
||||
|
||||
/**
|
||||
* @brief The minimum free margin for a buffer.
|
||||
*/
|
||||
static const size_t min_free_margin = 1ull << 20; // 1MB
|
||||
|
||||
/**
|
||||
* @brief The alignment for buffer allocation.
|
||||
*/
|
||||
static const size_t alignment = 128;
|
||||
|
||||
/**
|
||||
* @brief The device ID associated with this buffer pool.
|
||||
*/
|
||||
int device;
|
||||
|
||||
/**
|
||||
* @brief Whether to disable clean during buffer allocation.
|
||||
*/
|
||||
bool disable_clean = false;
|
||||
|
||||
/**
|
||||
* @brief Structure representing a CANN buffer.
|
||||
*/
|
||||
struct ggml_cann_buffer {
|
||||
void* ptr = nullptr; ///< Pointer to the buffer.
|
||||
size_t size = 0; ///< Size of the buffer.
|
||||
std::chrono::steady_clock::time_point last_used; ///< Last used time.
|
||||
|
||||
bool operator>(const ggml_cann_buffer& other) const {
|
||||
return size > other.size;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Array of CANN buffers in the pool.
|
||||
*/
|
||||
std::unordered_map<void*, size_t> buffer_pool;
|
||||
std::priority_queue<ggml_cann_buffer,
|
||||
std::vector<ggml_cann_buffer>,
|
||||
std::greater<>> free_buffers ;
|
||||
|
||||
/**
|
||||
* @brief Total size of all buffers in the pool.
|
||||
*/
|
||||
size_t pool_size = 0;
|
||||
|
||||
/**
|
||||
* @brief Constructor to initialize the buffer pool for a specific device.
|
||||
*
|
||||
* @param device The device ID to associate with this buffer pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_buf_prio(int device) : device(device) {
|
||||
disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Destructor to free all buffers in the pool.
|
||||
*/
|
||||
~ggml_cann_pool_buf_prio() {
|
||||
ggml_cann_set_device(device);
|
||||
for (auto& [b_ptr, b_size] : buffer_pool) {
|
||||
aclrtFree(b_ptr);
|
||||
pool_size -= b_size;
|
||||
}
|
||||
buffer_pool.clear();
|
||||
GGML_ASSERT(pool_size == 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Allocate a buffer of the given size.
|
||||
*
|
||||
* @param size The size of the buffer to allocate.
|
||||
* @param actual_size A pointer to a variable to receive the actual size of
|
||||
* the allocated buffer.
|
||||
* @return A pointer to the allocated buffer.
|
||||
*/
|
||||
void* alloc(size_t size, size_t* actual_size) override {
|
||||
size = GGML_PAD(size, alignment);
|
||||
if (size == 0) {
|
||||
size = alignment;
|
||||
}
|
||||
|
||||
void* ptr = nullptr;
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
|
||||
std::vector<ggml_cann_buffer> free_buffers_rest;
|
||||
free_buffers_rest.reserve(free_buffers.size());
|
||||
while (!free_buffers.empty()) {
|
||||
auto b = free_buffers.top();
|
||||
free_buffers.pop();
|
||||
|
||||
if (b.size >= size) {
|
||||
// reuse the buffer if the size is enough
|
||||
const size_t margin = b.size - size;
|
||||
if (margin <= max_reuse_margin) {
|
||||
*actual_size = b.size;
|
||||
ptr = b.ptr;
|
||||
#ifdef DEBUG_CANN_MALLOC
|
||||
GGML_LOG_INFO(
|
||||
"cann pool[%d]: reused %p, "
|
||||
"pool_size = %5u MB, "
|
||||
"size = %5u MB, "
|
||||
"margin = %5u MB\n",
|
||||
device, b.ptr,
|
||||
(uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
|
||||
(uint32_t)(GGML_PAD(size, 1048576) / 1048576),
|
||||
(uint32_t)(GGML_PAD(margin, 1048576) / 1048576));
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool should_clean = !disable_clean &&
|
||||
b.size > min_free_margin &&
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;
|
||||
if (should_clean) {
|
||||
// free the buffer if the size is needed to be freed
|
||||
ACL_CHECK(aclrtFree(b.ptr));
|
||||
pool_size -= b.size;
|
||||
buffer_pool.erase(b.ptr);
|
||||
#ifdef DEBUG_CANN_MALLOC
|
||||
GGML_LOG_INFO(
|
||||
"cann pool[%d]: clean %p, "
|
||||
"pool_size = %5u MB, "
|
||||
"size = %5u MB\n",
|
||||
device, b.ptr,
|
||||
(uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
|
||||
(uint32_t)(GGML_PAD(b.size, 1048576) / 1048576));
|
||||
#endif
|
||||
continue;
|
||||
}
|
||||
free_buffers_rest.push_back(b);
|
||||
}
|
||||
for (ggml_cann_buffer &b : free_buffers_rest) {
|
||||
free_buffers.push(std::move(b));
|
||||
}
|
||||
|
||||
#ifdef DEBUG_CANN_MALLOC
|
||||
GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576));
|
||||
#endif
|
||||
if (ptr != nullptr) {
|
||||
return ptr;
|
||||
}
|
||||
|
||||
// allocate a new buffer if no buffer can be reused
|
||||
ggml_cann_set_device(device);
|
||||
ACL_CHECK(aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
*actual_size = size;
|
||||
pool_size += size;
|
||||
#ifdef DEBUG_CANN_MALLOC
|
||||
GGML_LOG_INFO(
|
||||
"cann pool[%d]: allocate %p, "
|
||||
"pool_size = %5u MB, "
|
||||
"size = %5u MB\n",
|
||||
device, ptr, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
|
||||
(uint32_t)(GGML_PAD(size, 1048576) / 1048576));
|
||||
#endif
|
||||
buffer_pool.emplace(ptr, size);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Free a buffer and return it to the pool.
|
||||
*
|
||||
* @param ptr Pointer to the buffer to free.
|
||||
* @param size Size of the buffer to free.
|
||||
*/
|
||||
void free(void* ptr, size_t size) override {
|
||||
GGML_UNUSED(size);
|
||||
auto it = buffer_pool.find(ptr);
|
||||
if (it == buffer_pool.end()) {
|
||||
GGML_ABORT("cann pool[%d]: buffer %p not found in pool\n", device, ptr);
|
||||
}
|
||||
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
free_buffers.emplace(ggml_cann_buffer{ptr, it->second, now});
|
||||
#ifdef DEBUG_CANN_MALLOC
|
||||
GGML_LOG_INFO(
|
||||
"cann pool[%d]: return %p, "
|
||||
"pool_size = %5u MB\n",
|
||||
device, ptr,
|
||||
(uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576));
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief A pool of CANN buffers(segment buffer).
|
||||
*
|
||||
* This class manages a pool of CANN buffers for a specific device.
|
||||
*/
|
||||
struct ggml_cann_pool_buf : public ggml_cann_pool {
|
||||
/**
|
||||
* @brief The maximum reuse margin for a buffer.
|
||||
*/
|
||||
static const size_t max_reuse_margin = 1ull << 22; // 4MB
|
||||
|
||||
/**
|
||||
* @brief The minimum free margin for a buffer.
|
||||
*/
|
||||
static const size_t min_free_margin = 1ull << 20; // 1MB
|
||||
|
||||
/**
|
||||
* @brief The alignment for buffer allocation.
|
||||
*/
|
||||
static const size_t alignment = 128;
|
||||
|
||||
/**
|
||||
* @brief The maximum number of buffers in the pool.
|
||||
*/
|
||||
@@ -163,12 +378,19 @@ struct ggml_cann_pool_leg : public ggml_cann_pool {
|
||||
*/
|
||||
int device;
|
||||
|
||||
/**
|
||||
* @brief Whether to disable clean during buffer allocation.
|
||||
*/
|
||||
bool disable_clean = false;
|
||||
|
||||
/**
|
||||
* @brief Structure representing a CANN buffer.
|
||||
*/
|
||||
struct ggml_cann_buffer {
|
||||
void* ptr = nullptr; ///< Pointer to the buffer memory.
|
||||
size_t size = 0; ///< Size of the buffer.
|
||||
bool used = false; ///< Whether the buffer is currently in use.
|
||||
std::chrono::steady_clock::time_point last_used; ///< Last used time.
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -186,17 +408,19 @@ struct ggml_cann_pool_leg : public ggml_cann_pool {
|
||||
*
|
||||
* @param device The device ID to associate with this buffer pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_leg(int device) : device(device) {}
|
||||
explicit ggml_cann_pool_buf(int device) : device(device) {
|
||||
disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Destructor to free all buffers in the pool.
|
||||
*/
|
||||
~ggml_cann_pool_leg() {
|
||||
~ggml_cann_pool_buf() {
|
||||
ggml_cann_set_device(device);
|
||||
for (int i = 0; i < MAX_BUFFERS; ++i) {
|
||||
ggml_cann_buffer& b = buffer_pool[i];
|
||||
if (b.ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(b.ptr));
|
||||
aclrtFree(b.ptr);
|
||||
pool_size -= b.size;
|
||||
}
|
||||
}
|
||||
@@ -212,63 +436,93 @@ struct ggml_cann_pool_leg : public ggml_cann_pool {
|
||||
* @return A pointer to the allocated buffer.
|
||||
*/
|
||||
void* alloc(size_t size, size_t* actual_size) override {
|
||||
const size_t alignment = 128;
|
||||
size = GGML_PAD(size, alignment);
|
||||
if (size == 0) {
|
||||
size = alignment;
|
||||
}
|
||||
#ifdef DEBUG_CANN_MALLOC
|
||||
int nnz = 0;
|
||||
size_t max_size = 0;
|
||||
#endif
|
||||
size_t best_diff = 1ull << 36;
|
||||
int ibest = -1;
|
||||
for (int i = 0; i < MAX_BUFFERS; ++i) {
|
||||
|
||||
void* ptr = nullptr;
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
|
||||
int i = 0;
|
||||
for (; i < MAX_BUFFERS; ++i) {
|
||||
ggml_cann_buffer& b = buffer_pool[i];
|
||||
if (b.ptr != nullptr) {
|
||||
if (b.ptr == nullptr) {
|
||||
break;
|
||||
}
|
||||
if (b.used) {
|
||||
continue;
|
||||
}
|
||||
if (b.size >= size) {
|
||||
// reuse the buffer if the size is enough
|
||||
const size_t margin = b.size - size;
|
||||
if (margin <= max_reuse_margin) {
|
||||
*actual_size = b.size;
|
||||
b.used = true;
|
||||
ptr = b.ptr;
|
||||
#ifdef DEBUG_CANN_MALLOC
|
||||
++nnz;
|
||||
if (b.size > max_size) max_size = b.size;
|
||||
GGML_LOG_INFO(
|
||||
"cann pool[%d]: reused %p, "
|
||||
"pool_size = %5u MB, "
|
||||
"size = %5u MB, "
|
||||
"margin = %5u MB\n",
|
||||
device, b.ptr,
|
||||
(uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
|
||||
(uint32_t)(GGML_PAD(size, 1048576) / 1048576),
|
||||
(uint32_t)(GGML_PAD(margin, 1048576) / 1048576));
|
||||
#endif
|
||||
if (b.size >= size) {
|
||||
size_t diff = b.size - size;
|
||||
if (diff < best_diff) {
|
||||
best_diff = diff;
|
||||
ibest = i;
|
||||
if (!best_diff) {
|
||||
void* ptr = b.ptr;
|
||||
*actual_size = b.size;
|
||||
b.ptr = nullptr;
|
||||
b.size = 0;
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool should_clean = !disable_clean &&
|
||||
b.size > min_free_margin &&
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;
|
||||
if (should_clean) {
|
||||
// free the buffer if the size is needed to be freed
|
||||
ACL_CHECK(aclrtFree(b.ptr));
|
||||
pool_size -= b.size;
|
||||
#ifdef DEBUG_CANN_MALLOC
|
||||
GGML_LOG_INFO(
|
||||
"cann pool[%d]: clean %p, "
|
||||
"pool_size = %5u MB, "
|
||||
"size = %5u MB\n",
|
||||
device, b.ptr,
|
||||
(uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
|
||||
(uint32_t)(GGML_PAD(b.size, 1048576) / 1048576));
|
||||
#endif
|
||||
b.ptr = nullptr;
|
||||
}
|
||||
}
|
||||
if (ibest >= 0) {
|
||||
ggml_cann_buffer& b = buffer_pool[ibest];
|
||||
void* ptr = b.ptr;
|
||||
*actual_size = b.size;
|
||||
b.ptr = nullptr;
|
||||
b.size = 0;
|
||||
if (ptr != nullptr) {
|
||||
return ptr;
|
||||
}
|
||||
void* ptr;
|
||||
ggml_cann_set_device(device);
|
||||
ACL_CHECK(
|
||||
aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
*actual_size = size;
|
||||
pool_size += size;
|
||||
|
||||
if (i < MAX_BUFFERS) {
|
||||
// allocate a new buffer if no buffer can be reused
|
||||
ggml_cann_buffer& b = buffer_pool[i];
|
||||
ggml_cann_set_device(device);
|
||||
ACL_CHECK(aclrtMalloc(&b.ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
pool_size += size;
|
||||
*actual_size = size;
|
||||
b.size = size;
|
||||
b.used = true;
|
||||
if (i >= MAX_BUFFERS - 8) {
|
||||
GGML_LOG_WARN("cann pool[%d]: slots almost full\n", device);
|
||||
}
|
||||
#ifdef DEBUG_CANN_MALLOC
|
||||
GGML_LOG_INFO(
|
||||
"%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
|
||||
"requested %u MB\n",
|
||||
__func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
|
||||
(uint32_t)(pool_size / 1024 / 1024),
|
||||
(uint32_t)(size / 1024 / 1024));
|
||||
GGML_LOG_INFO(
|
||||
"cann pool[%d]: allocate %p, "
|
||||
"pool_size = %5u MB, "
|
||||
"size = %5u MB\n",
|
||||
device, b.ptr,
|
||||
(uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
|
||||
(uint32_t)(GGML_PAD(b.size, 1048576) / 1048576));
|
||||
#endif
|
||||
return ptr;
|
||||
return b.ptr;
|
||||
}
|
||||
|
||||
GGML_ABORT("cann pool[%d]: slots full\n", device);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -278,18 +532,24 @@ struct ggml_cann_pool_leg : public ggml_cann_pool {
|
||||
* @param size Size of the buffer to free.
|
||||
*/
|
||||
void free(void* ptr, size_t size) override {
|
||||
GGML_UNUSED(size);
|
||||
for (int i = 0; i < MAX_BUFFERS; ++i) {
|
||||
ggml_cann_buffer& b = buffer_pool[i];
|
||||
if (b.ptr == nullptr) {
|
||||
b.ptr = ptr;
|
||||
b.size = size;
|
||||
return;
|
||||
if (b.ptr != ptr) {
|
||||
continue;
|
||||
}
|
||||
b.used = false;
|
||||
b.last_used = std::chrono::steady_clock::now();
|
||||
#ifdef DEBUG_CANN_MALLOC
|
||||
GGML_LOG_INFO(
|
||||
"cann pool[%d]: return %p, "
|
||||
"pool_size = %5u MB\n",
|
||||
device, b.ptr,
|
||||
(uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576));
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
// memory should always buffered. these memory may still needed by
|
||||
// tasks in stream.
|
||||
// TODO, fix me.
|
||||
GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
|
||||
GGML_ABORT("cann pool[%d]: slots full\n", device);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -347,8 +607,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
|
||||
* @param device The device ID to associate with this buffer pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_vmm(int device)
|
||||
: device(device),
|
||||
granularity(ggml_cann_info().devices[device].vmm_granularity) {
|
||||
: device(device) {
|
||||
auto dev = ggml_cann_info().devices[device];
|
||||
granularity = dev.vmm_granularity;
|
||||
max_size = dev.total_vram;
|
||||
@@ -471,7 +730,18 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
|
||||
*/
|
||||
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
|
||||
int device) {
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
||||
bool disable_vmm = (getenv("GGML_CANN_DISABLE_VMM_POOL") != nullptr);
|
||||
if (!disable_vmm && ggml_cann_info().devices[device].vmm) {
|
||||
GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
||||
}
|
||||
bool enable_buf_prio = (getenv("GGML_CANN_ENABLE_BUF_PRIO_POOL") != nullptr);
|
||||
if (enable_buf_prio) {
|
||||
GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device));
|
||||
}
|
||||
GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device);
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device));
|
||||
}
|
||||
|
||||
// cann buffer
|
||||
@@ -1020,8 +1290,11 @@ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
||||
|
||||
ggml_cann_set_device(buft_ctx->device);
|
||||
|
||||
size = std::max(size, (size_t)1);
|
||||
|
||||
const size_t alignment = 128;
|
||||
size = GGML_PAD(size, alignment);
|
||||
if (size == 0) {
|
||||
size = alignment;
|
||||
}
|
||||
void* dev_ptr;
|
||||
aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
|
||||
if (err != ACL_SUCCESS) {
|
||||
@@ -1333,7 +1606,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
||||
auto lambda = [](ggml_backend_cann_context& ctx,
|
||||
aclTensor* acl_src,
|
||||
aclTensor* acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(GeluV2, acl_src, 0, acl_dst);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
|
||||
};
|
||||
ggml_cann_unary_op(lambda, ctx, dst);
|
||||
} break;
|
||||
@@ -1516,12 +1789,11 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
|
||||
delete backend;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Sets tensor data asynchronously in the CANN backend.
|
||||
*
|
||||
* This function asynchronously sets tensor data in the CANN backend. Depending
|
||||
* on the tensor type, it may perform data transformations before copying data
|
||||
* to the device.
|
||||
* This function asynchronously sets tensor data in the CANN backend.
|
||||
*
|
||||
* @param backend Pointer to the CANN backend structure.
|
||||
* @param tensor Pointer to the tensor structure to set data for.
|
||||
@@ -1536,23 +1808,28 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
|
||||
size_t size) {
|
||||
ggml_backend_cann_context *cann_ctx =
|
||||
(ggml_backend_cann_context *)backend->context;
|
||||
ggml_backend_buffer_t buf =
|
||||
tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||
|
||||
if (!need_transform(tensor->type)) {
|
||||
ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
|
||||
size, ACL_MEMCPY_HOST_TO_DEVICE,
|
||||
cann_ctx->stream()));
|
||||
} else {
|
||||
void *transform_buffer = malloc(size);
|
||||
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
||||
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
|
||||
"unsupported buffer type");
|
||||
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
||||
|
||||
ACL_CHECK(aclrtMemcpyAsync(
|
||||
(char *)tensor->data + offset, size, transform_buffer, size,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
|
||||
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
||||
free(transform_buffer);
|
||||
}
|
||||
ggml_cann_async_memcpy(cann_ctx, (char *)tensor->data + offset, data, size,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gets tensor data asynchronously in the CANN backend.
|
||||
*
|
||||
* This function asynchronously gets tensor data in the CANN backend.
|
||||
*
|
||||
* @param backend Pointer to the CANN backend structure.
|
||||
* @param tensor Pointer to the tensor structure to get data from.
|
||||
* @param data Pointer to the host data to copy from the tensor.
|
||||
* @param offset Offset in bytes within the host data.
|
||||
* @param size Size of the data to copy in bytes.
|
||||
*/
|
||||
static void ggml_backend_cann_get_tensor_async(
|
||||
ggml_backend_t backend, const ggml_tensor *tensor, void *data,
|
||||
size_t offset, size_t size) {
|
||||
@@ -1563,20 +1840,11 @@ static void ggml_backend_cann_get_tensor_async(
|
||||
|
||||
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
|
||||
"unsupported buffer type");
|
||||
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
||||
|
||||
ggml_cann_async_memcpy(cann_ctx, data, (char *)tensor->data + offset, size,
|
||||
ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
|
||||
if (!need_transform(tensor->type)) {
|
||||
ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
|
||||
size, ACL_MEMCPY_DEVICE_TO_HOST,
|
||||
cann_ctx->stream()));
|
||||
} else {
|
||||
void *transform_buffer = malloc(size);
|
||||
ACL_CHECK(aclrtMemcpyAsync(
|
||||
transform_buffer, size, (char *)tensor->data + offset, size,
|
||||
ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
|
||||
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
||||
ggml_backend_cann_transform_back(tensor, transform_buffer, data);
|
||||
free(transform_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1636,6 +1904,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
|
||||
ggml_cann_set_device(cann_ctx_src->device);
|
||||
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
|
||||
|
||||
// wait for task_queue empty to keep task order.
|
||||
cann_ctx_src->task_queue.wait();
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE,
|
||||
cann_ctx_src->stream()));
|
||||
@@ -1663,9 +1933,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
|
||||
static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
||||
ggml_backend_cann_context* cann_ctx =
|
||||
(ggml_backend_cann_context*)backend->context;
|
||||
|
||||
cann_ctx->task_queue.wait();
|
||||
ggml_cann_set_device(cann_ctx->device);
|
||||
|
||||
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
||||
}
|
||||
|
||||
@@ -1749,6 +2018,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
return true;
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
#ifdef ASCEND_310P
|
||||
// Q4 && Q8 per group is not suppor on 310p device
|
||||
return false;
|
||||
#endif
|
||||
// only support contiguous for quantized types.
|
||||
return ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op->src[1]);
|
||||
@@ -1816,6 +2089,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!ggml_is_contiguous(op->src[0])){
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_UPSCALE: {
|
||||
@@ -1831,6 +2107,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
}
|
||||
case GGML_OP_POOL_2D: {
|
||||
const int32_t * opts = (const int32_t *) op->op_params;
|
||||
#ifdef ASCEND_310P
|
||||
enum ggml_op_pool opt = static_cast<ggml_op_pool>(opts[0]);
|
||||
if(opt == GGML_OP_POOL_MAX){
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
const int k0 = opts[1];
|
||||
const int k1 = opts[2];
|
||||
const int p0 = opts[5];
|
||||
|
||||
@@ -222,7 +222,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
elseif (GGML_AVX)
|
||||
list(APPEND ARCH_FLAGS /arch:AVX)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AVX)
|
||||
else ()
|
||||
elseif (GGML_SSE42)
|
||||
list(APPEND ARCH_FLAGS /arch:SSE4.2)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_SSE42)
|
||||
endif()
|
||||
@@ -237,8 +237,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
if (GGML_NATIVE)
|
||||
list(APPEND ARCH_FLAGS -march=native)
|
||||
else ()
|
||||
list(APPEND ARCH_FLAGS -msse4.2)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_SSE42)
|
||||
if (GGML_SSE42)
|
||||
list(APPEND ARCH_FLAGS -msse4.2)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_SSE42)
|
||||
endif()
|
||||
if (GGML_F16C)
|
||||
list(APPEND ARCH_FLAGS -mf16c)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_F16C)
|
||||
|
||||
@@ -263,7 +263,7 @@ void test_x86_is() {
|
||||
static int ggml_backend_cpu_x86_score() {
|
||||
// FIXME: this does not check for OS support
|
||||
|
||||
int score = 0;
|
||||
int score = 1;
|
||||
cpuid_x86 is;
|
||||
|
||||
#ifdef GGML_FMA
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -425,6 +425,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
|
||||
}
|
||||
case GGML_OP_IM2COL_BACK:
|
||||
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
return src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16;
|
||||
case GGML_OP_OUT_PROD:
|
||||
return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
|
||||
src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
|
||||
@@ -551,7 +551,7 @@ static void ggml_cpy_f16_f16_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
|
||||
@@ -588,7 +588,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
char ** dest_ptrs_d = nullptr;
|
||||
int graph_cpynode_index = -1;
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
||||
if(ctx.cuda_graph->use_cpy_indirection) {
|
||||
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
||||
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
|
||||
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
|
||||
}
|
||||
@@ -636,7 +636,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
}
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
||||
if(ctx.cuda_graph->use_cpy_indirection) {
|
||||
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
||||
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
|
||||
}
|
||||
#endif
|
||||
@@ -645,7 +645,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
ggml_cuda_cpy(ctx, src0, dst);
|
||||
bool disable_indirection = true;
|
||||
ggml_cuda_cpy(ctx, src0, dst, disable_indirection);
|
||||
}
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#define CUDA_CPY_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
|
||||
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
|
||||
@@ -96,31 +96,32 @@ int ggml_cuda_get_device() {
|
||||
|
||||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
#if defined(GGML_USE_HIP) && defined(GGML_HIP_UMA)
|
||||
auto res = hipMallocManaged(ptr, size);
|
||||
if (res == hipSuccess) {
|
||||
// if error we "need" to know why...
|
||||
CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
|
||||
}
|
||||
return res;
|
||||
#else
|
||||
|
||||
#if !defined(GGML_USE_HIP)
|
||||
cudaError_t err;
|
||||
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
||||
{
|
||||
err = cudaMallocManaged(ptr, size);
|
||||
#if defined(GGML_USE_HIP)
|
||||
if (err == hipSuccess) {
|
||||
CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
|
||||
}
|
||||
|
||||
// fall back to cudaMalloc if not supported (e.g. on Windows)
|
||||
if (err == hipErrorNotSupported) {
|
||||
static bool warned_unsupported = false;
|
||||
if (!warned_unsupported) {
|
||||
GGML_LOG_WARN("hipMallocManaged unsupported, falling back to hipMalloc.\n");
|
||||
warned_unsupported = true;
|
||||
}
|
||||
|
||||
err = cudaMalloc(ptr, size);
|
||||
}
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
}
|
||||
else
|
||||
{
|
||||
err = cudaMalloc(ptr, size);
|
||||
}
|
||||
return err;
|
||||
#else
|
||||
return cudaMalloc(ptr, size);
|
||||
#endif // !defined(GGML_USE_HIP)
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
@@ -1409,6 +1410,11 @@ static void ggml_cuda_op_mul_mat(
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
|
||||
// const int64_t nb10 = src1->nb[0];
|
||||
const int64_t nb11 = src1->nb[1];
|
||||
const int64_t nb12 = src1->nb[2];
|
||||
const int64_t nb13 = src1->nb[3];
|
||||
|
||||
const int64_t nb2 = dst->nb[2];
|
||||
const int64_t nb3 = dst->nb[3];
|
||||
|
||||
@@ -1544,7 +1550,10 @@ static void ggml_cuda_op_mul_mat(
|
||||
dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
|
||||
|
||||
if (src1_on_device && src1_is_contiguous) {
|
||||
quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
|
||||
quantize_src1(
|
||||
dev[id].src1_ddf, dev[id].src1_ddq, src0->type, ne10,
|
||||
nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
|
||||
src1_padded_col_size, ne11, ne12, ne13, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
@@ -1639,7 +1648,9 @@ static void ggml_cuda_op_mul_mat(
|
||||
}
|
||||
|
||||
if (quantize_src1 && !src1_is_contiguous) {
|
||||
quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
|
||||
quantize_src1(
|
||||
src1_ddf_i, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
|
||||
src1_padded_col_size, src1_ncols, 1, 1, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
@@ -1877,7 +1888,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
||||
|
||||
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
|
||||
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||
@@ -1918,10 +1929,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||
|
||||
if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
|
||||
if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
|
||||
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
|
||||
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
|
||||
ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
|
||||
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
|
||||
} else if (!split && use_mul_mat_vec_q) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
|
||||
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
// general KQ + KQV multi-batch without FlashAttention
|
||||
@@ -1998,6 +2011,15 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ne2 == 1) {
|
||||
if (ggml_is_quantized(src0->type)) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
||||
} else {
|
||||
ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
@@ -2034,97 +2056,75 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
dst_row.nb[2] = nb1;
|
||||
dst_row.nb[3] = nb1;
|
||||
|
||||
if (ne12 == 1) {
|
||||
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
||||
ggml_cuda_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
||||
|
||||
src1_row.data = src1_contiguous.get();
|
||||
dst_row.data = dst_contiguous.get();
|
||||
|
||||
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
||||
int64_t num_src1_rows = 0;
|
||||
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
||||
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
|
||||
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = iid1;
|
||||
if (row_id_i != i02) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t i1 = id;
|
||||
const int64_t i2 = i12;
|
||||
|
||||
src0_row.data = src0_original + i02*nb02;
|
||||
src1_row.data = src1_original + i11*nb11 + i12*nb12;
|
||||
dst_row.data = dst_original + i1*nb1 + i2*nb2;
|
||||
|
||||
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
||||
num_src1_rows++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
||||
ggml_cuda_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
||||
|
||||
src1_row.data = src1_contiguous.get();
|
||||
dst_row.data = dst_contiguous.get();
|
||||
if (num_src1_rows == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
||||
int64_t num_src1_rows = 0;
|
||||
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
|
||||
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
|
||||
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
|
||||
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
{
|
||||
dim3 block_dims(std::min((unsigned int)ne10, 768u));
|
||||
dim3 grid_dims(ids->ne[1], n_ids);
|
||||
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
||||
src1_original, src1_contiguous.get(),
|
||||
dev_cur_src1_row.get(), dev_row_mapping.get(),
|
||||
ids_dev, i02, ids->nb[1], ids->nb[0],
|
||||
ne11, ne10,
|
||||
nb11, nb12);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
|
||||
src0_row.data = src0_original + i02*nb02;
|
||||
|
||||
if (row_id_i != i02) {
|
||||
continue;
|
||||
}
|
||||
GGML_ASSERT(nb11 == sizeof(float)*ne10);
|
||||
GGML_ASSERT(nb1 == sizeof(float)*ne0);
|
||||
|
||||
num_src1_rows++;
|
||||
}
|
||||
}
|
||||
src1_row.ne[1] = num_src1_rows;
|
||||
src1_row.nb[1] = nb11;
|
||||
src1_row.nb[2] = num_src1_rows*nb11;
|
||||
src1_row.nb[3] = num_src1_rows*nb11;
|
||||
|
||||
if (num_src1_rows == 0) {
|
||||
continue;
|
||||
}
|
||||
dst_row.ne[1] = num_src1_rows;
|
||||
dst_row.nb[1] = nb1;
|
||||
dst_row.nb[2] = num_src1_rows*nb1;
|
||||
dst_row.nb[3] = num_src1_rows*nb1;
|
||||
|
||||
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
|
||||
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
|
||||
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
|
||||
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
||||
|
||||
{
|
||||
dim3 block_dims(std::min((unsigned int)ne10, 768u));
|
||||
dim3 grid_dims(ids->ne[1], n_ids);
|
||||
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
||||
src1_original, src1_contiguous.get(),
|
||||
dev_cur_src1_row.get(), dev_row_mapping.get(),
|
||||
ids_dev, i02, ids->nb[1], ids->nb[0],
|
||||
ne11, ne10,
|
||||
nb11, nb12);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
src0_row.data = src0_original + i02*nb02;
|
||||
|
||||
GGML_ASSERT(nb11 == sizeof(float)*ne10);
|
||||
GGML_ASSERT(nb1 == sizeof(float)*ne0);
|
||||
|
||||
src1_row.ne[1] = num_src1_rows;
|
||||
src1_row.nb[1] = nb11;
|
||||
src1_row.nb[2] = num_src1_rows*nb11;
|
||||
src1_row.nb[3] = num_src1_rows*nb11;
|
||||
|
||||
dst_row.ne[1] = num_src1_rows;
|
||||
dst_row.nb[1] = nb1;
|
||||
dst_row.nb[2] = num_src1_rows*nb1;
|
||||
dst_row.nb[3] = num_src1_rows*nb1;
|
||||
|
||||
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
||||
|
||||
{
|
||||
dim3 block_dims(std::min((unsigned int)ne0, 768u));
|
||||
dim3 grid_dims(num_src1_rows);
|
||||
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
||||
dst_original, dst_contiguous.get(),
|
||||
dev_row_mapping.get(),
|
||||
ne0,
|
||||
nb1, nb2);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
{
|
||||
dim3 block_dims(std::min((unsigned int)ne0, 768u));
|
||||
dim3 grid_dims(num_src1_rows);
|
||||
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
||||
dst_original, dst_contiguous.get(),
|
||||
dev_row_mapping.get(),
|
||||
ne0,
|
||||
nb1, nb2);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2488,10 +2488,10 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_MUL_MAT_ID) {
|
||||
if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
|
||||
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -3202,9 +3202,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
}
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK: {
|
||||
const size_t ts = ggml_type_size(op->src[0]->type);
|
||||
const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
|
||||
return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
|
||||
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
|
||||
}
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_POOL_2D:
|
||||
@@ -3236,6 +3234,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (op->src[0]->ne[0] == 192) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek MLA
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -4,18 +4,23 @@
|
||||
|
||||
template <typename T, typename type_acc, int block_size>
|
||||
static __global__ void mul_mat_vec(
|
||||
const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
|
||||
const int64_t row = blockIdx.x;
|
||||
const int64_t channel = blockIdx.y;
|
||||
const int64_t sample = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int64_t row = blockIdx.x;
|
||||
const int64_t channel_dst = blockIdx.y;
|
||||
const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
||||
const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
||||
const int64_t sample_dst = blockIdx.z;
|
||||
const int64_t sample_x = sample_dst / sample_ratio;
|
||||
const int64_t sample_y = sample_dst;
|
||||
const int tid = threadIdx.x;
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row;
|
||||
y += sample *stride_sample_y + channel *stride_channel_y;
|
||||
dst += sample *stride_sample_dst + channel *stride_channel_dst;
|
||||
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
||||
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
|
||||
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
|
||||
|
||||
const float2 * y2 = (const float2 *) y;
|
||||
|
||||
@@ -31,12 +36,19 @@ static __global__ void mul_mat_vec(
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
const float2 * x2 = (const float2 *) x;
|
||||
|
||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmpx = x2[col2];
|
||||
const float2 tmpy = y2[col2];
|
||||
sumf += tmpx.x*tmpy.x;
|
||||
sumf += tmpx.y*tmpy.y;
|
||||
}
|
||||
} else if constexpr (std::is_same<T, half>::value) {
|
||||
const half2 * x2 = (const half2 *) x;
|
||||
|
||||
if (std::is_same<type_acc, float>::value) {
|
||||
sumf = 0.0f;
|
||||
|
||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmpx = __half22float2(x2[col2]);
|
||||
const float2 tmpy = y2[col2];
|
||||
@@ -59,8 +71,6 @@ static __global__ void mul_mat_vec(
|
||||
}
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
const int * x2 = (const int *) x;
|
||||
sumf = 0.0f;
|
||||
|
||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const int tmpx = x2[col2];
|
||||
const float2 tmpy = y2[col2];
|
||||
@@ -92,17 +102,17 @@ static __global__ void mul_mat_vec(
|
||||
|
||||
template <typename T, typename type_acc>
|
||||
static void launch_mul_mat_vec_cuda(
|
||||
const T * x, const float * y, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
GGML_ASSERT(stride_row % 2 == 0);
|
||||
GGML_ASSERT(nchannels_y % nchannels_x == 0);
|
||||
GGML_ASSERT(nsamples_y % nsamples_x == 0);
|
||||
const int64_t channel_ratio = nchannels_y / nchannels_x;
|
||||
const int64_t sample_ratio = nsamples_y / nsamples_x;
|
||||
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
||||
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
||||
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
||||
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
||||
int device;
|
||||
int warp_size;
|
||||
|
||||
@@ -124,48 +134,48 @@ static void launch_mul_mat_vec_cuda(
|
||||
}
|
||||
|
||||
const int smem = warp_size*sizeof(float);
|
||||
const dim3 block_nums(nrows, nchannels_y, nsamples_y);
|
||||
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
|
||||
const dim3 block_dims(block_size_best, 1, 1);
|
||||
switch (block_size_best) {
|
||||
case 32: {
|
||||
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 64: {
|
||||
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 96: {
|
||||
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 128: {
|
||||
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 160: {
|
||||
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 192: {
|
||||
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 224: {
|
||||
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 256: {
|
||||
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -175,28 +185,28 @@ static void launch_mul_mat_vec_cuda(
|
||||
|
||||
template<typename T>
|
||||
static void mul_mat_vec_cuda(
|
||||
const T * x, const float * y, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
enum ggml_prec prec, cudaStream_t stream) {
|
||||
switch (prec) {
|
||||
case GGML_PREC_DEFAULT: {
|
||||
if constexpr(std::is_same<T, half>::value) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
launch_mul_mat_vec_cuda<T, half>
|
||||
(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case GGML_PREC_F32: {
|
||||
launch_mul_mat_vec_cuda<T, float>
|
||||
(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
return;
|
||||
}
|
||||
}
|
||||
launch_mul_mat_vec_cuda<T, float>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
|
||||
@@ -204,21 +214,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
||||
const size_t ts_src1 = ggml_type_size(src1->type);
|
||||
const size_t ts_dst = ggml_type_size(dst->type);
|
||||
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
GGML_ASSERT(ne12 == ne2);
|
||||
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
|
||||
GGML_ASSERT(ne13 == ne3);
|
||||
|
||||
GGML_ASSERT(nb00 == ts_src0);
|
||||
GGML_ASSERT(nb10 == ts_src1);
|
||||
GGML_ASSERT(nb0 == ts_dst);
|
||||
GGML_ASSERT( nb00 == ts_src0);
|
||||
GGML_ASSERT( nb10 == ts_src1);
|
||||
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
|
||||
GGML_ASSERT( nb0 == ts_dst);
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||||
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
const int64_t s01 = src0->nb[1] / ts_src0;
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s1 = dst->nb[1] / ts_dst;
|
||||
const int64_t s02 = src0->nb[2] / ts_src0;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s2 = dst->nb[2] / ts_dst;
|
||||
@@ -226,14 +239,33 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
const int64_t s3 = dst->nb[3] / ts_dst;
|
||||
|
||||
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
||||
const int64_t ncols_dst = ids ? ne2 : ne1;
|
||||
const int64_t nchannels_y = ids ? ne11 : ne12;
|
||||
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
||||
const int64_t stride_channel_dst = ids ? s1 : s2;
|
||||
const int64_t stride_channel_y = ids ? s11 : s12;
|
||||
|
||||
GGML_ASSERT(ncols_dst == 1);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half * src0_d = (const half *) src0->data;
|
||||
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
|
||||
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
||||
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
|
||||
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
||||
@@ -262,27 +294,34 @@ void ggml_cuda_op_mul_mat_vec(
|
||||
const int64_t stride_row = ne00;
|
||||
const int64_t nchannels_x = 1;
|
||||
const int64_t nchannels_y = 1;
|
||||
const int64_t nchannels_dst = 1;
|
||||
const int64_t stride_channel_x = 0;
|
||||
const int64_t stride_channel_y = 0;
|
||||
const int64_t stride_channel_dst = 0;
|
||||
const int64_t nsamples_x = 1;
|
||||
const int64_t nsamples_y = 1;
|
||||
const int64_t nsamples_dst = 1;
|
||||
const int64_t stride_sample_x = 0;
|
||||
const int64_t stride_sample_y = 0;
|
||||
const int64_t stride_sample_dst = 0;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0_dd_i;
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half * src0_d = (const half *) src0_dd_i;
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
|
||||
nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
|
||||
nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
|
||||
#define MMV_MAX_ROWS 512
|
||||
|
||||
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
#include "mmvq.cuh"
|
||||
#include "quantize.cuh"
|
||||
#include "vecdotq.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
|
||||
|
||||
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
|
||||
@@ -73,9 +76,9 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
|
||||
return MMVQ_PARAMETERS_GENERIC;
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_parameter_table_id table_id) {
|
||||
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
|
||||
if (table_id == MMVQ_PARAMETERS_GENERIC) {
|
||||
switch (ncols_y) {
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
case 2:
|
||||
case 3:
|
||||
@@ -90,7 +93,7 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_paramete
|
||||
return 1;
|
||||
}
|
||||
} else if (table_id == MMVQ_PARAMETERS_GCN) {
|
||||
switch (ncols_y) {
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
case 2:
|
||||
case 3:
|
||||
@@ -107,9 +110,9 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_paramete
|
||||
return 1;
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
|
||||
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
|
||||
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
|
||||
switch (ncols_y) {
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
return 1;
|
||||
case 2:
|
||||
@@ -127,19 +130,21 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int ta
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_y>
|
||||
template <ggml_type type, int ncols_dst>
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
||||
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||
constexpr int vdr = get_vdr_mmvq(type);
|
||||
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
|
||||
constexpr int nwarps = calc_nwarps(ncols_y, table_id);
|
||||
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
|
||||
constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
|
||||
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||
@@ -147,13 +152,21 @@ static __global__ void mul_mat_vec_q(
|
||||
const int tid = warp_size*threadIdx.y + threadIdx.x;
|
||||
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||
const int blocks_per_row_x = ncols_x / qk;
|
||||
const int blocks_per_col_y = nrows_y / QK8_1;
|
||||
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}};
|
||||
// The MUL_MAT_ID code path with ids != nullptr is only implemetned for ncols_dst == 1.
|
||||
const int channel_dst = blockIdx.y;
|
||||
const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
||||
const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
|
||||
const int sample_dst = blockIdx.z;
|
||||
const int sample_x = sample_dst / sample_ratio;
|
||||
const int sample_y = sample_dst;
|
||||
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
// partial sum for each thread
|
||||
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
||||
|
||||
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
|
||||
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
|
||||
|
||||
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
|
||||
const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
|
||||
@@ -162,18 +175,19 @@ static __global__ void mul_mat_vec_q(
|
||||
const int kqs = vdr * (tid % (qi/vdr));
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_y; ++j) {
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||
tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
|
||||
tmp[j][i] += vec_dot_q_cuda(
|
||||
vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
|
||||
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
||||
if (threadIdx.y > 0) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_y; ++j) {
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
|
||||
@@ -185,9 +199,11 @@ static __global__ void mul_mat_vec_q(
|
||||
return;
|
||||
}
|
||||
|
||||
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_y; ++j) {
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||
#pragma unroll
|
||||
@@ -197,88 +213,121 @@ static __global__ void mul_mat_vec_q(
|
||||
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
|
||||
}
|
||||
|
||||
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) {
|
||||
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
|
||||
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + int(threadIdx.x) < stride_col_dst)) {
|
||||
dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED(nrows_x);
|
||||
}
|
||||
|
||||
static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
|
||||
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
|
||||
const dim3 block_nums(nblocks, 1, 1);
|
||||
const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
|
||||
static std::pair<dim3, dim3> calc_launch_params(
|
||||
const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
|
||||
const int warp_size, const mmvq_parameter_table_id table_id) {
|
||||
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
|
||||
const dim3 block_nums(nblocks, nchannels_y, nsamples_y);
|
||||
const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
|
||||
return {block_nums, block_dims};
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static void mul_mat_vec_q_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
static void mul_mat_vec_q_switch_ncols_dst(
|
||||
const void * vx, const void * vy, const int32_t * ids, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst,
|
||||
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
||||
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
||||
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
|
||||
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
|
||||
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
|
||||
|
||||
const int channel_ratio = nchannels_dst / nchannels_x;
|
||||
const int sample_ratio = nsamples_dst / nsamples_x;
|
||||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
|
||||
|
||||
switch (ncols_y) {
|
||||
GGML_ASSERT(!ids || ncols_dst == 1);
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
{
|
||||
constexpr int c_ncols_y = 1;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
constexpr int c_ncols_dst = 1;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
constexpr int c_ncols_y = 2;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
constexpr int c_ncols_dst = 2;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 3:
|
||||
{
|
||||
constexpr int c_ncols_y = 3;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
constexpr int c_ncols_dst = 3;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 4:
|
||||
{
|
||||
constexpr int c_ncols_y = 4;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
constexpr int c_ncols_dst = 4;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 5:
|
||||
{
|
||||
constexpr int c_ncols_y = 5;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
constexpr int c_ncols_dst = 5;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 6:
|
||||
{
|
||||
constexpr int c_ncols_y = 6;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
constexpr int c_ncols_dst = 6;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 7:
|
||||
{
|
||||
constexpr int c_ncols_y = 7;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
constexpr int c_ncols_dst = 7;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 8:
|
||||
{
|
||||
constexpr int c_ncols_y = 8;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
constexpr int c_ncols_dst = 8;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -287,137 +336,213 @@ static void mul_mat_vec_q_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
static void mul_mat_vec_q_switch_type(
|
||||
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst,
|
||||
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
||||
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
||||
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
switch (type_x) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_1_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
void ggml_cuda_mul_mat_vec_q(
|
||||
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
|
||||
static void mul_mat_vec_q5_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
const size_t ts_src0 = ggml_type_size(src0->type);
|
||||
const size_t ts_src1 = ggml_type_size(src1->type);
|
||||
const size_t ts_dst = ggml_type_size(dst->type);
|
||||
|
||||
static void mul_mat_vec_q5_1_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
GGML_ASSERT( nb00 == ts_src0);
|
||||
GGML_ASSERT( nb10 == ts_src1);
|
||||
GGML_ASSERT( nb0 == ts_dst);
|
||||
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
|
||||
|
||||
static void mul_mat_vec_q8_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
|
||||
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
|
||||
{
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
quantize_row_q8_1_cuda(src1_d, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q2_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int64_t s01 = src0->nb[1] / ts_src0;
|
||||
const int64_t s11 = ne10_padded / QK8_1;
|
||||
const int64_t s1 = dst->nb[1] / ts_dst;
|
||||
const int64_t s02 = src0->nb[2] / ts_src0;
|
||||
const int64_t s2 = dst->nb[2] / ts_dst;
|
||||
const int64_t s03 = src0->nb[3] / ts_src0;
|
||||
const int64_t s3 = dst->nb[3] / ts_dst;
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
const int64_t s12 = ne11*s11;
|
||||
const int64_t s13 = ne12*s12;
|
||||
|
||||
static void mul_mat_vec_q3_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
||||
const int64_t ncols_dst = ids ? ne2 : ne1;
|
||||
const int64_t nchannels_y = ids ? ne11 : ne12;
|
||||
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
||||
const int64_t stride_col_dst = ids ? s2 : s1;
|
||||
const int64_t stride_col_y = ids ? s12 : s11;
|
||||
const int64_t stride_channel_dst = ids ? s1 : s2;
|
||||
const int64_t stride_channel_y = ids ? s11 : s12;
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q6_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_xs_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_s_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq3_xxs_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq1_s_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq1_m_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq4_nl_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq4_xs_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq3_s_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_switch_type(
|
||||
src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00,
|
||||
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q(
|
||||
@@ -440,68 +565,12 @@ void ggml_cuda_op_mul_mat_vec_q(
|
||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
const int stride_row_x = ne00 / ggml_blck_size(src0->type);
|
||||
const int stride_col_y = src1_padded_row_size / QK8_1;
|
||||
|
||||
mul_mat_vec_q_switch_type(
|
||||
src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
|
||||
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
||||
|
||||
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
|
||||
@@ -1,30 +1,40 @@
|
||||
#include "quantize.cuh"
|
||||
#include <cstdint>
|
||||
|
||||
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
|
||||
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
static __global__ void quantize_q8_1(
|
||||
const float * __restrict__ x, void * __restrict__ vy,
|
||||
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int ne1, const int ne2) {
|
||||
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (ix0 >= kx0_padded) {
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t ix1 = blockIdx.y;
|
||||
const int64_t i1 = blockIdx.y;
|
||||
const int64_t i2 = blockIdx.z % ne2;
|
||||
const int64_t i3 = blockIdx.z / ne2;
|
||||
|
||||
const int64_t i_padded = ix1*kx0_padded + ix0;
|
||||
const int64_t & i00 = i0;
|
||||
const int64_t & i01 = i1;
|
||||
const int64_t & i02 = i2;
|
||||
const int64_t & i03 = i3;
|
||||
|
||||
const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
|
||||
|
||||
block_q8_1 * y = (block_q8_1 *) vy;
|
||||
|
||||
const int64_t ib = i_padded / QK8_1; // block index
|
||||
const int64_t iqs = i_padded % QK8_1; // quant index
|
||||
const int64_t ib = i_cont / QK8_1; // block index
|
||||
const int64_t iqs = i_cont % QK8_1; // quant index
|
||||
|
||||
const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
|
||||
const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f;
|
||||
float amax = fabsf(xi);
|
||||
float sum = xi;
|
||||
|
||||
amax = warp_reduce_max(amax);
|
||||
sum = warp_reduce_sum(sum);
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
const float d = amax / 127;
|
||||
const float d = amax / 127;
|
||||
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
||||
|
||||
y[ib].qs[iqs] = q;
|
||||
@@ -127,43 +137,45 @@ static __global__ void quantize_mmq_q8_1(
|
||||
}
|
||||
|
||||
void quantize_row_q8_1_cuda(
|
||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
|
||||
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
|
||||
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(kx0_padded % QK8_1 == 0);
|
||||
GGML_ASSERT(ne0 % QK8_1 == 0);
|
||||
|
||||
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||
const dim3 num_blocks(block_num_x, kx1*channels, 1);
|
||||
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
|
||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
|
||||
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
|
||||
|
||||
GGML_UNUSED(type_x);
|
||||
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
GGML_UNUSED(type_src0);
|
||||
}
|
||||
|
||||
void quantize_mmq_q8_1_cuda(
|
||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
|
||||
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
|
||||
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
|
||||
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
|
||||
|
||||
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||
const dim3 num_blocks(block_num_x, kx1, channels);
|
||||
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
|
||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
|
||||
switch (mmq_get_q8_1_ds_layout(type_x)) {
|
||||
switch (mmq_get_q8_1_ds_layout(type_src0)) {
|
||||
case MMQ_Q8_1_DS_LAYOUT_D4:
|
||||
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
|
||||
break;
|
||||
case MMQ_Q8_1_DS_LAYOUT_DS4:
|
||||
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
|
||||
break;
|
||||
case MMQ_Q8_1_DS_LAYOUT_D2S6:
|
||||
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
GGML_UNUSED(s01);
|
||||
GGML_UNUSED(s02);
|
||||
GGML_UNUSED(s03);
|
||||
}
|
||||
|
||||
@@ -12,13 +12,13 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk
|
||||
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
|
||||
|
||||
typedef void (*quantize_cuda_t)(
|
||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
||||
const ggml_type type_x, cudaStream_t stream);
|
||||
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
|
||||
|
||||
void quantize_row_q8_1_cuda(
|
||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
||||
const ggml_type type_x, cudaStream_t stream);
|
||||
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
|
||||
|
||||
void quantize_mmq_q8_1_cuda(
|
||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
||||
const ggml_type type_x, cudaStream_t stream);
|
||||
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
#include <cstdint>
|
||||
|
||||
|
||||
2
ggml/src/ggml-cuda/vendors/hip.h
vendored
2
ggml/src/ggml-cuda/vendors/hip.h
vendored
@@ -71,6 +71,8 @@
|
||||
#define cudaLaunchHostFunc hipLaunchHostFunc
|
||||
#define cudaMalloc hipMalloc
|
||||
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
||||
#define cudaMallocManaged hipMallocManaged
|
||||
#define cudaMemAdvise hipMemAdvise
|
||||
#define cudaMemcpy hipMemcpy
|
||||
#define cudaMemcpyAsync hipMemcpyAsync
|
||||
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
|
||||
|
||||
@@ -89,10 +89,6 @@ endif()
|
||||
|
||||
add_compile_definitions(GGML_USE_HIP)
|
||||
|
||||
if (GGML_HIP_UMA)
|
||||
add_compile_definitions(GGML_HIP_UMA)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_FORCE_MMQ)
|
||||
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
||||
endif()
|
||||
|
||||
@@ -44,8 +44,8 @@ static struct ggml_backend_device g_ggml_backend_metal_device;
|
||||
// note: assumes single GPU device - the default one
|
||||
// TODO: support multiple GPU devices
|
||||
static struct ggml_backend_metal_device_context {
|
||||
id<MTLDevice> mtl_device;
|
||||
int mtl_device_ref_count;
|
||||
id<MTLDevice> mtl_device;
|
||||
int mtl_device_ref_count;
|
||||
id<MTLLibrary> mtl_library;
|
||||
|
||||
bool has_simdgroup_reduction;
|
||||
@@ -354,6 +354,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
||||
@@ -362,6 +363,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
||||
@@ -370,6 +372,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
||||
@@ -378,6 +381,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
||||
@@ -386,6 +390,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
||||
@@ -394,6 +399,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
||||
@@ -402,6 +408,14 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
||||
@@ -430,6 +444,13 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_SET_I32,
|
||||
GGML_METAL_KERNEL_TYPE_SET_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
@@ -460,6 +481,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_SQRT,
|
||||
GGML_METAL_KERNEL_TYPE_SIN,
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
@@ -468,7 +490,259 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_COUNT
|
||||
};
|
||||
|
||||
//
|
||||
// ggml_metal_heap
|
||||
//
|
||||
|
||||
struct ggml_metal_heap {
|
||||
// number of times the heap was unused
|
||||
int n_unused;
|
||||
|
||||
// total number of buffer allocations in this heap across all computes
|
||||
int64_t n_alloc;
|
||||
|
||||
// current offset in the heap - we reset this after each node in order to reuse the memory
|
||||
size_t offs;
|
||||
|
||||
// the currently allocated MTLBuffer objects in this heap
|
||||
id<MTLHeap> obj;
|
||||
|
||||
NSMutableArray * bufs;
|
||||
};
|
||||
|
||||
static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
|
||||
struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
|
||||
|
||||
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
|
||||
desc.storageMode = MTLStorageModePrivate;
|
||||
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
||||
desc.type = MTLHeapTypePlacement;
|
||||
desc.size = size;
|
||||
|
||||
heap->n_unused = 0;
|
||||
heap->n_alloc = 0;
|
||||
|
||||
heap->obj = [device newHeapWithDescriptor:desc];
|
||||
if (!heap->obj) {
|
||||
GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
|
||||
|
||||
free(heap);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
[desc release];
|
||||
|
||||
heap->bufs = [[NSMutableArray alloc] init];
|
||||
|
||||
return heap;
|
||||
}
|
||||
|
||||
static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
|
||||
heap->offs = 0;
|
||||
|
||||
// count how many graph computes the heap ended up being unused
|
||||
if ([heap->bufs count] > 0) {
|
||||
heap->n_unused = 0;
|
||||
} else {
|
||||
heap->n_unused++;
|
||||
}
|
||||
|
||||
for (id<MTLBuffer> buf in heap->bufs) {
|
||||
[buf release];
|
||||
}
|
||||
[heap->bufs removeAllObjects];
|
||||
|
||||
// tell the OS that it can reuse this memory if needed
|
||||
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
||||
[heap->obj setPurgeableState:MTLPurgeableStateVolatile];
|
||||
}
|
||||
|
||||
static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
|
||||
if (heap == nil) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_metal_heap_reset(heap);
|
||||
|
||||
[heap->obj release];
|
||||
[heap->bufs release];
|
||||
|
||||
free(heap);
|
||||
}
|
||||
|
||||
@interface ggml_metal_heap_ptr : NSObject
|
||||
|
||||
@property (nonatomic, assign) struct ggml_metal_heap * data;
|
||||
|
||||
@end
|
||||
|
||||
@implementation ggml_metal_heap_ptr
|
||||
@end
|
||||
|
||||
//
|
||||
// ggml_metal_mem_pool
|
||||
//
|
||||
|
||||
struct ggml_metal_mem_pool {
|
||||
id<MTLDevice> device;
|
||||
|
||||
int n_heaps; // total number of heaps ever created (including those that were removed)
|
||||
|
||||
NSMutableArray * heaps;
|
||||
NSMutableArray * heaps_to_remove;
|
||||
};
|
||||
|
||||
static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
|
||||
struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
|
||||
|
||||
mem_pool->n_heaps = 0;
|
||||
|
||||
mem_pool->heaps = [[NSMutableArray alloc] init];
|
||||
mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
|
||||
|
||||
return mem_pool;
|
||||
}
|
||||
|
||||
static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
|
||||
GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
|
||||
|
||||
size_t size_all = 0;
|
||||
size_t size_cur = 0;
|
||||
|
||||
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
||||
GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
|
||||
GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
|
||||
GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
|
||||
GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
|
||||
GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
|
||||
|
||||
if ([ptr.data->bufs count] > 0) {
|
||||
size_cur += [ptr.data->obj size];
|
||||
}
|
||||
size_all += [ptr.data->obj size];
|
||||
|
||||
ggml_metal_heap_free(ptr.data);
|
||||
[ptr release];
|
||||
}
|
||||
[mem_pool->heaps release];
|
||||
[mem_pool->heaps_to_remove release];
|
||||
|
||||
if (size_all > 0) {
|
||||
GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
|
||||
GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
free(mem_pool);
|
||||
}
|
||||
|
||||
static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
|
||||
for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
|
||||
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
|
||||
|
||||
struct ggml_metal_heap * heap = ptr.data;
|
||||
ggml_metal_heap_reset(heap);
|
||||
|
||||
// if the heap hasn't been used for a while, remove it
|
||||
if (heap->n_unused >= 128) {
|
||||
[mem_pool->heaps_to_remove addObject:@(i)];
|
||||
}
|
||||
}
|
||||
|
||||
if (mem_pool->heaps_to_remove.count > 0) {
|
||||
for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
|
||||
NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
|
||||
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
|
||||
|
||||
struct ggml_metal_heap * heap = ptr.data;
|
||||
ggml_metal_heap_free(heap);
|
||||
|
||||
[mem_pool->heaps removeObjectAtIndex:index];
|
||||
[ptr release];
|
||||
}
|
||||
|
||||
[mem_pool->heaps_to_remove removeAllObjects];
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
|
||||
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
||||
ptr.data->offs = 0;
|
||||
}
|
||||
}
|
||||
|
||||
static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
|
||||
const size_t alignment = 32;
|
||||
|
||||
const size_t size_aligned = GGML_PAD(size, alignment);
|
||||
|
||||
// try one of the existing heaps
|
||||
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
||||
struct ggml_metal_heap * heap = ptr.data;
|
||||
if (heap->offs + size_aligned <= [heap->obj size]) {
|
||||
// if this is the first buffer in the heap for the current command buffer, tell the OS that
|
||||
// it cannot free the memory used by the heap
|
||||
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
||||
if ([heap->bufs count] == 0) {
|
||||
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
||||
}
|
||||
|
||||
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
||||
if (buf == nil) {
|
||||
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
||||
return nil;
|
||||
}
|
||||
|
||||
heap->n_alloc++;
|
||||
heap->offs += size_aligned;
|
||||
|
||||
[heap->bufs addObject:buf];
|
||||
|
||||
return buf;
|
||||
}
|
||||
}
|
||||
|
||||
// create a new heap that can fit this buffer
|
||||
ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
|
||||
|
||||
struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
|
||||
if (heap == NULL) {
|
||||
GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
//GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
|
||||
|
||||
heap_ptr.data = heap;
|
||||
ggml_metal_heap_reset(heap);
|
||||
|
||||
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
||||
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
||||
if (buf == nil) {
|
||||
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
heap->n_alloc++;
|
||||
heap->offs += size_aligned;
|
||||
|
||||
[heap->bufs addObject:buf];
|
||||
|
||||
[mem_pool->heaps addObject:heap_ptr];
|
||||
mem_pool->n_heaps++;
|
||||
|
||||
return buf;
|
||||
}
|
||||
|
||||
struct ggml_metal_command_buffer {
|
||||
id<MTLCommandBuffer> obj;
|
||||
|
||||
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
|
||||
struct ggml_metal_mem_pool * mem_pool;
|
||||
};
|
||||
|
||||
struct ggml_backend_metal_context {
|
||||
id<MTLDevice> device;
|
||||
id<MTLCommandQueue> queue;
|
||||
|
||||
dispatch_queue_t d_queue;
|
||||
@@ -493,7 +767,7 @@ struct ggml_backend_metal_context {
|
||||
void (^encode_async)(size_t ith);
|
||||
|
||||
// n_cb command buffers + 1 used by the main thread
|
||||
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
||||
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
||||
|
||||
// abort ggml_metal_graph_compute if callback returns true
|
||||
ggml_abort_callback abort_callback;
|
||||
@@ -683,9 +957,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
||||
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
||||
|
||||
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
||||
|
||||
ctx->queue = [device newCommandQueue];
|
||||
ctx->device = device;
|
||||
ctx->queue = [device newCommandQueue];
|
||||
if (ctx->queue == nil) {
|
||||
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
||||
return NULL;
|
||||
@@ -746,7 +1022,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
ctx->gf = nil;
|
||||
ctx->encode_async = nil;
|
||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||
ctx->command_buffers[i] = nil;
|
||||
ctx->cmd_bufs[i].obj = nil;
|
||||
|
||||
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
|
||||
ctx->cmd_bufs[i].mem_pool->device = device;
|
||||
}
|
||||
|
||||
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
||||
@@ -1011,6 +1290,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
||||
@@ -1019,6 +1299,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
||||
@@ -1027,6 +1308,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
||||
@@ -1035,6 +1317,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
||||
@@ -1043,6 +1326,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
||||
@@ -1051,6 +1335,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
||||
@@ -1059,6 +1344,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
||||
@@ -1087,6 +1380,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
@@ -1117,6 +1417,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
@@ -1137,6 +1438,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
||||
|
||||
[ctx->queue release];
|
||||
|
||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||
// ctx->cmd_bufs[i].obj is auto released
|
||||
|
||||
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
|
||||
}
|
||||
|
||||
dispatch_release(ctx->d_queue);
|
||||
|
||||
free(ctx);
|
||||
@@ -1278,6 +1585,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_NEG:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
@@ -1351,6 +1659,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
// TODO: not sure if it is worth adding kernels for this size
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek sizes
|
||||
// TODO: disabled for now, until optmized
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
return false;
|
||||
}
|
||||
@@ -1436,10 +1749,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_metal_encode_node(
|
||||
static bool ggml_metal_encode_node(
|
||||
ggml_backend_t backend,
|
||||
int idx,
|
||||
id<MTLComputeCommandEncoder> encoder) {
|
||||
id<MTLComputeCommandEncoder> encoder,
|
||||
struct ggml_metal_mem_pool * mem_pool) {
|
||||
struct ggml_backend_metal_context * ctx = backend->context;
|
||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||
|
||||
@@ -1455,7 +1769,7 @@ static void ggml_metal_encode_node(
|
||||
struct ggml_tensor * dst = node;
|
||||
|
||||
if (ggml_is_empty(dst)) {
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
switch (dst->op) {
|
||||
@@ -1466,7 +1780,7 @@ static void ggml_metal_encode_node(
|
||||
case GGML_OP_PERMUTE:
|
||||
{
|
||||
// noop -> next node
|
||||
} return;
|
||||
} return true;
|
||||
default:
|
||||
{
|
||||
} break;
|
||||
@@ -1477,6 +1791,8 @@ static void ggml_metal_encode_node(
|
||||
GGML_ABORT("unsupported op");
|
||||
}
|
||||
|
||||
ggml_metal_mem_pool_clear(mem_pool);
|
||||
|
||||
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
||||
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
||||
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
||||
@@ -1963,6 +2279,18 @@ static void ggml_metal_encode_node(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_NEG:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||
@@ -2111,26 +2439,76 @@ static void ggml_metal_encode_node(
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
ggml_metal_kargs_soft_max args = {
|
||||
// use this branch to test the ggml_metal_mem_pool functionality
|
||||
#if 0
|
||||
// cpy to tmp buffer in MTLHeap
|
||||
|
||||
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
|
||||
if (!h_src0) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
|
||||
return false;
|
||||
}
|
||||
|
||||
offs_src0 = 0;
|
||||
|
||||
ggml_metal_kargs_cpy args_cpy = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.scale =*/ scale,
|
||||
/*.max_bias =*/ max_bias,
|
||||
/*.m0 =*/ m0,
|
||||
/*.m1 =*/ m1,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne00,
|
||||
/*.ne1 =*/ ne01,
|
||||
/*.ne2 =*/ ne02,
|
||||
/*.ne3 =*/ ne03,
|
||||
/*.nb0 =*/ nb00,
|
||||
/*.nb1 =*/ nb01,
|
||||
/*.nb2 =*/ nb02,
|
||||
/*.nb3 =*/ nb03,
|
||||
};
|
||||
|
||||
if (src0->type == GGML_TYPE_F16) {
|
||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
|
||||
}
|
||||
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:h_src0 offset:0 atIndex:2];
|
||||
|
||||
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
||||
int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
|
||||
|
||||
#else
|
||||
id<MTLBuffer> h_src0 = id_src0;
|
||||
#endif
|
||||
// softmax
|
||||
|
||||
ggml_metal_kargs_soft_max args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.scale =*/ scale,
|
||||
/*.max_bias =*/ max_bias,
|
||||
/*.m0 =*/ m0,
|
||||
/*.m1 =*/ m1,
|
||||
/*.n_head_log2 =*/ n_head_log2,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
|
||||
if (id_src1) {
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
@@ -3843,12 +4221,14 @@ static void ggml_metal_encode_node(
|
||||
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
||||
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
||||
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
||||
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) {
|
||||
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
@@ -3871,6 +4251,8 @@ static void ggml_metal_encode_node(
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
||||
@@ -3893,6 +4275,8 @@ static void ggml_metal_encode_node(
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
||||
@@ -3915,6 +4299,8 @@ static void ggml_metal_encode_node(
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
||||
@@ -3937,6 +4323,8 @@ static void ggml_metal_encode_node(
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
||||
@@ -3959,6 +4347,8 @@ static void ggml_metal_encode_node(
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
||||
@@ -3981,6 +4371,8 @@ static void ggml_metal_encode_node(
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
||||
@@ -4010,6 +4402,24 @@ static void ggml_metal_encode_node(
|
||||
use_vec_kernel = true;
|
||||
|
||||
switch (ne00) {
|
||||
case 96:
|
||||
{
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||
GGML_ABORT("add template specialization for this type");
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case 128:
|
||||
{
|
||||
switch (src1->type) {
|
||||
@@ -4082,12 +4492,36 @@ static void ggml_metal_encode_node(
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case 576:
|
||||
{
|
||||
if (ne20 == 512) {
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||
GGML_ABORT("add template specialization for this type");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4483,6 +4917,8 @@ static void ggml_metal_encode_node(
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_metal_graph_compute(
|
||||
@@ -4536,25 +4972,25 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
}
|
||||
|
||||
// the main thread commits the first few commands immediately
|
||||
// command_buffer[n_cb]
|
||||
// cmd_buf[n_cb]
|
||||
{
|
||||
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
ctx->command_buffers[n_cb] = command_buffer;
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
||||
|
||||
[command_buffer enqueue];
|
||||
[cmd_buf enqueue];
|
||||
ctx->encode_async(n_cb);
|
||||
}
|
||||
|
||||
// prepare the rest of the command buffers asynchronously
|
||||
// command_buffer[0.. n_cb)
|
||||
// cmd_buf[0.. n_cb)
|
||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
ctx->command_buffers[cb_idx] = command_buffer;
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
||||
|
||||
// always enqueue the first two command buffers
|
||||
// enqueue all of the command buffers if we don't need to abort
|
||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||
[command_buffer enqueue];
|
||||
[cmd_buf enqueue];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4563,14 +4999,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
// wait for completion and check status of each command buffer
|
||||
// needed to detect if the device ran out-of-memory for example (#1881)
|
||||
{
|
||||
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
|
||||
[command_buffer waitUntilCompleted];
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
||||
[cmd_buf waitUntilCompleted];
|
||||
|
||||
MTLCommandBufferStatus status = [command_buffer status];
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
|
||||
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
|
||||
return GGML_STATUS_FAILED;
|
||||
@@ -4578,20 +5014,20 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_cb; ++i) {
|
||||
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
|
||||
[command_buffer waitUntilCompleted];
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
||||
[cmd_buf waitUntilCompleted];
|
||||
|
||||
MTLCommandBufferStatus status = [command_buffer status];
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
|
||||
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
|
||||
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
||||
if (!next_buffer) {
|
||||
continue;
|
||||
}
|
||||
@@ -4974,8 +5410,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||
|
||||
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
||||
|
||||
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
||||
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
||||
|
||||
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
|
||||
|
||||
int node_start = 0;
|
||||
int node_end = n_nodes_0;
|
||||
@@ -4987,22 +5424,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||
|
||||
const bool should_capture = ctx->capture_next_compute;
|
||||
|
||||
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
||||
ggml_metal_mem_pool_reset(mem_pool);
|
||||
|
||||
for (int idx = node_start; idx < node_end; ++idx) {
|
||||
if (should_capture) {
|
||||
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
||||
}
|
||||
|
||||
ggml_metal_encode_node(backend, idx, encoder);
|
||||
const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
|
||||
|
||||
if (should_capture) {
|
||||
[encoder popDebugGroup];
|
||||
}
|
||||
|
||||
if (!res) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
[encoder endEncoding];
|
||||
|
||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||
[command_buffer commit];
|
||||
[cmd_buf commit];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -949,6 +949,13 @@ kernel void kernel_cos(
|
||||
dst[tpig] = cos(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_neg(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_sum_rows(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
@@ -3185,7 +3192,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
{
|
||||
float S[Q] = { [0 ... Q-1] = 0.0f };
|
||||
float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
|
||||
float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
|
||||
|
||||
// thread indices inside the simdgroup
|
||||
// TODO: see if we can utilize quad-group functions for better performance
|
||||
@@ -3445,7 +3452,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// reduce the warps sequentially
|
||||
for (ushort sg = 1; sg < nsg; ++sg) {
|
||||
float S = { 0.0f };
|
||||
float M = { -__FLT16_MAX__/2 };
|
||||
float M = { -__FLT_MAX__/2 };
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@@ -3546,6 +3553,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
@@ -3556,6 +3564,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
@@ -3566,6 +3575,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||
@@ -3575,6 +3585,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||
@@ -3584,6 +3595,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||
@@ -3593,6 +3605,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||
@@ -3602,6 +3615,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
||||
|
||||
#undef FA_TYPES
|
||||
|
||||
@@ -3685,7 +3699,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
{
|
||||
float S = 0.0f;
|
||||
float M = -__FLT16_MAX__/2;
|
||||
float M = -__FLT_MAX__/2;
|
||||
|
||||
// thread indices inside the simdgroup
|
||||
const short tx = tiisg%NL;
|
||||
@@ -3959,6 +3973,16 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 4>;
|
||||
@@ -3999,6 +4023,16 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
|
||||
|
||||
#undef FA_TYPES
|
||||
|
||||
template<typename T>
|
||||
|
||||
@@ -54,16 +54,41 @@ function(ggml_opencl_add_kernel KNAME)
|
||||
endfunction()
|
||||
|
||||
set(GGML_OPENCL_KERNELS
|
||||
ggml-opencl
|
||||
ggml-opencl_mm
|
||||
ggml-opencl_cvt
|
||||
ggml-opencl_gemv_noshuffle
|
||||
ggml-opencl_gemv_noshuffle_general
|
||||
ggml-opencl_mul_mat_Ab_Bi_8x4
|
||||
ggml-opencl_transpose_16
|
||||
ggml-opencl_transpose_32
|
||||
ggml-opencl_transpose_32_16
|
||||
ggml-opencl_im2col
|
||||
add
|
||||
clamp
|
||||
cpy
|
||||
cvt
|
||||
diag_mask_inf
|
||||
gelu
|
||||
gemv_noshuffle_general
|
||||
gemv_noshuffle
|
||||
get_rows
|
||||
im2col_f32
|
||||
im2col_f16
|
||||
mul_mat_Ab_Bi_8x4
|
||||
mul_mv_f16_f16
|
||||
mul_mv_f16_f32_1row
|
||||
mul_mv_f16_f32_l4
|
||||
mul_mv_f16_f32
|
||||
mul_mv_f32_f32
|
||||
mul_mv_q4_0_f32
|
||||
mul_mv_q4_0_f32_v
|
||||
mul_mv_q4_0_f32_8x_flat
|
||||
mul_mv_q4_0_f32_1d_8x_flat
|
||||
mul_mv_q4_0_f32_1d_16x_flat
|
||||
mul_mv_q6_k
|
||||
mul
|
||||
norm
|
||||
relu
|
||||
rms_norm
|
||||
rope
|
||||
scale
|
||||
silu
|
||||
softmax_4_f32
|
||||
softmax_4_f16
|
||||
softmax_f32
|
||||
softmax_f16
|
||||
transpose
|
||||
)
|
||||
|
||||
foreach (K ${GGML_OPENCL_KERNELS})
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
83
ggml/src/ggml-opencl/kernels/add.cl
Normal file
83
ggml/src/ggml-opencl/kernels/add.cl
Normal file
@@ -0,0 +1,83 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// add
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// general-purpose kernel for addition of two tensors
|
||||
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
||||
// cons: not very efficient
|
||||
kernel void kernel_add(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
int i13 = i03 % ne13;
|
||||
int i12 = i02 % ne12;
|
||||
int i11 = i01 % ne11;
|
||||
|
||||
global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
||||
global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
||||
global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
||||
|
||||
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||
const int i10 = i0 % ne10;
|
||||
*((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10));
|
||||
}
|
||||
}
|
||||
|
||||
// assumption: src1 is a row
|
||||
// broadcast src1 into src0
|
||||
kernel void kernel_add_row(
|
||||
global float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float4 * dst,
|
||||
ulong offsetd,
|
||||
int ne
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
|
||||
// This performs better than using %.
|
||||
uint gid = get_global_id(0);
|
||||
uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
|
||||
dst[gid] = src0[gid] + src1[idx1];
|
||||
}
|
||||
20
ggml/src/ggml-opencl/kernels/clamp.cl
Normal file
20
ggml/src/ggml-opencl/kernels/clamp.cl
Normal file
@@ -0,0 +1,20 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// clamp
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_clamp(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
float min,
|
||||
float max
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = src0[get_global_id(0)] < min ?
|
||||
min :
|
||||
(src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]);
|
||||
}
|
||||
184
ggml/src/ggml-opencl/kernels/cpy.cl
Normal file
184
ggml/src/ggml-opencl/kernels/cpy.cl
Normal file
@@ -0,0 +1,184 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// cpy
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
kernel void kernel_cpy_f16_f16(
|
||||
global half * src0,
|
||||
ulong offset0,
|
||||
global half * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = (global half*)((global char*)src0 + offset0);
|
||||
dst = (global half*)((global char*)dst + offsetd);
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
int i3 = n / (ne2*ne1*ne0);
|
||||
int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||
|
||||
global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
dst_data[i00] = src[0];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f16_f32(
|
||||
global half * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
|
||||
src0 = (global half*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
int i3 = n / (ne2*ne1*ne0);
|
||||
int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||
|
||||
global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
dst_data[i00] = src[0];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_f16(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global half * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global half*)((global char*)dst + offsetd);
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
int i3 = n / (ne2*ne1*ne0);
|
||||
int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||
|
||||
global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
|
||||
dst_data[i00] = src[0];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_f32(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
int i3 = n / (ne2*ne1*ne0);
|
||||
int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||
|
||||
global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
|
||||
dst_data[i00] = src[0];
|
||||
}
|
||||
}
|
||||
@@ -1,39 +1,20 @@
|
||||
//------------------------------------------------------------------------------
|
||||
// This file is contains additional kernels for data conversion.
|
||||
// This file is contains kernels for data conversion.
|
||||
// These kernels are used when loading the model, so its performance is less
|
||||
// important.
|
||||
//------------------------------------------------------------------------------
|
||||
#ifdef cl_khr_fp16
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#elif defined(cl_amd_fp16)
|
||||
#pragma OPENCL EXTENSION cl_amd_fp16 : enable
|
||||
#else
|
||||
#error "Half precision floating point not supportedby OpenCL implementation on your device."
|
||||
#endif
|
||||
|
||||
#ifdef cl_khr_subgroups
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#elif defined(cl_intel_subgroups)
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#error "Subgroup not supported on your device."
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
// Always use subgroup size of 32 on Intel.
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
// Always use subgroups size of 64 on Adreno.
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#else
|
||||
// TODO: do not know how to choose subgroup size on other GPUs.
|
||||
#error "Selecting subgroup size is not supported on your device."
|
||||
#endif
|
||||
|
||||
#define QK4_0 32
|
||||
@@ -66,13 +47,44 @@ struct block_q4_0
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// mul_vec_q_n_f32_flat_noshuffle
|
||||
//
|
||||
// This variation uses flat arrays (struct of arrays, SOA) representation for
|
||||
// quant tensors. It also uses non shuffled bit order for weights.
|
||||
//
|
||||
// The shuffled version is kept in the original file because moving it here
|
||||
// seems to result in worse performance for adreno.
|
||||
// kernel_convert_block_q4_0
|
||||
// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
|
||||
// This kernel does not deshuffle the bits.
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_convert_block_q4_0(
|
||||
global struct block_q4_0 * src0,
|
||||
global uchar * dst_q,
|
||||
global half * dst_d
|
||||
) {
|
||||
global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
|
||||
global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
|
||||
global half * d = (global half *) dst_d + get_global_id(0);
|
||||
|
||||
*d = b->d;
|
||||
|
||||
for (int i = 0; i < QK4_0/2; ++i) {
|
||||
q[i] = b->qs[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q4_0(
|
||||
global uchar * src_q,
|
||||
global half * src_d,
|
||||
global struct block_q4_0 * dst
|
||||
) {
|
||||
global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);
|
||||
global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);
|
||||
global half * d = (global half *) src_d + get_global_id(0);
|
||||
|
||||
b->d = *d;
|
||||
for (int i = 0; i < QK4_0/2; ++i) {
|
||||
b->qs[i] = q[i];
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_q4_0_noshuffle
|
||||
// Flatten q4_0 weights and unshuffle the bits
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
kernel void kernel_convert_block_q4_0_noshuffle(
|
||||
58
ggml/src/ggml-opencl/kernels/diag_mask_inf.cl
Normal file
58
ggml/src/ggml-opencl/kernels/diag_mask_inf.cl
Normal file
@@ -0,0 +1,58 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// diag_mask_inf kernels
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_diag_mask_inf(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int n_past
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i02 = get_global_id(2);
|
||||
int i01 = get_global_id(1);
|
||||
int i00 = get_global_id(0);
|
||||
|
||||
if (i00 > n_past + i01) {
|
||||
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
||||
} else {
|
||||
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_diag_mask_inf_8(
|
||||
global float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int n_past
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
|
||||
int i = 2*get_global_id(0);
|
||||
|
||||
dst[i+0] = src0[i+0];
|
||||
dst[i+1] = src0[i+1];
|
||||
int i4 = 4*i;
|
||||
int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
||||
int i01 = i4/(ne00); i4 -= i01*ne00;
|
||||
int i00 = i4;
|
||||
for (int k = 3; k >= 0; --k) {
|
||||
if (i00 + 4 + k <= n_past + i01) {
|
||||
break;
|
||||
}
|
||||
(&dst[i+1])[k] = -INFINITY;
|
||||
if (i00 + k > n_past + i01) {
|
||||
(&dst[i])[k] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
62
ggml/src/ggml-opencl/kernels/gelu.cl
Normal file
62
ggml/src/ggml-opencl/kernels/gelu.cl
Normal file
@@ -0,0 +1,62 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// gelu
|
||||
//------------------------------------------------------------------------------
|
||||
#define GELU_COEF_A 0.044715f
|
||||
#define GELU_QUICK_COEF -1.702f
|
||||
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
||||
|
||||
kernel void kernel_gelu(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
float x = src0[get_global_id(0)];
|
||||
|
||||
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_4(
|
||||
global float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
|
||||
float4 x = src0[get_global_id(0)];
|
||||
|
||||
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
float x = src0[get_global_id(0)];
|
||||
dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick_4(
|
||||
global float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
|
||||
float4 x = src0[get_global_id(0)];
|
||||
dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
||||
}
|
||||
163
ggml/src/ggml-opencl/kernels/get_rows.cl
Normal file
163
ggml/src/ggml-opencl/kernels/get_rows.cl
Normal file
@@ -0,0 +1,163 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
typedef char int8_t;
|
||||
typedef uchar uint8_t;
|
||||
typedef short int16_t;
|
||||
typedef ushort uint16_t;
|
||||
typedef int int32_t;
|
||||
typedef uint uint32_t;
|
||||
|
||||
#define QK4_0 32
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_0
|
||||
//------------------------------------------------------------------------------
|
||||
struct block_q4_0
|
||||
{
|
||||
half d;
|
||||
uint8_t qs[QK4_0 / 2];
|
||||
};
|
||||
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// dequantize_q4_0_f32, dequantize_q4_0_f16
|
||||
//------------------------------------------------------------------------------
|
||||
void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) {
|
||||
global ushort * qs = ((global ushort *)xb + 1);
|
||||
float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||
float d2 = d1 / 256.f;
|
||||
float md = -8.h * xb->d;
|
||||
ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||
ushort mask1 = mask0 << 8;
|
||||
|
||||
reg->s0 = d1 * (qs[0] & mask0) + md;
|
||||
reg->s1 = d2 * (qs[0] & mask1) + md;
|
||||
|
||||
reg->s2 = d1 * (qs[1] & mask0) + md;
|
||||
reg->s3 = d2 * (qs[1] & mask1) + md;
|
||||
|
||||
reg->s4 = d1 * (qs[2] & mask0) + md;
|
||||
reg->s5 = d2 * (qs[2] & mask1) + md;
|
||||
|
||||
reg->s6 = d1 * (qs[3] & mask0) + md;
|
||||
reg->s7 = d2 * (qs[3] & mask1) + md;
|
||||
|
||||
reg->s8 = d1 * (qs[4] & mask0) + md;
|
||||
reg->s9 = d2 * (qs[4] & mask1) + md;
|
||||
|
||||
reg->sa = d1 * (qs[5] & mask0) + md;
|
||||
reg->sb = d2 * (qs[5] & mask1) + md;
|
||||
|
||||
reg->sc = d1 * (qs[6] & mask0) + md;
|
||||
reg->sd = d2 * (qs[6] & mask1) + md;
|
||||
|
||||
reg->se = d1 * (qs[7] & mask0) + md;
|
||||
reg->sf = d2 * (qs[7] & mask1) + md;
|
||||
}
|
||||
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// get_rows
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_get_rows_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
int ne10,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb1,
|
||||
ulong nb2
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i10 = get_group_id(0);
|
||||
int i11 = get_group_id(1);
|
||||
|
||||
int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
|
||||
int i02 = i11;
|
||||
|
||||
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
|
||||
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||
((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_f16(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
int ne10,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb1,
|
||||
ulong nb2
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i10 = get_group_id(0);
|
||||
int i11 = get_group_id(1);
|
||||
|
||||
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
|
||||
int i02 = i11;
|
||||
|
||||
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
|
||||
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||
((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_q4_0(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
int ne10,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb1,
|
||||
ulong nb2
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
const int NL = 2;
|
||||
|
||||
int i10 = get_group_id(0);
|
||||
int i11 = get_group_id(1);
|
||||
|
||||
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
|
||||
int i02 = i11;
|
||||
|
||||
for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
|
||||
float16 temp;
|
||||
dequantize_q4_0_f32(
|
||||
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp);
|
||||
*(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,146 +0,0 @@
|
||||
#ifdef cl_khr_fp16
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#elif defined(cl_amd_fp16)
|
||||
#pragma OPENCL EXTENSION cl_amd_fp16 : enable
|
||||
#else
|
||||
#error "Half precision floating point not supportedby OpenCL implementation on your device."
|
||||
#endif
|
||||
|
||||
#ifdef cl_khr_subgroups
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#elif defined(cl_intel_subgroups)
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#error "Subgroup not supported on your device."
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
// Always use subgroup size of 32 on Intel.
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
// Always use subgroups size of 64 on Adreno.
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#else
|
||||
// TODO: do not know how to choose subgroup size on other GPUs.
|
||||
#error "Selecting subgroup size is not supported on your device."
|
||||
#endif
|
||||
|
||||
kernel void kernel_im2col_f32(
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
ulong batch_offset,
|
||||
ulong delta_offset,
|
||||
long IW,
|
||||
long IH,
|
||||
long IC,
|
||||
long OW,
|
||||
long OH,
|
||||
long KW,
|
||||
long KH,
|
||||
long pelements,
|
||||
long CHW,
|
||||
int s0,
|
||||
int s1,
|
||||
int p0,
|
||||
int p1,
|
||||
int d0,
|
||||
int d1
|
||||
) {
|
||||
// threadIdx.x + blockIdx.x * blockDim.x
|
||||
long i = get_global_id(0);
|
||||
if (i >= pelements) {
|
||||
return;
|
||||
}
|
||||
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
long ksize = OW * (KH > 1 ? KW : 1);
|
||||
long kx = i / ksize;
|
||||
long kd = kx * ksize;
|
||||
long ky = (i - kd) / OW;
|
||||
long ix = i % OW;
|
||||
|
||||
long oh = get_group_id(1);
|
||||
long batch = get_group_id(2) / IC;
|
||||
long ic = get_group_id(2) % IC;
|
||||
|
||||
long iiw = ix * s0 + kx * d0 - p0;
|
||||
long iih = oh * s1 + ky * d1 - p1;
|
||||
|
||||
long offset_dst =
|
||||
((batch * OH + oh) * OW + ix) * CHW +
|
||||
(ic * (KW * KH) + ky * KW + kx);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
long offset_src = ic * delta_offset + batch * batch_offset;
|
||||
dst[offset_dst] = src1[offset_src + iih * IW + iiw];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_im2col_f16(
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global half * dst,
|
||||
ulong offsetd,
|
||||
ulong batch_offset,
|
||||
ulong delta_offset,
|
||||
long IW,
|
||||
long IH,
|
||||
long IC,
|
||||
long OW,
|
||||
long OH,
|
||||
long KW,
|
||||
long KH,
|
||||
long pelements,
|
||||
long CHW,
|
||||
int s0,
|
||||
int s1,
|
||||
int p0,
|
||||
int p1,
|
||||
int d0,
|
||||
int d1
|
||||
) {
|
||||
long i = get_global_id(0);
|
||||
|
||||
if (i >= pelements) {
|
||||
return;
|
||||
}
|
||||
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global half*)((global char*)dst + offsetd);
|
||||
|
||||
long ksize = OW * (KH > 1 ? KW : 1);
|
||||
long kx = i / ksize;
|
||||
long kd = kx * ksize;
|
||||
long ky = (i - kd) / OW;
|
||||
long ix = i % OW;
|
||||
|
||||
long oh = get_group_id(1);
|
||||
long batch = get_group_id(2) / IC;
|
||||
long ic = get_group_id(2) % IC;
|
||||
|
||||
long iiw = ix * s0 + kx * d0 - p0;
|
||||
long iih = oh * s1 + ky * d1 - p1;
|
||||
|
||||
long offset_dst =
|
||||
((batch * OH + oh) * OW + ix) * CHW +
|
||||
(ic * (KW * KH) + ky * KW + kx);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
long offset_src = ic * delta_offset + batch * batch_offset;
|
||||
dst[offset_dst] = src1[offset_src + iih * IW + iiw];
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,26 +0,0 @@
|
||||
// 16-bit transpose, loading/storing a 4x4 tile of elements
|
||||
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
kernel void kernel_transpose_16(
|
||||
__read_only image1d_buffer_t input,
|
||||
__write_only image1d_buffer_t output,
|
||||
const uint rows,
|
||||
const uint cols
|
||||
) {
|
||||
|
||||
const int i = get_global_id(0);
|
||||
const int j = get_global_id(1);
|
||||
const int i_2 = i<<2;
|
||||
const int j_2 = j<<2;
|
||||
|
||||
half4 temp0 = read_imageh(input, (j_2+0)*cols+i);
|
||||
half4 temp1 = read_imageh(input, (j_2+1)*cols+i);
|
||||
half4 temp2 = read_imageh(input, (j_2+2)*cols+i);
|
||||
half4 temp3 = read_imageh(input, (j_2+3)*cols+i);
|
||||
|
||||
write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
|
||||
write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
|
||||
write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
|
||||
write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
// 32-bit transpose, loading/storing a 4x4 tile of elements
|
||||
|
||||
kernel void kernel_transpose_32(
|
||||
__read_only image1d_buffer_t input,
|
||||
__write_only image1d_buffer_t output,
|
||||
const uint rows,
|
||||
const uint cols
|
||||
) {
|
||||
|
||||
const int i = get_global_id(0);
|
||||
const int j = get_global_id(1);
|
||||
const int i_2 = i<<2;
|
||||
const int j_2 = j<<2;
|
||||
|
||||
float4 temp0 = read_imagef(input, (j_2+0)*cols+i);
|
||||
float4 temp1 = read_imagef(input, (j_2+1)*cols+i);
|
||||
float4 temp2 = read_imagef(input, (j_2+2)*cols+i);
|
||||
float4 temp3 = read_imagef(input, (j_2+3)*cols+i);
|
||||
|
||||
write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
|
||||
write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
|
||||
write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
|
||||
write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
|
||||
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
// 32-bit transpose, loading/storing a 4x4 tile of elements
|
||||
// Only used for activations
|
||||
// converts to FP16
|
||||
// also adds zero padding for non multiple of 8 prompt lengths
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) {
|
||||
|
||||
const int i = get_global_id(0);
|
||||
const int j = get_global_id(1);
|
||||
const int i_2 = i<<2;
|
||||
const int j_2 = j<<2;
|
||||
half4 temp0 = {0,0,0,0}; // initialize outputs to 0
|
||||
half4 temp1 = {0,0,0,0};
|
||||
half4 temp2 = {0,0,0,0};
|
||||
half4 temp3 = {0,0,0,0};
|
||||
|
||||
if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0
|
||||
temp0 = read_imageh(input, (j_2+0)*cols+i);
|
||||
}
|
||||
if((j_2+1)*cols+i*4+3 < rows*cols*16){
|
||||
temp1 = read_imageh(input, (j_2+1)*cols+i);
|
||||
}
|
||||
if((j_2+2)*cols+i*4+3 < rows*cols*16){
|
||||
temp2 = read_imageh(input, (j_2+2)*cols+i);
|
||||
}
|
||||
if((j_2+3)*cols+i*4+3 < rows*cols*16){
|
||||
temp3 = read_imageh(input, (j_2+3)*cols+i);
|
||||
}
|
||||
|
||||
write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding
|
||||
write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
|
||||
write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
|
||||
write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
|
||||
}
|
||||
57
ggml/src/ggml-opencl/kernels/im2col_f16.cl
Normal file
57
ggml/src/ggml-opencl/kernels/im2col_f16.cl
Normal file
@@ -0,0 +1,57 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
kernel void kernel_im2col_f16(
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global half * dst,
|
||||
ulong offsetd,
|
||||
ulong batch_offset,
|
||||
ulong delta_offset,
|
||||
long IW,
|
||||
long IH,
|
||||
long IC,
|
||||
long OW,
|
||||
long OH,
|
||||
long KW,
|
||||
long KH,
|
||||
long pelements,
|
||||
long CHW,
|
||||
int s0,
|
||||
int s1,
|
||||
int p0,
|
||||
int p1,
|
||||
int d0,
|
||||
int d1
|
||||
) {
|
||||
long i = get_global_id(0);
|
||||
if (i >= pelements) {
|
||||
return;
|
||||
}
|
||||
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global half*)((global char*)dst + offsetd);
|
||||
|
||||
long ksize = OW * (KH > 1 ? KW : 1);
|
||||
long kx = i / ksize;
|
||||
long kd = kx * ksize;
|
||||
long ky = (i - kd) / OW;
|
||||
long ix = i % OW;
|
||||
|
||||
long oh = get_group_id(1);
|
||||
long batch = get_group_id(2) / IC;
|
||||
long ic = get_group_id(2) % IC;
|
||||
|
||||
long iiw = ix * s0 + kx * d0 - p0;
|
||||
long iih = oh * s1 + ky * d1 - p1;
|
||||
|
||||
long offset_dst =
|
||||
((batch * OH + oh) * OW + ix) * CHW +
|
||||
(ic * (KW * KH) + ky * KW + kx);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
long offset_src = ic * delta_offset + batch * batch_offset;
|
||||
dst[offset_dst] = src1[offset_src + iih * IW + iiw];
|
||||
}
|
||||
}
|
||||
57
ggml/src/ggml-opencl/kernels/im2col_f32.cl
Normal file
57
ggml/src/ggml-opencl/kernels/im2col_f32.cl
Normal file
@@ -0,0 +1,57 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
kernel void kernel_im2col_f32(
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
ulong batch_offset,
|
||||
ulong delta_offset,
|
||||
long IW,
|
||||
long IH,
|
||||
long IC,
|
||||
long OW,
|
||||
long OH,
|
||||
long KW,
|
||||
long KH,
|
||||
long pelements,
|
||||
long CHW,
|
||||
int s0,
|
||||
int s1,
|
||||
int p0,
|
||||
int p1,
|
||||
int d0,
|
||||
int d1
|
||||
) {
|
||||
long i = get_global_id(0);
|
||||
if (i >= pelements) {
|
||||
return;
|
||||
}
|
||||
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
long ksize = OW * (KH > 1 ? KW : 1);
|
||||
long kx = i / ksize;
|
||||
long kd = kx * ksize;
|
||||
long ky = (i - kd) / OW;
|
||||
long ix = i % OW;
|
||||
|
||||
long oh = get_group_id(1);
|
||||
long batch = get_group_id(2) / IC;
|
||||
long ic = get_group_id(2) % IC;
|
||||
|
||||
long iiw = ix * s0 + kx * d0 - p0;
|
||||
long iih = oh * s1 + ky * d1 - p1;
|
||||
|
||||
long offset_dst =
|
||||
((batch * OH + oh) * OW + ix) * CHW +
|
||||
(ic * (KW * KH) + ky * KW + kx);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
long offset_src = ic * delta_offset + batch * batch_offset;
|
||||
dst[offset_dst] = src1[offset_src + iih * IW + iiw];
|
||||
}
|
||||
}
|
||||
79
ggml/src/ggml-opencl/kernels/mul.cl
Normal file
79
ggml/src/ggml-opencl/kernels/mul.cl
Normal file
@@ -0,0 +1,79 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// mul
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_mul(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
int i13 = i03 % ne13;
|
||||
int i12 = i02 % ne12;
|
||||
int i11 = i01 % ne11;
|
||||
|
||||
global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
||||
global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
||||
global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
||||
|
||||
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||
const int i10 = i0 % ne10;
|
||||
*((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10));
|
||||
}
|
||||
}
|
||||
|
||||
// assumption: src1 is a row
|
||||
// broadcast src1 into src0
|
||||
kernel void kernel_mul_row(
|
||||
global float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float4 * dst,
|
||||
ulong offsetd,
|
||||
int ne
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
|
||||
// This performs better than using %.
|
||||
uint gid = get_global_id(0);
|
||||
uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
|
||||
dst[gid] = src0[gid] * src1[idx1];
|
||||
}
|
||||
118
ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl
Normal file
118
ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl
Normal file
@@ -0,0 +1,118 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define N_F16_F16 4
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mat_f16_f16(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3)
|
||||
{
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int rb = get_group_id(1)*N_F16_F16;
|
||||
int im = get_group_id(2);
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
||||
|
||||
global half * x = (global half *) (src0 + offset_src0);
|
||||
|
||||
if (ne00 < 128) {
|
||||
for (int row = 0; row < N_F16_F16; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
||||
|
||||
global half * y = (global half *) (src1 + offset_src1);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
|
||||
sumf += (half) x[i] * (half) y[i];
|
||||
}
|
||||
|
||||
float all_sum = sub_group_reduce_add(sumf);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
global half4 * x4 = (global half4 *)x;
|
||||
for (int row = 0; row < N_F16_F16; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
||||
|
||||
global half * y = (global half *) (src1 + offset_src1);
|
||||
global half4 * y4 = (global half4 *) y;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
|
||||
sumf += (half) x4[i].s0 * y4[i].s0;
|
||||
sumf += (half) x4[i].s1 * y4[i].s1;
|
||||
sumf += (half) x4[i].s2 * y4[i].s2;
|
||||
sumf += (half) x4[i].s3 * y4[i].s3;
|
||||
}
|
||||
|
||||
float all_sum = sub_group_reduce_add(sumf);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) {
|
||||
all_sum += (half) x[i] * y[i];
|
||||
}
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
118
ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl
Normal file
118
ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl
Normal file
@@ -0,0 +1,118 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define N_F16_F32 4
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mat_f16_f32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int rb = get_group_id(1)*N_F16_F32;
|
||||
int im = get_group_id(2);
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
||||
|
||||
global half * x = (global half *) (src0 + offset_src0);
|
||||
|
||||
if (ne00 < 128) {
|
||||
for (int row = 0; row < N_F16_F32; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
||||
|
||||
global float * y = (global float *) (src1 + offset_src1);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
|
||||
sumf += convert_float(x[i]) * y[i];
|
||||
}
|
||||
|
||||
float all_sum = sub_group_reduce_add(sumf);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
global half4 * x4 = (global half4 *)x;
|
||||
for (int row = 0; row < N_F16_F32; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
||||
|
||||
global float * y = (global float *) (src1 + offset_src1);
|
||||
global float4 * y4 = (global float4 *) y;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
|
||||
sumf += convert_float(x4[i].s0) * y4[i].s0;
|
||||
sumf += convert_float(x4[i].s1) * y4[i].s1;
|
||||
sumf += convert_float(x4[i].s2) * y4[i].s2;
|
||||
sumf += convert_float(x4[i].s3) * y4[i].s3;
|
||||
}
|
||||
|
||||
float all_sum = sub_group_reduce_add(sumf);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) {
|
||||
all_sum += (float) x[i] * y[i];
|
||||
}
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
94
ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl
Normal file
94
ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl
Normal file
@@ -0,0 +1,94 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mat_f16_f32_1row(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
||||
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
||||
|
||||
global half * x = (global half *) (src0 + offset_src0);
|
||||
global float * y = (global float *) (src1 + offset_src1);
|
||||
|
||||
float sumf = 0;
|
||||
if (ne00 < 128) {
|
||||
for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
|
||||
sumf += (float) x[i] * (float) y[i];
|
||||
}
|
||||
float all_sum = sub_group_reduce_add(sumf);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
} else {
|
||||
global half4 * x4 = (global half4 *) x;
|
||||
global float4 * y4 = (global float4 *) y;
|
||||
for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
|
||||
sumf += (float) x4[i].s0 * y4[i].s0;
|
||||
sumf += (float) x4[i].s1 * y4[i].s1;
|
||||
sumf += (float) x4[i].s2 * y4[i].s2;
|
||||
sumf += (float) x4[i].s3 * y4[i].s3;
|
||||
}
|
||||
float all_sum = sub_group_reduce_add(sumf);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) {
|
||||
all_sum += (float) x[i] * y[i];
|
||||
}
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
84
ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl
Normal file
84
ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl
Normal file
@@ -0,0 +1,84 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
// Assumes row size (ne00) is a multiple of 4
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mat_f16_f32_l4(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int nrows = ne11;
|
||||
int r0 = get_group_id(0);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
||||
|
||||
global half4 * x4 = (global half4 *) (src0 + offset_src0);
|
||||
|
||||
for (int r1 = 0; r1 < nrows; ++r1) {
|
||||
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
||||
|
||||
global float4 * y4 = (global float4 *) (src1 + offset_src1);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
|
||||
sumf += convert_float(x4[i].s0) * y4[i].s0;
|
||||
sumf += convert_float(x4[i].s1) * y4[i].s1;
|
||||
sumf += convert_float(x4[i].s2) * y4[i].s2;
|
||||
sumf += convert_float(x4[i].s3) * y4[i].s3;
|
||||
}
|
||||
|
||||
float all_sum = sub_group_reduce_add(sumf);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
118
ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl
Normal file
118
ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl
Normal file
@@ -0,0 +1,118 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define N_F32_F32 4
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mat_f32_f32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int rb = get_group_id(1)*N_F32_F32;
|
||||
int im = get_group_id(2);
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
||||
|
||||
global float * x = (global float *) (src0 + offset_src0);
|
||||
|
||||
if (ne00 < 128) {
|
||||
for (int row = 0; row < N_F32_F32; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
||||
|
||||
global float * y = (global float *) (src1 + offset_src1);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
|
||||
sumf += (float) x[i] * (float) y[i];
|
||||
}
|
||||
|
||||
float all_sum = sub_group_reduce_add(sumf);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
global float4 * x4 = (global float4 *)x;
|
||||
for (int row = 0; row < N_F32_F32; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
||||
|
||||
global float * y = (global float *) (src1 + offset_src1);
|
||||
global float4 * y4 = (global float4 *) y;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
|
||||
sumf += (float) x4[i].s0 * y4[i].s0;
|
||||
sumf += (float) x4[i].s1 * y4[i].s1;
|
||||
sumf += (float) x4[i].s2 * y4[i].s2;
|
||||
sumf += (float) x4[i].s3 * y4[i].s3;
|
||||
}
|
||||
|
||||
float all_sum = sub_group_reduce_add(sumf);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) {
|
||||
all_sum += (float) x[i] * y[i];
|
||||
}
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
192
ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl
Normal file
192
ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl
Normal file
@@ -0,0 +1,192 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
#define QK4_1 32
|
||||
#define QR4_1 2
|
||||
#define QK5_0 32
|
||||
#define QR5_0 2
|
||||
#define QK5_1 32
|
||||
#define QR5_1 2
|
||||
#define QK8_0 32
|
||||
#define QR8_0 1
|
||||
#define QK_K 256
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
|
||||
typedef char int8_t;
|
||||
typedef uchar uint8_t;
|
||||
typedef short int16_t;
|
||||
typedef ushort uint16_t;
|
||||
typedef int int32_t;
|
||||
typedef uint uint32_t;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_0
|
||||
//------------------------------------------------------------------------------
|
||||
struct block_q4_0
|
||||
{
|
||||
half d;
|
||||
uint8_t qs[QK4_0 / 2];
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// mul_vec_q_n_f32
|
||||
//------------------------------------------------------------------------------
|
||||
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
||||
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||
inline float block_q_4_0_dot_y(
|
||||
global struct block_q4_0 * qb_curr,
|
||||
float sumy,
|
||||
private float * yl,
|
||||
int il
|
||||
) {
|
||||
float d = qb_curr->d;
|
||||
float2 acc = 0.f;
|
||||
global ushort * qs = ((global ushort *)qb_curr + 1 + il/2);
|
||||
for (int i = 0; i < 8; i+=2) {
|
||||
acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||
acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
||||
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
||||
}
|
||||
return d * (sumy * -8.f + acc.s0 + acc.s1);
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4 // each SIMD group works on 4 rows
|
||||
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
|
||||
#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
inline void mul_vec_q_n_f32(
|
||||
global void * src0,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
|
||||
const ulong nb = ne00/QK4_0;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
// (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global
|
||||
// id of a SIMD group in the grid.
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[16]; // src1 vector cache
|
||||
float sumf[N_DST]={0.f};
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix * QK4_0 + il;
|
||||
|
||||
// each thread in a SIMD group deals with half a block.
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
||||
float sumy = 0;
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
sumy += yb[i] + yb[i+1];
|
||||
yl[i+0] = yb[i+ 0];
|
||||
yl[i+1] = yb[i+ 1]/256.f;
|
||||
sumy += yb[i+16] + yb[i+17];
|
||||
yl[i+8] = yb[i+16]/16.f;
|
||||
yl[i+9] = yb[i+17]/4096.f;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il);
|
||||
}
|
||||
|
||||
// One thread in a SIMD group (i.e., subgroup) handles a half block,
|
||||
// hence then entire SIMD group handles SIMDWIDTH/2 blocks.
|
||||
// y points to the activation matrix (of type float). Therefore for
|
||||
// one thread, the # of blocks y should advance is SIMDWIDTH/2 (because
|
||||
// SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of
|
||||
// floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size.
|
||||
yb += QK4_0 * (N_SIMDWIDTH/2);
|
||||
}
|
||||
|
||||
// The above does not work for Adreno - it produces incorrect results for
|
||||
// row = 1, 2, 3 and only row = 0 gives the correct result.
|
||||
// If N_DST is changed, the below array must be initialized accordingly.
|
||||
// This also seems to perform better on Intel.
|
||||
float tot[N_DST] = {
|
||||
sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]),
|
||||
sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])};
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mat_q4_0_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
307
ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl
Normal file
307
ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl
Normal file
@@ -0,0 +1,307 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
#define QK4_1 32
|
||||
#define QR4_1 2
|
||||
#define QK5_0 32
|
||||
#define QR5_0 2
|
||||
#define QK5_1 32
|
||||
#define QR5_1 2
|
||||
#define QK8_0 32
|
||||
#define QR8_0 1
|
||||
#define QK_K 256
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
|
||||
typedef char int8_t;
|
||||
typedef uchar uint8_t;
|
||||
typedef short int16_t;
|
||||
typedef ushort uint16_t;
|
||||
typedef int int32_t;
|
||||
typedef uint uint32_t;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_0
|
||||
//------------------------------------------------------------------------------
|
||||
struct block_q4_0
|
||||
{
|
||||
half d;
|
||||
uint8_t qs[QK4_0 / 2];
|
||||
};
|
||||
|
||||
inline float mm_block_q_4_0_dot_y_flat(
|
||||
global uchar * x,
|
||||
global half * dh,
|
||||
float sumy,
|
||||
float16 yl,
|
||||
int il
|
||||
) {
|
||||
float d = *dh;
|
||||
global ushort * qs = ((global ushort *)x + il/2);
|
||||
float acc = 0.f;
|
||||
|
||||
acc += yl.s0 * (qs[0] & 0x000F);
|
||||
acc += yl.s1 * (qs[0] & 0x0F00);
|
||||
acc += yl.s8 * (qs[0] & 0x00F0);
|
||||
acc += yl.s9 * (qs[0] & 0xF000);
|
||||
|
||||
acc += yl.s2 * (qs[1] & 0x000F);
|
||||
acc += yl.s3 * (qs[1] & 0x0F00);
|
||||
acc += yl.sa * (qs[1] & 0x00F0);
|
||||
acc += yl.sb * (qs[1] & 0xF000);
|
||||
|
||||
acc += yl.s4 * (qs[2] & 0x000F);
|
||||
acc += yl.s5 * (qs[2] & 0x0F00);
|
||||
acc += yl.sc * (qs[2] & 0x00F0);
|
||||
acc += yl.sd * (qs[2] & 0xF000);
|
||||
|
||||
acc += yl.s6 * (qs[3] & 0x000F);
|
||||
acc += yl.s7 * (qs[3] & 0x0F00);
|
||||
acc += yl.se * (qs[3] & 0x00F0);
|
||||
acc += yl.sf * (qs[3] & 0xF000);
|
||||
|
||||
return d * (sumy * -8.f + acc);
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix)
|
||||
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
|
||||
#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 16
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
//
|
||||
// This variant performs 1d blocking with 16x output.
|
||||
// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix).
|
||||
//
|
||||
inline void mul_mat_q_n_f32_1d_16x_flat(
|
||||
global uchar * src0_q,
|
||||
global half * src0_d,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
const int nb = ne00/QK4_0;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
// (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
|
||||
// a SIMD group in the grid. Each SIMD group produces N_DST values in the
|
||||
// result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
|
||||
// Currently with llama2 7B, im is always 0.
|
||||
// TODO: how to handle im/gqa*(nb*ne0)?
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
// The number of scales is the same as the number of blocks.
|
||||
ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
// Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
|
||||
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
|
||||
|
||||
global uchar * x = (global uchar *) src0_q + offset0_q;
|
||||
global half * d = (global half *) src0_d + offset0_d;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float16 yl;
|
||||
float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix*QK4_0 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
||||
float sumy = 0.f;
|
||||
|
||||
sumy += yb[0];
|
||||
sumy += yb[1];
|
||||
sumy += yb[2];
|
||||
sumy += yb[3];
|
||||
sumy += yb[4];
|
||||
sumy += yb[5];
|
||||
sumy += yb[6];
|
||||
sumy += yb[7];
|
||||
|
||||
sumy += yb[16];
|
||||
sumy += yb[17];
|
||||
sumy += yb[18];
|
||||
sumy += yb[19];
|
||||
sumy += yb[20];
|
||||
sumy += yb[21];
|
||||
sumy += yb[22];
|
||||
sumy += yb[23];
|
||||
|
||||
yl.s0 = yb[0];
|
||||
yl.s1 = yb[1]/256.f;
|
||||
|
||||
yl.s2 = yb[2];
|
||||
yl.s3 = yb[3]/256.f;
|
||||
|
||||
yl.s4 = yb[4];
|
||||
yl.s5 = yb[5]/256.f;
|
||||
|
||||
yl.s6 = yb[6];
|
||||
yl.s7 = yb[7]/256.f;
|
||||
|
||||
yl.s8 = yb[16]/16.f;
|
||||
yl.s9 = yb[17]/4096.f;
|
||||
|
||||
yl.sa = yb[18]/16.f;
|
||||
yl.sb = yb[19]/4096.f;
|
||||
|
||||
yl.sc = yb[20]/16.f;
|
||||
yl.sd = yb[21]/4096.f;
|
||||
|
||||
yl.se = yb[22]/16.f;
|
||||
yl.sf = yb[23]/4096.f;
|
||||
|
||||
sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
|
||||
sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
|
||||
sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
|
||||
sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
|
||||
|
||||
sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
|
||||
sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
|
||||
sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
|
||||
sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
|
||||
|
||||
sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il);
|
||||
sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il);
|
||||
sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il);
|
||||
sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il);
|
||||
|
||||
sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il);
|
||||
sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il);
|
||||
sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il);
|
||||
sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il);
|
||||
|
||||
yb += QK4_0 * (N_SIMDWIDTH/2);
|
||||
}
|
||||
|
||||
float16 tot = (float16)(
|
||||
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
|
||||
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
|
||||
sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
|
||||
sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7),
|
||||
|
||||
sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9),
|
||||
sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb),
|
||||
sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd),
|
||||
sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf)
|
||||
);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
if (first_row + 0 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
|
||||
}
|
||||
if (first_row + 1 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
|
||||
}
|
||||
if (first_row + 2 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
|
||||
}
|
||||
if (first_row + 3 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
|
||||
}
|
||||
|
||||
if (first_row + 4 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
|
||||
}
|
||||
if (first_row + 5 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
|
||||
}
|
||||
if (first_row + 6 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
|
||||
}
|
||||
if (first_row + 7 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
|
||||
}
|
||||
|
||||
if (first_row + 8 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8;
|
||||
}
|
||||
if (first_row + 9 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9;
|
||||
}
|
||||
if (first_row + 10 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa;
|
||||
}
|
||||
if (first_row + 11 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb;
|
||||
}
|
||||
|
||||
if (first_row + 12 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc;
|
||||
}
|
||||
if (first_row + 13 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd;
|
||||
}
|
||||
if (first_row + 14 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se;
|
||||
}
|
||||
if (first_row + 15 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat(
|
||||
global uchar * src0_q,
|
||||
global half * src0_d,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
265
ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl
Normal file
265
ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl
Normal file
@@ -0,0 +1,265 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
#define QK4_1 32
|
||||
#define QR4_1 2
|
||||
#define QK5_0 32
|
||||
#define QR5_0 2
|
||||
#define QK5_1 32
|
||||
#define QR5_1 2
|
||||
#define QK8_0 32
|
||||
#define QR8_0 1
|
||||
#define QK_K 256
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
|
||||
typedef char int8_t;
|
||||
typedef uchar uint8_t;
|
||||
typedef short int16_t;
|
||||
typedef ushort uint16_t;
|
||||
typedef int int32_t;
|
||||
typedef uint uint32_t;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_0
|
||||
//------------------------------------------------------------------------------
|
||||
struct block_q4_0
|
||||
{
|
||||
half d;
|
||||
uint8_t qs[QK4_0 / 2];
|
||||
};
|
||||
|
||||
inline float mm_block_q_4_0_dot_y_flat(
|
||||
global uchar * x,
|
||||
global half * dh,
|
||||
float sumy,
|
||||
float16 yl,
|
||||
int il
|
||||
) {
|
||||
float d = *dh;
|
||||
global ushort * qs = ((global ushort *)x + il/2);
|
||||
float acc = 0.f;
|
||||
|
||||
acc += yl.s0 * (qs[0] & 0x000F);
|
||||
acc += yl.s1 * (qs[0] & 0x0F00);
|
||||
acc += yl.s8 * (qs[0] & 0x00F0);
|
||||
acc += yl.s9 * (qs[0] & 0xF000);
|
||||
|
||||
acc += yl.s2 * (qs[1] & 0x000F);
|
||||
acc += yl.s3 * (qs[1] & 0x0F00);
|
||||
acc += yl.sa * (qs[1] & 0x00F0);
|
||||
acc += yl.sb * (qs[1] & 0xF000);
|
||||
|
||||
acc += yl.s4 * (qs[2] & 0x000F);
|
||||
acc += yl.s5 * (qs[2] & 0x0F00);
|
||||
acc += yl.sc * (qs[2] & 0x00F0);
|
||||
acc += yl.sd * (qs[2] & 0xF000);
|
||||
|
||||
acc += yl.s6 * (qs[3] & 0x000F);
|
||||
acc += yl.s7 * (qs[3] & 0x0F00);
|
||||
acc += yl.se * (qs[3] & 0x00F0);
|
||||
acc += yl.sf * (qs[3] & 0xF000);
|
||||
|
||||
return d * (sumy * -8.f + acc);
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix)
|
||||
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
|
||||
#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 8
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
//
|
||||
// This variant performs 1d blocking with 8x output.
|
||||
// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix).
|
||||
//
|
||||
inline void mul_mat_q_n_f32_1d_8x_flat(
|
||||
global uchar * src0_q,
|
||||
global half * src0_d,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
const int nb = ne00/QK4_0;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
// (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
|
||||
// a SIMD group in the grid. Each SIMD group produces N_DST values in the
|
||||
// result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
|
||||
// Currently with llama2 7B, im is always 0.
|
||||
// TODO: how to handle im/gqa*(nb*ne0)?
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
// The number of scales is the same as the number of blocks.
|
||||
ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
// Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
|
||||
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
|
||||
|
||||
global uchar * x = (global uchar *) src0_q + offset0_q;
|
||||
global half * d = (global half *) src0_d + offset0_d;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float16 yl;
|
||||
float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix*QK4_0 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
||||
float sumy = 0.f;
|
||||
|
||||
sumy += yb[0];
|
||||
sumy += yb[1];
|
||||
sumy += yb[2];
|
||||
sumy += yb[3];
|
||||
sumy += yb[4];
|
||||
sumy += yb[5];
|
||||
sumy += yb[6];
|
||||
sumy += yb[7];
|
||||
|
||||
sumy += yb[16];
|
||||
sumy += yb[17];
|
||||
sumy += yb[18];
|
||||
sumy += yb[19];
|
||||
sumy += yb[20];
|
||||
sumy += yb[21];
|
||||
sumy += yb[22];
|
||||
sumy += yb[23];
|
||||
|
||||
yl.s0 = yb[0];
|
||||
yl.s1 = yb[1]/256.f;
|
||||
|
||||
yl.s2 = yb[2];
|
||||
yl.s3 = yb[3]/256.f;
|
||||
|
||||
yl.s4 = yb[4];
|
||||
yl.s5 = yb[5]/256.f;
|
||||
|
||||
yl.s6 = yb[6];
|
||||
yl.s7 = yb[7]/256.f;
|
||||
|
||||
yl.s8 = yb[16]/16.f;
|
||||
yl.s9 = yb[17]/4096.f;
|
||||
|
||||
yl.sa = yb[18]/16.f;
|
||||
yl.sb = yb[19]/4096.f;
|
||||
|
||||
yl.sc = yb[20]/16.f;
|
||||
yl.sd = yb[21]/4096.f;
|
||||
|
||||
yl.se = yb[22]/16.f;
|
||||
yl.sf = yb[23]/4096.f;
|
||||
|
||||
sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
|
||||
sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
|
||||
sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
|
||||
sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
|
||||
|
||||
sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
|
||||
sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
|
||||
sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
|
||||
sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
|
||||
|
||||
yb += QK4_0 * (N_SIMDWIDTH/2);
|
||||
}
|
||||
|
||||
float8 tot = (float8)(
|
||||
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
|
||||
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
|
||||
sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
|
||||
sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
|
||||
);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
if (first_row + 0 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
|
||||
}
|
||||
if (first_row + 1 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
|
||||
}
|
||||
if (first_row + 2 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
|
||||
}
|
||||
if (first_row + 3 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
|
||||
}
|
||||
|
||||
if (first_row + 4 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
|
||||
}
|
||||
if (first_row + 5 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
|
||||
}
|
||||
if (first_row + 6 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
|
||||
}
|
||||
if (first_row + 7 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat(
|
||||
global uchar * src0_q,
|
||||
global half * src0_d,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
272
ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl
Normal file
272
ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl
Normal file
@@ -0,0 +1,272 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
#define QK4_1 32
|
||||
#define QR4_1 2
|
||||
#define QK5_0 32
|
||||
#define QR5_0 2
|
||||
#define QK5_1 32
|
||||
#define QR5_1 2
|
||||
#define QK8_0 32
|
||||
#define QR8_0 1
|
||||
#define QK_K 256
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
|
||||
typedef char int8_t;
|
||||
typedef uchar uint8_t;
|
||||
typedef short int16_t;
|
||||
typedef ushort uint16_t;
|
||||
typedef int int32_t;
|
||||
typedef uint uint32_t;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_0
|
||||
//------------------------------------------------------------------------------
|
||||
struct block_q4_0
|
||||
{
|
||||
half d;
|
||||
uint8_t qs[QK4_0 / 2];
|
||||
};
|
||||
|
||||
// This function requires the original shuffled weights.
|
||||
// As a reminder, the original weights are shuffled so that (q[0], q[16]) are
|
||||
// packed together in a byte, so are (q[1], q[17]) and so on.
|
||||
inline float block_q_4_0_dot_y_flat(
|
||||
global uchar * x,
|
||||
global half * dh,
|
||||
float sumy,
|
||||
float16 yl,
|
||||
int il
|
||||
) {
|
||||
float d = *dh;
|
||||
global ushort * qs = ((global ushort *)x + il/2);
|
||||
float acc = 0.f;
|
||||
|
||||
acc += yl.s0 * (qs[0] & 0x000F);
|
||||
acc += yl.s1 * (qs[0] & 0x0F00);
|
||||
acc += yl.s8 * (qs[0] & 0x00F0);
|
||||
acc += yl.s9 * (qs[0] & 0xF000);
|
||||
|
||||
acc += yl.s2 * (qs[1] & 0x000F);
|
||||
acc += yl.s3 * (qs[1] & 0x0F00);
|
||||
acc += yl.sa * (qs[1] & 0x00F0);
|
||||
acc += yl.sb * (qs[1] & 0xF000);
|
||||
|
||||
acc += yl.s4 * (qs[2] & 0x000F);
|
||||
acc += yl.s5 * (qs[2] & 0x0F00);
|
||||
acc += yl.sc * (qs[2] & 0x00F0);
|
||||
acc += yl.sd * (qs[2] & 0xF000);
|
||||
|
||||
acc += yl.s6 * (qs[3] & 0x000F);
|
||||
acc += yl.s7 * (qs[3] & 0x0F00);
|
||||
acc += yl.se * (qs[3] & 0x00F0);
|
||||
acc += yl.sf * (qs[3] & 0xF000);
|
||||
|
||||
return d * (sumy * -8.f + acc);
|
||||
}
|
||||
|
||||
//
|
||||
// This variant outputs 8 values.
|
||||
//
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 8 // each SIMD group works on 8 rows
|
||||
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
|
||||
#define N_SIMDWIDTH 16 // assuming SIMD group size is 32
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 8
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
inline void mul_vec_q_n_f32_8x_flat(
|
||||
global uchar * src0_q,
|
||||
global half * src0_d,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
const ulong nb = ne00/QK4_0;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
// (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
|
||||
// a SIMD group in the grid. Each SIMD group produces N_DST values in the
|
||||
// result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
|
||||
// Currently with llama2 7B, im is always 0.
|
||||
// TODO: how to handle im/gqa*(nb*ne0)?
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
// The number of scales is the same as the number of blocks.
|
||||
ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
// Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
|
||||
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
|
||||
|
||||
global uchar * x = (global uchar *) src0_q + offset0_q;
|
||||
global half * d = (global half *) src0_d + offset0_d;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float16 yl;
|
||||
float8 sumf = 0.f;
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix*QK4_0 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
||||
float sumy = 0.f;
|
||||
|
||||
sumy += yb[0];
|
||||
sumy += yb[1];
|
||||
sumy += yb[2];
|
||||
sumy += yb[3];
|
||||
sumy += yb[4];
|
||||
sumy += yb[5];
|
||||
sumy += yb[6];
|
||||
sumy += yb[7];
|
||||
|
||||
sumy += yb[16];
|
||||
sumy += yb[17];
|
||||
sumy += yb[18];
|
||||
sumy += yb[19];
|
||||
sumy += yb[20];
|
||||
sumy += yb[21];
|
||||
sumy += yb[22];
|
||||
sumy += yb[23];
|
||||
|
||||
yl.s0 = yb[0];
|
||||
yl.s1 = yb[1]/256.f;
|
||||
|
||||
yl.s2 = yb[2];
|
||||
yl.s3 = yb[3]/256.f;
|
||||
|
||||
yl.s4 = yb[4];
|
||||
yl.s5 = yb[5]/256.f;
|
||||
|
||||
yl.s6 = yb[6];
|
||||
yl.s7 = yb[7]/256.f;
|
||||
|
||||
yl.s8 = yb[16]/16.f;
|
||||
yl.s9 = yb[17]/4096.f;
|
||||
|
||||
yl.sa = yb[18]/16.f;
|
||||
yl.sb = yb[19]/4096.f;
|
||||
|
||||
yl.sc = yb[20]/16.f;
|
||||
yl.sd = yb[21]/4096.f;
|
||||
|
||||
yl.se = yb[22]/16.f;
|
||||
yl.sf = yb[23]/4096.f;
|
||||
|
||||
sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
|
||||
sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
|
||||
sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
|
||||
sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
|
||||
|
||||
sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
|
||||
sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
|
||||
sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
|
||||
sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
|
||||
|
||||
yb += QK4_0 * (N_SIMDWIDTH/2);
|
||||
}
|
||||
|
||||
float8 tot = (float8)(
|
||||
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
|
||||
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
|
||||
sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
|
||||
sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
|
||||
);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
if (first_row + 0 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
|
||||
}
|
||||
if (first_row + 1 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
|
||||
}
|
||||
if (first_row + 2 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
|
||||
}
|
||||
if (first_row + 3 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
|
||||
}
|
||||
|
||||
if (first_row + 4 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
|
||||
}
|
||||
if (first_row + 5 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
|
||||
}
|
||||
if (first_row + 6 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
|
||||
}
|
||||
if (first_row + 7 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mat_q4_0_f32_8x_flat(
|
||||
global uchar * src0_q,
|
||||
global half * src0_d,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
254
ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl
Normal file
254
ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl
Normal file
@@ -0,0 +1,254 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
#define QK4_1 32
|
||||
#define QR4_1 2
|
||||
#define QK5_0 32
|
||||
#define QR5_0 2
|
||||
#define QK5_1 32
|
||||
#define QR5_1 2
|
||||
#define QK8_0 32
|
||||
#define QR8_0 1
|
||||
#define QK_K 256
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
|
||||
typedef char int8_t;
|
||||
typedef uchar uint8_t;
|
||||
typedef short int16_t;
|
||||
typedef ushort uint16_t;
|
||||
typedef int int32_t;
|
||||
typedef uint uint32_t;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_0
|
||||
//------------------------------------------------------------------------------
|
||||
struct block_q4_0
|
||||
{
|
||||
half d;
|
||||
uint8_t qs[QK4_0 / 2];
|
||||
};
|
||||
|
||||
//
|
||||
// This variant unrolls the loops and uses vector types instead of pointers.
|
||||
// It improves performance on Adreno but not so much on Intel.
|
||||
//
|
||||
inline float block_q_4_0_dot_y_v(
|
||||
global struct block_q4_0 * qb_curr,
|
||||
float sumy,
|
||||
float16 yl,
|
||||
int il
|
||||
) {
|
||||
float d = qb_curr->d;
|
||||
float acc = 0.f;
|
||||
global ushort * qs = ((global ushort *)qb_curr + 1 + il/2);
|
||||
|
||||
acc += yl.s0 * (qs[0] & 0x000F);
|
||||
acc += yl.s1 * (qs[0] & 0x0F00);
|
||||
acc += yl.s8 * (qs[0] & 0x00F0);
|
||||
acc += yl.s9 * (qs[0] & 0xF000);
|
||||
|
||||
acc += yl.s2 * (qs[1] & 0x000F);
|
||||
acc += yl.s3 * (qs[1] & 0x0F00);
|
||||
acc += yl.sa * (qs[1] & 0x00F0);
|
||||
acc += yl.sb * (qs[1] & 0xF000);
|
||||
|
||||
acc += yl.s4 * (qs[2] & 0x000F);
|
||||
acc += yl.s5 * (qs[2] & 0x0F00);
|
||||
acc += yl.sc * (qs[2] & 0x00F0);
|
||||
acc += yl.sd * (qs[2] & 0xF000);
|
||||
|
||||
acc += yl.s6 * (qs[3] & 0x000F);
|
||||
acc += yl.s7 * (qs[3] & 0x0F00);
|
||||
acc += yl.se * (qs[3] & 0x00F0);
|
||||
acc += yl.sf * (qs[3] & 0xF000);
|
||||
|
||||
return d * (sumy * -8.f + acc);
|
||||
}
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4 // each SIMD group works on 4 rows
|
||||
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
|
||||
#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
inline void mul_vec_q_n_f32_v(
|
||||
global void * src0,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
const ulong nb = ne00/QK4_0;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
// (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global
|
||||
// id of a SIMD group in the grid.
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float16 yl; // src1 vector cache
|
||||
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix * QK4_0 + il;
|
||||
|
||||
// each thread in a SIMD group deals with half a block.
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
||||
float sumy = 0;
|
||||
|
||||
sumy += yb[0];
|
||||
sumy += yb[1];
|
||||
sumy += yb[2];
|
||||
sumy += yb[3];
|
||||
sumy += yb[4];
|
||||
sumy += yb[5];
|
||||
sumy += yb[6];
|
||||
sumy += yb[7];
|
||||
|
||||
sumy += yb[16];
|
||||
sumy += yb[17];
|
||||
sumy += yb[18];
|
||||
sumy += yb[19];
|
||||
sumy += yb[20];
|
||||
sumy += yb[21];
|
||||
sumy += yb[22];
|
||||
sumy += yb[23];
|
||||
|
||||
|
||||
yl.s0 = yb[0];
|
||||
yl.s1 = yb[1]/256.f;
|
||||
|
||||
yl.s2 = yb[2];
|
||||
yl.s3 = yb[3]/256.f;
|
||||
|
||||
yl.s4 = yb[4];
|
||||
yl.s5 = yb[5]/256.f;
|
||||
|
||||
yl.s6 = yb[6];
|
||||
yl.s7 = yb[7]/256.f;
|
||||
|
||||
yl.s8 = yb[16]/16.f;
|
||||
yl.s9 = yb[17]/4096.f;
|
||||
|
||||
yl.sa = yb[18]/16.f;
|
||||
yl.sb = yb[19]/4096.f;
|
||||
|
||||
yl.sc = yb[20]/16.f;
|
||||
yl.sd = yb[21]/4096.f;
|
||||
|
||||
yl.se = yb[22]/16.f;
|
||||
yl.sf = yb[23]/4096.f;
|
||||
|
||||
sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il);
|
||||
sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il);
|
||||
sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il);
|
||||
sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il);
|
||||
|
||||
// One thread in a SIMD group (i.e., subgroup) handles a half block,
|
||||
// hence then entire SIMD group handles SIMDWIDTH/2 blocks.
|
||||
// y points to the activation matrix (of type float). Therefore for
|
||||
// one thread, the # of blocks y should advance is SIMDWIDTH/2 (because
|
||||
// SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of
|
||||
// floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size.
|
||||
yb += QK4_0 * (N_SIMDWIDTH/2);
|
||||
}
|
||||
|
||||
// The above does not work for Adreno - it produces incorrect results for
|
||||
// row = 1, 2, 3 and only row = 0 gives the correct result.
|
||||
// If N_DST is changed, the below array must be initialized accordingly.
|
||||
// This also seems to perform better on Intel.
|
||||
float4 tot = (float4)(
|
||||
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
|
||||
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
|
||||
);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
if (first_row + 0 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
|
||||
}
|
||||
if (first_row + 1 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
|
||||
}
|
||||
if (first_row + 2 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
|
||||
}
|
||||
if (first_row + 3 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mat_q4_0_f32_v(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
190
ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl
Normal file
190
ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl
Normal file
@@ -0,0 +1,190 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
#define QK4_1 32
|
||||
#define QR4_1 2
|
||||
#define QK5_0 32
|
||||
#define QR5_0 2
|
||||
#define QK5_1 32
|
||||
#define QR5_1 2
|
||||
#define QK8_0 32
|
||||
#define QR8_0 1
|
||||
#define QK_K 256
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
|
||||
typedef char int8_t;
|
||||
typedef uchar uint8_t;
|
||||
typedef short int16_t;
|
||||
typedef ushort uint16_t;
|
||||
typedef int int32_t;
|
||||
typedef uint uint32_t;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q6_K
|
||||
//------------------------------------------------------------------------------
|
||||
// 6-bit quantization
|
||||
// weight is represented as x = a * q
|
||||
// 16 blocks of 16 elements each
|
||||
// Effectively 6.5625 bits per weight
|
||||
typedef struct {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
||||
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
||||
half d; // super-block scale
|
||||
} block_q6_K;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_mul_mv_q6_K_f32
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 1 // number of rows each SIMD group works on
|
||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||
#define N_SIMDWIDTH 16 // SIMD group size
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 1
|
||||
#define N_SIMDGROUP 2
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q6_K_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
uchar kmask1 = 0x03;
|
||||
uchar kmask2 = 0x0C;
|
||||
uchar kmask3 = 0x30;
|
||||
uchar kmask4 = 0xC0;
|
||||
|
||||
int nb = ne00/QK_K;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int row = N_SIMDGROUP * r0 + get_sub_group_id();
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0;
|
||||
global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
// For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a
|
||||
// block. Values in a subblock shares a scale that is quantized with 8 bits;
|
||||
// the entire block shares a single floating point scale.
|
||||
// For work distribution, each thread processes a subblock (16 weights), hence
|
||||
// 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16
|
||||
// (super) blocks -- this is the block stride.
|
||||
// The 16 threads that process a (super) block are split into 2 portions, each has
|
||||
// 8 threads; each portion works on 8 subblocks.
|
||||
// For subgroup of 16 threads, the entire subgroup works on a single (super) block
|
||||
// before moving to the next (super) block. Thread0 - thread7 work on the
|
||||
// first 8 subblocks; thread8 - thread15 works on the last 8 subblocks.
|
||||
// Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on
|
||||
// subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but
|
||||
// works on a total of 16 weight values.
|
||||
int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0
|
||||
int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1
|
||||
int ip = tid/8; // first or second half of (super) block (0 or 1)
|
||||
int il = tid%8; // each half has 8 parts, one per scale
|
||||
int n = 4; // 4 scales at a time (and 4 sums)
|
||||
int l0 = n*il; // offset into half-block, 0..28
|
||||
int is = 8*ip + l0/16; // 0, 1, 8, 9
|
||||
|
||||
int y_offset = 128*ip + l0;
|
||||
int q_offset_l = 64*ip + l0;
|
||||
int q_offset_h = 32*ip + l0;
|
||||
|
||||
for (int i = ix; i < nb; i += BLOCK_STRIDE) {
|
||||
|
||||
global uint8_t * q1 = x[i].ql + q_offset_l;
|
||||
global uint8_t * q2 = q1 + QK_K/8;
|
||||
global uint8_t * qh = x[i].qh + q_offset_h;
|
||||
global int8_t * sc = x[i].scales + is;
|
||||
|
||||
global float * y = yy + i * QK_K + y_offset;
|
||||
|
||||
float dall = x[i].d;
|
||||
|
||||
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
||||
|
||||
sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f);
|
||||
sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f);
|
||||
sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f);
|
||||
sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f);
|
||||
|
||||
sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f);
|
||||
sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f);
|
||||
sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f);
|
||||
sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f);
|
||||
|
||||
sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f);
|
||||
sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f);
|
||||
sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f);
|
||||
sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f);
|
||||
|
||||
sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f);
|
||||
sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f);
|
||||
sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f);
|
||||
sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f);
|
||||
|
||||
sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);
|
||||
}
|
||||
|
||||
float tot = sub_group_reduce_add(sumf);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
||||
}
|
||||
}
|
||||
81
ggml/src/ggml-opencl/kernels/norm.cl
Normal file
81
ggml/src/ggml-opencl/kernels/norm.cl
Normal file
@@ -0,0 +1,81 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// norm
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_norm(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
float eps,
|
||||
local float * sum
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
dst = (global void*)((global char*)dst + offsetd);
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
|
||||
// MEAN
|
||||
// parallel sum
|
||||
sum[get_local_id(0)] = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
sum[get_local_id(0)] += x[i00];
|
||||
}
|
||||
// reduce
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
|
||||
if (get_local_id(0) < i) {
|
||||
sum[get_local_id(0)] += sum[get_local_id(0) + i];
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
float mean = sum[0] / ne00;
|
||||
|
||||
// recenter and VARIANCE
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
sum[get_local_id(0)] = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
y[i00] = x[i00] - mean;
|
||||
sum[get_local_id(0)] += y[i00] * y[i00];
|
||||
}
|
||||
|
||||
// reduce
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
|
||||
if (get_local_id(0) < i) {
|
||||
sum[get_local_id(0)] += sum[get_local_id(0) + i];
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
float variance = sum[0] / ne00;
|
||||
|
||||
float scale = 1.0f/sqrt(variance + eps);
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
y[i00] = y[i00] * scale;
|
||||
}
|
||||
}
|
||||
16
ggml/src/ggml-opencl/kernels/relu.cl
Normal file
16
ggml/src/ggml-opencl/kernels/relu.cl
Normal file
@@ -0,0 +1,16 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// relu
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_relu(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]);
|
||||
}
|
||||
96
ggml/src/ggml-opencl/kernels/rms_norm.cl
Normal file
96
ggml/src/ggml-opencl/kernels/rms_norm.cl
Normal file
@@ -0,0 +1,96 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// rms_norm
|
||||
//------------------------------------------------------------------------------
|
||||
// This kernel depends on subgroup size.
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_32
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_rms_norm(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
float eps,
|
||||
local float * sum // Note, the size depends on number of subgroups
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
global float * x_scalar = (global float *) x;
|
||||
float4 sumf = 0;
|
||||
float all_sum = 0;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
sumf += x[i00] * x[i00];
|
||||
}
|
||||
all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3;
|
||||
all_sum = sub_group_reduce_add(all_sum);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
sum[get_sub_group_id()] = all_sum;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
// broadcast
|
||||
for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
|
||||
if (get_local_id(0) < i) {
|
||||
sum[get_local_id(0)] += sum[get_local_id(0) + i];
|
||||
}
|
||||
}
|
||||
if (get_local_id(0) == 0) {
|
||||
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
||||
sum[0] += x_scalar[i];
|
||||
}
|
||||
sum[0] /= ne00;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
const float mean = sum[0];
|
||||
const float scale = 1.0f/sqrt(mean + eps);
|
||||
|
||||
global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
global float * y_scalar = (global float *) y;
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
if (get_local_id(0) == 0) {
|
||||
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
||||
y_scalar[i00] = x_scalar[i00] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
721
ggml/src/ggml-opencl/kernels/rope.cl
Normal file
721
ggml/src/ggml-opencl/kernels/rope.cl
Normal file
@@ -0,0 +1,721 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_rope
|
||||
//------------------------------------------------------------------------------
|
||||
float rope_yarn_ramp(float low, float high, int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
}
|
||||
|
||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
float2 rope_yarn(
|
||||
float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale
|
||||
) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
||||
}
|
||||
return (float2)(cos(theta) * mscale, sin(theta) * mscale);
|
||||
}
|
||||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||
float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
||||
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
||||
}
|
||||
|
||||
float2 rope_yarn_corr_dims(
|
||||
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow
|
||||
) {
|
||||
// start and end correction dims
|
||||
return (float2)(
|
||||
max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))),
|
||||
min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)))
|
||||
);
|
||||
}
|
||||
|
||||
kernel void kernel_rope_norm_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * src2,
|
||||
ulong offset2,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
src2 = (global float*)((global char*)src2 + offset2);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i3 = get_group_id(2);
|
||||
int i2 = get_group_id(1);
|
||||
int i1 = get_group_id(0);
|
||||
|
||||
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
|
||||
|
||||
global int * pos = src1;
|
||||
|
||||
float theta_base = (float) pos[i2];
|
||||
float inv_ndims = -1.f/n_dims;
|
||||
|
||||
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
|
||||
if (i0 < n_dims) {
|
||||
int ic = i0/2;
|
||||
|
||||
float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
|
||||
|
||||
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
float x0 = src[0];
|
||||
float x1 = src[1];
|
||||
|
||||
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
|
||||
dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
|
||||
} else {
|
||||
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope_norm_f16(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * src2,
|
||||
ulong offset2,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
src2 = (global float*)((global char*)src2 + offset2);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i3 = get_group_id(2);
|
||||
int i2 = get_group_id(1);
|
||||
int i1 = get_group_id(0);
|
||||
|
||||
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
|
||||
|
||||
global int * pos = src1;
|
||||
|
||||
float theta_base = (float) pos[i2];
|
||||
float inv_ndims = -1.f/n_dims;
|
||||
|
||||
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
|
||||
if (i0 < n_dims) {
|
||||
int ic = i0/2;
|
||||
|
||||
float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
|
||||
|
||||
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
float x0 = src[0];
|
||||
float x1 = src[1];
|
||||
|
||||
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
|
||||
dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
|
||||
} else {
|
||||
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope_neox_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * src2,
|
||||
ulong offset2,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
src2 = (global float*)((global char*)src2 + offset2);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i3 = get_group_id(2);
|
||||
int i2 = get_group_id(1);
|
||||
int i1 = get_group_id(0);
|
||||
|
||||
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
|
||||
|
||||
global int * pos = src1;
|
||||
|
||||
float theta_base = (float) pos[i2];
|
||||
float inv_ndims = -1.f/n_dims;
|
||||
|
||||
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
|
||||
if (i0 < n_dims) {
|
||||
int ic = i0/2;
|
||||
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
|
||||
|
||||
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
|
||||
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
|
||||
} else {
|
||||
global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope_neox_f16(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * src2,
|
||||
ulong offset2,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
src2 = (global float*)((global char*)src2 + offset2);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i3 = get_group_id(2);
|
||||
int i2 = get_group_id(1);
|
||||
int i1 = get_group_id(0);
|
||||
|
||||
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
|
||||
|
||||
global int * pos = src1;
|
||||
|
||||
float theta_base = (float) pos[i2];
|
||||
float inv_ndims = -1.f/n_dims;
|
||||
|
||||
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
|
||||
if (i0 < n_dims) {
|
||||
int ic = i0/2;
|
||||
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
|
||||
|
||||
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
|
||||
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
|
||||
} else {
|
||||
global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope_multi_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * src2,
|
||||
ulong offset2,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
int4 sections
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
src2 = (global float*)((global char*)src2 + offset2);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i3 = get_group_id(2);
|
||||
int i2 = get_group_id(1);
|
||||
int i1 = get_group_id(0);
|
||||
|
||||
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
|
||||
|
||||
global int * pos = src1;
|
||||
|
||||
const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
|
||||
const int sec_w = sections.s1 + sections.s0;
|
||||
|
||||
float inv_ndims = -1.f/n_dims;
|
||||
|
||||
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
|
||||
if (i0 < n_dims) {
|
||||
int ic = i0/2;
|
||||
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
float theta_base = 0.0f;
|
||||
|
||||
if (sector < sections.s0) {
|
||||
theta_base = pos[i2];
|
||||
}
|
||||
else if (sector >= sections.s0 && sector < sec_w) {
|
||||
theta_base = pos[i2 + ne2 * 1];
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 2];
|
||||
}
|
||||
else if (sector >= sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 3];
|
||||
}
|
||||
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
|
||||
|
||||
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
|
||||
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
|
||||
} else {
|
||||
global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope_multi_f16(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * src2,
|
||||
ulong offset2,
|
||||
global half * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
int4 sections
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
src2 = (global float*)((global char*)src2 + offset2);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i3 = get_group_id(2);
|
||||
int i2 = get_group_id(1);
|
||||
int i1 = get_group_id(0);
|
||||
|
||||
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
|
||||
|
||||
global int * pos = src1;
|
||||
|
||||
const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
|
||||
const int sec_w = sections.s1 + sections.s0;
|
||||
|
||||
float inv_ndims = -1.f/n_dims;
|
||||
|
||||
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
|
||||
if (i0 < n_dims) {
|
||||
int ic = i0/2;
|
||||
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
float theta_base = 0.0f;
|
||||
|
||||
if (sector < sections.s0) {
|
||||
theta_base = pos[i2];
|
||||
}
|
||||
else if (sector >= sections.s0 && sector < sec_w) {
|
||||
theta_base = pos[i2 + ne2 * 1];
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 2];
|
||||
}
|
||||
else if (sector >= sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 3];
|
||||
}
|
||||
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
|
||||
|
||||
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
|
||||
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
|
||||
} else {
|
||||
global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope_vision_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * src2,
|
||||
ulong offset2,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
int4 sections
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
src2 = (global float*)((global char*)src2 + offset2);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i3 = get_group_id(2);
|
||||
int i2 = get_group_id(1);
|
||||
int i1 = get_group_id(0);
|
||||
|
||||
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
|
||||
|
||||
global int * pos = src1;
|
||||
|
||||
const int sect_dims = sections.s0 + sections.s1;
|
||||
const int sec_w = sections.s1 + sections.s0;
|
||||
|
||||
float inv_ndims = -1.f/n_dims;
|
||||
|
||||
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
|
||||
int ic = i0/2;
|
||||
|
||||
const int sector = (i0/2) % sect_dims;
|
||||
float theta_base = 0.0f;
|
||||
|
||||
if (sector < sections.s0) {
|
||||
const int p = sector;
|
||||
theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
|
||||
} else if (sector >= sections.s0 && sector < sec_w) {
|
||||
const int p = sector - sections.s0;
|
||||
theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
|
||||
}
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
|
||||
|
||||
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims];
|
||||
|
||||
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
|
||||
dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope_vision_f16(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * src2,
|
||||
ulong offset2,
|
||||
global half * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
int4 sections
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
src2 = (global float*)((global char*)src2 + offset2);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i3 = get_group_id(2);
|
||||
int i2 = get_group_id(1);
|
||||
int i1 = get_group_id(0);
|
||||
|
||||
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
|
||||
|
||||
global int * pos = src1;
|
||||
|
||||
const int sect_dims = sections.s0 + sections.s1;
|
||||
const int sec_w = sections.s1 + sections.s0;
|
||||
|
||||
float inv_ndims = -1.f/n_dims;
|
||||
|
||||
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
|
||||
int ic = i0/2;
|
||||
|
||||
const int sector = (i0/2) % sect_dims;
|
||||
float theta_base = 0.0f;
|
||||
|
||||
if (sector < sections.s0) {
|
||||
const int p = sector;
|
||||
theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
|
||||
} else if (sector >= sections.s0 && sector < sec_w) {
|
||||
const int p = sector - sections.s0;
|
||||
theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
|
||||
}
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
|
||||
|
||||
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims];
|
||||
|
||||
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
|
||||
dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
|
||||
}
|
||||
}
|
||||
16
ggml/src/ggml-opencl/kernels/scale.cl
Normal file
16
ggml/src/ggml-opencl/kernels/scale.cl
Normal file
@@ -0,0 +1,16 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// scale
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_scale(
|
||||
global float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * dst,
|
||||
ulong offsetd,
|
||||
float scale
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
dst[get_global_id(0)] = src0[get_global_id(0)] * scale;
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user