mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-23 16:37:33 +03:00
Compare commits
47 Commits
b5869
...
compilade/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
942c55cd57 | ||
|
|
183eeb5518 | ||
|
|
50f53b3e40 | ||
|
|
42423ec4d3 | ||
|
|
0ee322cd0f | ||
|
|
c31e60647d | ||
|
|
67eade1bf9 | ||
|
|
7de5c7cab6 | ||
|
|
8eff95544e | ||
|
|
3120413ccd | ||
|
|
215535701d | ||
|
|
74bb294591 | ||
|
|
3e303b1107 | ||
|
|
0c1df14b5f | ||
|
|
b3ad3a0191 | ||
|
|
98197e5c98 | ||
|
|
f5e96b368f | ||
|
|
756aa1020a | ||
|
|
aaa088d87f | ||
|
|
0d5375d54b | ||
|
|
e33de128c7 | ||
|
|
118d52fefc | ||
|
|
0e79355075 | ||
|
|
43cd2b3eb5 | ||
|
|
1a9454a3d2 | ||
|
|
ba6f6be6ce | ||
|
|
2c0945027a | ||
|
|
1d19025909 | ||
|
|
635f945ed1 | ||
|
|
a5165a6ca9 | ||
|
|
16202d6f96 | ||
|
|
1be357d990 | ||
|
|
db502ddd0e | ||
|
|
c7a32e761d | ||
|
|
2d79a7077c | ||
|
|
8c13e16bb0 | ||
|
|
2217247051 | ||
|
|
efa9186dc8 | ||
|
|
894ed8d7b6 | ||
|
|
9e6b0e9419 | ||
|
|
503630e88a | ||
|
|
d19101c9a0 | ||
|
|
3ad0603c65 | ||
|
|
c8ab6a3ba3 | ||
|
|
3de9300c37 | ||
|
|
347247a24e | ||
|
|
bce54642c8 |
10
README.md
10
README.md
@@ -6,9 +6,9 @@
|
||||
[](https://github.com/ggml-org/llama.cpp/releases)
|
||||
[](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml)
|
||||
|
||||
[Roadmap](https://github.com/users/ggerganov/projects/7) / [Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml)
|
||||
[Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml) / [ops](https://github.com/ggml-org/llama.cpp/blob/master/docs/ops.md)
|
||||
|
||||
Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++
|
||||
LLM inference in C/C++
|
||||
|
||||
## Recent API changes
|
||||
|
||||
@@ -17,10 +17,9 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
||||
|
||||
## Hot topics
|
||||
|
||||
- 🔥 Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
|
||||
- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141](https://github.com/ggml-org/llama.cpp/pull/13141)), `libllava` will be deprecated
|
||||
- Hot PRs: [All](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+) | [Open](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+is%3Aopen)
|
||||
- Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
|
||||
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
|
||||
- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
|
||||
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
|
||||
- Introducing GGUF-my-LoRA https://github.com/ggml-org/llama.cpp/discussions/10123
|
||||
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggml-org/llama.cpp/discussions/9669
|
||||
@@ -134,6 +133,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
|
||||
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
|
||||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
|
||||
|
||||
#### Multimodal
|
||||
|
||||
|
||||
@@ -448,6 +448,15 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
|
||||
bool has_suffix = string_ends_with(str, suffix);
|
||||
if (has_suffix) {
|
||||
str = str.substr(0, str.size() - suffix.size());
|
||||
}
|
||||
return has_suffix;
|
||||
}
|
||||
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const char text_last_char = str.back();
|
||||
|
||||
@@ -522,6 +522,7 @@ static bool string_starts_with(const std::string & str,
|
||||
|
||||
// While we wait for C++20's std::string::ends_with...
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||
|
||||
@@ -300,6 +300,7 @@ class ModelBase:
|
||||
gguf.MODEL_TENSOR.POS_EMBD,
|
||||
gguf.MODEL_TENSOR.TOKEN_TYPES,
|
||||
gguf.MODEL_TENSOR.SSM_CONV1D,
|
||||
gguf.MODEL_TENSOR.SHORTCONV_CONV,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_FIRST,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_W1,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_W2,
|
||||
@@ -836,6 +837,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "f6791d196f87ce6b56a7d234be618e0d58f8cda3549416635b2bebcd22cd95c4":
|
||||
# ref: https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct
|
||||
res = "midm-2.0"
|
||||
if chkhsh == "169bf0296a13c4d9b7672313f749eb36501d931022de052aad6e36f2bf34dd51":
|
||||
# ref: https://huggingface.co/LiquidAI/LFM2-Tokenizer
|
||||
res = "lfm2"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@@ -7073,6 +7077,50 @@ class SmolLM3Model(LlamaModel):
|
||||
chat_template = tokenizer.chat_template.replace("[:]", "")
|
||||
self.gguf_writer.add_chat_template(chat_template)
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2ForCausalLM")
|
||||
@ModelBase.register("LFM2ForCausalLM")
|
||||
class LFM2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2
|
||||
|
||||
def _add_feed_forward_length(self):
|
||||
ff_dim = self.hparams["block_ff_dim"]
|
||||
|
||||
auto_adjust_ff_dim = self.hparams["block_auto_adjust_ff_dim"]
|
||||
ff_dim = self.hparams["block_ff_dim"]
|
||||
ffn_dim_multiplier = self.hparams["block_ffn_dim_multiplier"]
|
||||
multiple_of = self.hparams["block_multiple_of"]
|
||||
|
||||
if auto_adjust_ff_dim:
|
||||
ff_dim = int(2 * ff_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
ff_dim = int(ffn_dim_multiplier * ff_dim)
|
||||
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.gguf_writer.add_feed_forward_length(ff_dim)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
# set num_key_value_heads only for attention layers
|
||||
self.hparams["num_key_value_heads"] = [
|
||||
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
|
||||
for layer_type in self.hparams["layer_types"]
|
||||
]
|
||||
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["norm_eps"])
|
||||
self._add_feed_forward_length()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# conv op requires 2d tensor
|
||||
if 'conv.conv' in name:
|
||||
data_torch = data_torch.squeeze(1)
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
||||
@@ -130,6 +130,7 @@ models = [
|
||||
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
|
||||
{"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", },
|
||||
{"name": "midm-2.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct", },
|
||||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
|
||||
@@ -43,6 +43,7 @@
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -2230,6 +2231,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
ggml_cuda_op_get_rows_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_cuda_op_set_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
ggml_cuda_dup(ctx, dst);
|
||||
break;
|
||||
@@ -3216,6 +3220,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
{
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
op->src[1]->type == GGML_TYPE_I64;
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
|
||||
130
ggml/src/ggml-cuda/set-rows.cu
Normal file
130
ggml/src/ggml-cuda/set-rows.cu
Normal file
@@ -0,0 +1,130 @@
|
||||
#include "set-rows.cuh"
|
||||
|
||||
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {}
|
||||
|
||||
template<>
|
||||
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
|
||||
*dst_h = __float2half(*src_f);
|
||||
}
|
||||
|
||||
template<>
|
||||
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
|
||||
*dst_f = *src_f;
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __global__ void k_set_rows(
|
||||
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s10, const int64_t s11, const int64_t s12,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
||||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
|
||||
|
||||
if (i >= ne_total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i03 = i / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
|
||||
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
|
||||
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
|
||||
|
||||
const src_t* src_elem = src0_row + i00;
|
||||
dst_t* dst_elem = dst_row_ptr + i00;
|
||||
set_rows_1(src_elem, dst_elem);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void set_rows_cuda(
|
||||
const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
|
||||
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
|
||||
const dim3 grid_size(num_blocks);
|
||||
|
||||
|
||||
const int64_t s01 = nb01/sizeof(src_t);
|
||||
const int64_t s02 = nb02/sizeof(src_t);
|
||||
const int64_t s03 = nb03/sizeof(src_t);
|
||||
const int64_t s10 = nb10/sizeof(int64_t);
|
||||
const int64_t s11 = nb11/sizeof(int64_t);
|
||||
const int64_t s12 = nb12/sizeof(int64_t);
|
||||
const int64_t s1 = nb1/sizeof(dst_t);
|
||||
const int64_t s2 = nb2/sizeof(dst_t);
|
||||
const int64_t s3 = nb3/sizeof(dst_t);
|
||||
|
||||
if (ne_total > 0) {
|
||||
k_set_rows<<<grid_size, block_size, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
s01, s02, s03,
|
||||
s10, s11, s12,
|
||||
s1, s2, s3);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I64);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const int64_t * src1_d = (const int64_t *)src1->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
|
||||
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
set_rows_cuda(
|
||||
src0_d, src1_d, (float*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
set_rows_cuda(
|
||||
src0_d, src1_d, (half*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else {
|
||||
GGML_ABORT("unsupported type");
|
||||
}
|
||||
}
|
||||
7
ggml/src/ggml-cuda/set-rows.cuh
Normal file
7
ggml/src/ggml-cuda/set-rows.cuh
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_SET_ROWS_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -107,8 +107,11 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
||||
if (nc == 4) {
|
||||
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else if (nc == 3) {
|
||||
ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
GGML_ABORT("Only support kernel size = 4 now.");
|
||||
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
|
||||
}
|
||||
} else {
|
||||
if (nc == 4) {
|
||||
@@ -116,8 +119,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else if (nc == 3) {
|
||||
const int64_t split_n_t = 32;
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
GGML_ABORT("Only support kernel size = 4 right now.");
|
||||
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
19
ggml/src/ggml-cuda/vendors/hip.h
vendored
19
ggml/src/ggml-cuda/vendors/hip.h
vendored
@@ -10,9 +10,6 @@
|
||||
#include "rocblas/rocblas.h"
|
||||
#endif // __HIP_PLATFORM_AMD__
|
||||
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||
@@ -30,7 +27,6 @@
|
||||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
||||
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasDestroy hipblasDestroy
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
@@ -42,7 +38,6 @@
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
@@ -144,6 +139,20 @@
|
||||
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
||||
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
|
||||
#define cublasComputeType_t hipblasComputeType_t
|
||||
#define cudaDataType_t hipDataType
|
||||
#else
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||
#define cublasComputeType_t hipblasDatatype_t
|
||||
#define cudaDataType_t hipblasDatatype_t
|
||||
#endif
|
||||
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
|
||||
|
||||
@@ -425,18 +425,20 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_div_norepeat[2][2][2];
|
||||
|
||||
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
||||
vk_pipeline pipeline_upscale_f32;
|
||||
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
|
||||
vk_pipeline pipeline_scale_f32;
|
||||
vk_pipeline pipeline_sqr_f32;
|
||||
vk_pipeline pipeline_sin_f32;
|
||||
vk_pipeline pipeline_cos_f32;
|
||||
vk_pipeline pipeline_clamp_f32;
|
||||
vk_pipeline pipeline_pad_f32;
|
||||
vk_pipeline pipeline_roll_f32;
|
||||
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
||||
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
|
||||
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
|
||||
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_norm_f32;
|
||||
vk_pipeline pipeline_group_norm_f32;
|
||||
vk_pipeline pipeline_rms_norm_f32;
|
||||
@@ -693,6 +695,37 @@ struct vk_op_unary_push_constants {
|
||||
};
|
||||
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
|
||||
|
||||
static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
|
||||
GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
|
||||
ne = ne != 0 ? ne : ggml_nelements(dst);
|
||||
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
|
||||
|
||||
vk_op_unary_push_constants p{};
|
||||
p.ne = (uint32_t)ne;
|
||||
|
||||
size_t src0_tsize = ggml_type_size(src0->type);
|
||||
p.ne00 = (uint32_t)src0->ne[0];
|
||||
p.ne01 = (uint32_t)src0->ne[1];
|
||||
p.ne02 = (uint32_t)src0->ne[2];
|
||||
p.ne03 = (uint32_t)src0->ne[3];
|
||||
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
|
||||
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
|
||||
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
|
||||
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
|
||||
|
||||
size_t dst_tsize = ggml_type_size(dst->type);
|
||||
p.ne10 = (uint32_t)dst->ne[0];
|
||||
p.ne11 = (uint32_t)dst->ne[1];
|
||||
p.ne12 = (uint32_t)dst->ne[2];
|
||||
p.ne13 = (uint32_t)dst->ne[3];
|
||||
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
|
||||
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
|
||||
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
|
||||
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
|
||||
|
||||
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
|
||||
}
|
||||
|
||||
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
||||
// Precompute mp (m' in the paper) and L such that division
|
||||
// can be computed using a multiply (high 32b of 64b result)
|
||||
@@ -862,6 +895,7 @@ struct vk_op_conv2d_dw_push_constants {
|
||||
|
||||
struct vk_op_upscale_push_constants {
|
||||
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
|
||||
uint32_t ne00; uint32_t ne01;
|
||||
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
||||
float sf0; float sf1; float sf2; float sf3;
|
||||
@@ -1735,7 +1769,14 @@ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
|
||||
// number of rows/cols for flash attention shader
|
||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
||||
|
||||
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
|
||||
if (hsv >= 512) {
|
||||
return 2;
|
||||
} else {
|
||||
return 8;
|
||||
}
|
||||
}
|
||||
|
||||
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
||||
// 128 threads split into four subgroups, each subgroup does 1/4
|
||||
@@ -1760,7 +1801,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
||||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, 64};
|
||||
} else {
|
||||
return {scalar_flash_attention_num_large_rows, 32};
|
||||
return {get_fa_scalar_num_large_rows(hsv), 32};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1779,7 +1820,11 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
||||
|
||||
// small cols to reduce register count
|
||||
if (ggml_is_quantized(type) || hsk >= 256) {
|
||||
return {64, 32};
|
||||
if (hsk >= 512) {
|
||||
return {32, 32};
|
||||
} else {
|
||||
return {64, 32};
|
||||
}
|
||||
}
|
||||
return {64, 64};
|
||||
}
|
||||
@@ -1821,7 +1866,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||
const uint32_t warps = warptile[0] / warptile[10];
|
||||
|
||||
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
||||
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
|
||||
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
|
||||
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
||||
|
||||
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
||||
@@ -1946,10 +1991,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
s_mmq_wg_denoms_k = { 32, 32, 1 };
|
||||
|
||||
// spec constants and tile sizes for quant matmul_id
|
||||
l_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
||||
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
|
||||
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
||||
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
||||
l_mmqid_wg_denoms = { 128, 64, 1 };
|
||||
l_mmqid_wg_denoms = { 128, 128, 1 };
|
||||
m_mmqid_wg_denoms = { 128, 64, 1 };
|
||||
s_mmqid_wg_denoms = { 128, 64, 1 };
|
||||
|
||||
@@ -2738,19 +2783,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
}
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
||||
@@ -2790,7 +2857,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
@@ -2802,6 +2871,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
@@ -6048,7 +6119,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
||||
const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
@@ -6173,7 +6244,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
case FA_SCALAR:
|
||||
case FA_COOPMAT1:
|
||||
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
||||
max_gqa = scalar_flash_attention_num_large_rows;
|
||||
max_gqa = get_fa_scalar_num_large_rows(HSV);
|
||||
break;
|
||||
case FA_COOPMAT2:
|
||||
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
||||
@@ -6468,8 +6539,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_UPSCALE:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
|
||||
return ctx->device->pipeline_upscale_f32;
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
int mode = ggml_get_op_params_i32(dst, 0);
|
||||
switch (mode) {
|
||||
case GGML_SCALE_MODE_NEAREST:
|
||||
return ctx->device->pipeline_upscale_nearest_f32;
|
||||
case GGML_SCALE_MODE_BILINEAR:
|
||||
return ctx->device->pipeline_upscale_bilinear_f32;
|
||||
case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
|
||||
return ctx->device->pipeline_upscale_bilinear_ac_f32;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SCALE:
|
||||
@@ -6502,6 +6581,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_pad_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ROLL:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_roll_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_REPEAT:
|
||||
if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
|
||||
return ctx->device->pipeline_repeat_f32;
|
||||
@@ -6516,6 +6600,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DUP:
|
||||
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
|
||||
case GGML_OP_SET_ROWS:
|
||||
return ctx->device->pipeline_set_rows[dst->type];
|
||||
case GGML_OP_SILU_BACK:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_silu_back_f32;
|
||||
@@ -6754,6 +6840,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_SET_ROWS:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -7048,6 +7135,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_CPY:
|
||||
@@ -7067,6 +7155,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
ne *= ggml_type_size(src0->type) / 2;
|
||||
}
|
||||
}
|
||||
// copy_to_quant has block size of 32, and each thread does QUANT_K elements.
|
||||
// Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
|
||||
// So divide by block size here before splitting into 512x512 groups.
|
||||
if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
||||
ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
|
||||
}
|
||||
if (ne > 262144) {
|
||||
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
||||
} else if (ne > 512) {
|
||||
@@ -7075,6 +7169,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
elements = { ne, 1, 1 };
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
uint32_t ne = ggml_nelements(src0);
|
||||
if (ggml_is_quantized(dst->type)) {
|
||||
// quants run 32 threads each doing QUANT_K elements
|
||||
ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
|
||||
} else {
|
||||
// scalar types do one element per thread, running 512 threads
|
||||
ne = CEIL_DIV(ne, 512);
|
||||
}
|
||||
if (ne > 262144) {
|
||||
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
||||
} else if (ne > 512) {
|
||||
elements = { 512, CEIL_DIV(ne, 512), 1 };
|
||||
} else {
|
||||
elements = { ne, 1, 1 };
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
|
||||
break;
|
||||
@@ -7484,14 +7597,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
|
||||
static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
const float sf0 = (float)dst->ne[0] / src0->ne[0];
|
||||
const float sf1 = (float)dst->ne[1] / src0->ne[1];
|
||||
const float sf2 = (float)dst->ne[2] / src0->ne[2];
|
||||
const float sf3 = (float)dst->ne[3] / src0->ne[3];
|
||||
float sf0 = (float)dst->ne[0] / src0->ne[0];
|
||||
float sf1 = (float)dst->ne[1] / src0->ne[1];
|
||||
float sf2 = (float)dst->ne[2] / src0->ne[2];
|
||||
float sf3 = (float)dst->ne[3] / src0->ne[3];
|
||||
|
||||
if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
|
||||
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
|
||||
}
|
||||
|
||||
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
|
||||
(uint32_t)ggml_nelements(dst), 0, 0,
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
|
||||
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
|
||||
sf0, sf1, sf2, sf3,
|
||||
@@ -7499,123 +7619,64 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||
}
|
||||
|
||||
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = ggml_get_op_params_f32(dst, 0);
|
||||
p.param2 = ggml_get_op_params_f32(dst, 1);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
op_params[0], op_params[1],
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = ggml_get_op_params_f32(dst, 0);
|
||||
p.param2 = ggml_get_op_params_f32(dst, 1);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
op_params[0], op_params[1],
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
|
||||
(uint32_t)ggml_nelements(dst),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
|
||||
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
|
||||
const int32_t s2 = ggml_get_op_params_i32(dst, 2);
|
||||
const int32_t s3 = ggml_get_op_params_i32(dst, 3);
|
||||
const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
|
||||
const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
|
||||
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
memcpy(&p.param1, &s01_packed, sizeof(float));
|
||||
memcpy(&p.param2, &s23_packed, sizeof(float));
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
|
||||
(uint32_t)ggml_nelements(dst),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
|
||||
(uint32_t)ggml_nelements(dst),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
uint32_t ne = (uint32_t)ggml_nelements(src0);
|
||||
if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
||||
// Convert from number of logical elements to 2- or 4-byte units.
|
||||
@@ -7627,13 +7688,22 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||
}
|
||||
}
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
|
||||
ne,
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0.0f, 0.0f, 0,
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
@@ -8956,7 +9026,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SILU_BACK:
|
||||
@@ -9023,6 +9095,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SILU_BACK:
|
||||
@@ -9125,12 +9198,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_PAD:
|
||||
ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_ROLL:
|
||||
ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DUP:
|
||||
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_SILU_BACK:
|
||||
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
@@ -9345,7 +9426,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SILU_BACK:
|
||||
@@ -10411,9 +10494,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
// TODO: add support
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||
return false;
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_CPY:
|
||||
@@ -10499,13 +10593,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
return true;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_ARGSORT:
|
||||
@@ -11028,6 +11121,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
} else {
|
||||
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
}
|
||||
} else if (tensor->op == GGML_OP_SET_ROWS) {
|
||||
tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
} else if (tensor->op == GGML_OP_CONT) {
|
||||
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
||||
} else if (tensor->op == GGML_OP_RESHAPE) {
|
||||
|
||||
@@ -6,17 +6,25 @@ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bi
|
||||
#endif // RTE16
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
// 16 invocations needed for init_iq4nl_shmem
|
||||
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
|
||||
#if defined(SET_ROWS) && QUANT_K == 1
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
const uint BLOCK_SIZE = 512;
|
||||
#else
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||
const uint BLOCK_SIZE = 32;
|
||||
#endif
|
||||
|
||||
layout (binding = 0) readonly buffer S {float data_s[];};
|
||||
|
||||
#if defined(SET_ROWS)
|
||||
#include "generic_binary_head.comp"
|
||||
layout (binding = 1) readonly buffer C {uvec2 data_i[];};
|
||||
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
|
||||
#else
|
||||
#include "generic_unary_head.comp"
|
||||
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
void quantize(uint dst_idx, uint src_idx)
|
||||
@@ -221,15 +229,56 @@ void quantize(uint dst_idx, uint src_idx)
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
void quantize(uint dst_idx, uint src_idx)
|
||||
{
|
||||
data_q[dst_idx] = A_TYPE(data_s[src_idx]);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_BF16)
|
||||
void quantize(uint dst_idx, uint src_idx)
|
||||
{
|
||||
data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(SET_ROWS)
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
if (gl_LocalInvocationIndex.x != 0) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
|
||||
const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
uint i12 = fastmod(i03, p.ne12);
|
||||
uint i11 = fastmod(i02, p.ne11);
|
||||
uint i10 = i01;
|
||||
|
||||
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
|
||||
|
||||
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
|
||||
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
|
||||
|
||||
quantize(dst_idx, src0_idx);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
@@ -240,3 +289,5 @@ void main() {
|
||||
|
||||
quantize(dst_idx, src_idx);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
uint _ne1;
|
||||
#ifdef COOPMAT
|
||||
shared uint _ne1_sh;
|
||||
#endif
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
@@ -172,7 +177,47 @@ void main() {
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint _ne1 = 0;
|
||||
#ifdef COOPMAT
|
||||
// Spread the search across all elements in the first subgroup
|
||||
if (gl_SubgroupID == 0) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
_ne1 = _ne1_sh;
|
||||
#else
|
||||
_ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||
@@ -183,6 +228,7 @@ void main() {
|
||||
}
|
||||
|
||||
barrier();
|
||||
#endif
|
||||
|
||||
// Workgroup has no work
|
||||
if (ic * BN >= _ne1) return;
|
||||
|
||||
@@ -162,17 +162,32 @@ void main() {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
}
|
||||
@@ -414,17 +429,31 @@ void main() {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||
}
|
||||
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
#endif
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
#endif
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
|
||||
46
ggml/src/ggml-vulkan/vulkan-shaders/roll.comp
Normal file
46
ggml/src/ggml-vulkan/vulkan-shaders/roll.comp
Normal file
@@ -0,0 +1,46 @@
|
||||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
uint wrap_idx(int i, uint ne) {
|
||||
if (i < 0) {
|
||||
return i + ne;
|
||||
} else if (i >= ne) {
|
||||
return i - ne;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
|
||||
const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
|
||||
const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
|
||||
const uint i2_offset = i2*p.ne11*p.ne10;
|
||||
const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
|
||||
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
|
||||
|
||||
const uint p1 = floatBitsToUint(p.param1);
|
||||
const uint p2 = floatBitsToUint(p.param2);
|
||||
const int s0 = int(p1 >> 16) - 0x8000;
|
||||
const int s1 = int(p1 & 0xFFFF) - 0x8000;
|
||||
const int s2 = int(p2 >> 16) - 0x8000;
|
||||
const int s3 = int(p2 & 0xFFFF) - 0x8000;
|
||||
|
||||
const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
|
||||
const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
|
||||
const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
|
||||
const uint i03 = wrap_idx(int(i3) - s3, p.ne13);
|
||||
|
||||
const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
|
||||
const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;
|
||||
|
||||
data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ne; uint a_offset; uint d_offset;
|
||||
uint ne00; uint ne01;
|
||||
uint nb00; uint nb01; uint nb02; uint nb03;
|
||||
uint ne10; uint ne11; uint ne12; uint ne13;
|
||||
float sf0; float sf1; float sf2; float sf3;
|
||||
@@ -15,6 +16,61 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
|
||||
#define NEAREST 0
|
||||
#define BILINEAR 1
|
||||
#define ALIGN_CORNERS (1 << 8)
|
||||
|
||||
layout (constant_id = 0) const uint scale_mode = 0;
|
||||
|
||||
float fetch_nearest(uint i10, uint i11, uint i12, uint i13) {
|
||||
const uint i00 = uint(i10 / p.sf0);
|
||||
const uint i01 = uint(i11 / p.sf1);
|
||||
const uint i02 = uint(i12 / p.sf2);
|
||||
const uint i03 = uint(i13 / p.sf3);
|
||||
|
||||
return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00];
|
||||
}
|
||||
|
||||
float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {
|
||||
const uint i02 = uint(i12 / p.sf2);
|
||||
const uint i03 = uint(i13 / p.sf3);
|
||||
const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;
|
||||
|
||||
const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00];
|
||||
const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00];
|
||||
const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00];
|
||||
const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00];
|
||||
|
||||
return
|
||||
v00 * (1.0-d.x) * (1.0-d.y) +
|
||||
v01 * d.x * (1.0-d.y) +
|
||||
v10 * (1.0-d.x) * d.y +
|
||||
v11 * d.x * d.y;
|
||||
}
|
||||
|
||||
float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
|
||||
const ivec2 ne0 = ivec2(p.ne00, p.ne01);
|
||||
|
||||
const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5;
|
||||
const vec2 c0f = floor(c);
|
||||
const vec2 d = c - c0f;
|
||||
const ivec2 c0 = max(ivec2(c0f), 0);
|
||||
const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1);
|
||||
|
||||
return fetch_bilinear(c0, c1, d, i12, i13);
|
||||
}
|
||||
|
||||
float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) {
|
||||
const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1);
|
||||
const vec2 c0f = floor(c);
|
||||
const vec2 d = c - c0f;
|
||||
const ivec2 c0 = ivec2(c0f);
|
||||
const ivec2 c1 = c0 + 1;
|
||||
|
||||
return fetch_bilinear(c0, c1, d, i12, i13);
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
@@ -27,10 +83,18 @@ void main() {
|
||||
const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
|
||||
const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
|
||||
|
||||
const uint i00 = uint(i10 / p.sf0);
|
||||
const uint i01 = uint(i11 / p.sf1);
|
||||
const uint i02 = uint(i12 / p.sf2);
|
||||
const uint i03 = uint(i13 / p.sf3);
|
||||
float result;
|
||||
switch (scale_mode) {
|
||||
case NEAREST:
|
||||
result = fetch_nearest(i10, i11, i12, i13);
|
||||
break;
|
||||
case BILINEAR:
|
||||
result = interpolate_bilinear(i10, i11, i12, i13);
|
||||
break;
|
||||
case BILINEAR | ALIGN_CORNERS:
|
||||
result = interpolate_bilinear_align_corners(i10, i11, i12, i13);
|
||||
break;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
|
||||
data_d[p.d_offset + idx] = D_TYPE(result);
|
||||
}
|
||||
|
||||
@@ -518,6 +518,11 @@ void process_shaders() {
|
||||
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}
|
||||
|
||||
for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||
string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
}
|
||||
|
||||
auto get_type_str = [](bool f16) {
|
||||
return f16 ? "float16_t" : "float";
|
||||
};
|
||||
@@ -648,6 +653,8 @@ void process_shaders() {
|
||||
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
|
||||
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
||||
@@ -187,6 +187,9 @@ class Keys:
|
||||
class Classifier:
|
||||
OUTPUT_LABELS = "{arch}.classifier.output_labels"
|
||||
|
||||
class ShortConv:
|
||||
L_CACHE = "{arch}.shortconv.l_cache"
|
||||
|
||||
class Tokenizer:
|
||||
MODEL = "tokenizer.ggml.model"
|
||||
PRE = "tokenizer.ggml.pre"
|
||||
@@ -230,6 +233,11 @@ class Keys:
|
||||
TYPE = "adapter.type"
|
||||
LORA_ALPHA = "adapter.lora.alpha"
|
||||
|
||||
class IMatrix:
|
||||
CHUNK_COUNT = "imatrix.chunk_count"
|
||||
CHUNK_SIZE = "imatrix.chunk_size"
|
||||
DATASETS = "imatrix.datasets"
|
||||
|
||||
class Clip:
|
||||
PROJECTOR_TYPE = "clip.projector_type"
|
||||
HAS_VISION_ENCODER = "clip.has_vision_encoder"
|
||||
@@ -279,6 +287,7 @@ class Keys:
|
||||
class GGUFType:
|
||||
MODEL = "model"
|
||||
ADAPTER = "adapter"
|
||||
IMATRIX = "imatrix"
|
||||
MMPROJ = "mmproj" # dummy, unused for now
|
||||
|
||||
|
||||
@@ -362,6 +371,7 @@ class MODEL_ARCH(IntEnum):
|
||||
ERNIE4_5 = auto()
|
||||
HUNYUAN_MOE = auto()
|
||||
SMOLLM3 = auto()
|
||||
LFM2 = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
@@ -533,6 +543,9 @@ class MODEL_TENSOR(IntEnum):
|
||||
POSNET_ATTN_K = auto()
|
||||
POSNET_ATTN_V = auto()
|
||||
POSNET_ATTN_OUT = auto()
|
||||
SHORTCONV_CONV = auto()
|
||||
SHORTCONV_INPROJ = auto()
|
||||
SHORTCONV_OUTPROJ = auto()
|
||||
# vision
|
||||
V_MMPROJ = auto()
|
||||
V_MMPROJ_FC = auto()
|
||||
@@ -673,6 +686,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.FALCON_H1: "falcon-h1",
|
||||
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
|
||||
MODEL_ARCH.SMOLLM3: "smollm3",
|
||||
MODEL_ARCH.LFM2: "lfm2",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
@@ -844,6 +858,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k",
|
||||
MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v",
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
|
||||
MODEL_TENSOR.SHORTCONV_CONV: "blk.{bid}.shortconv.conv",
|
||||
MODEL_TENSOR.SHORTCONV_INPROJ: "blk.{bid}.shortconv.in_proj",
|
||||
MODEL_TENSOR.SHORTCONV_OUTPROJ: "blk.{bid}.shortconv.out_proj",
|
||||
# vision
|
||||
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
|
||||
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
|
||||
@@ -2356,6 +2373,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.LFM2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.SHORTCONV_CONV,
|
||||
MODEL_TENSOR.SHORTCONV_INPROJ,
|
||||
MODEL_TENSOR.SHORTCONV_OUTPROJ,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM, # operator_norm
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
||||
@@ -648,6 +648,9 @@ class GGUFWriter:
|
||||
def add_convnext_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_shortconv_l_cache(self, length: int) -> None:
|
||||
self.add_uint32(Keys.ShortConv.L_CACHE.format(arch=self.arch), length)
|
||||
|
||||
def add_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ class TensorNameMap:
|
||||
"model.pre_ln", # rwkv7
|
||||
"model.layers.0.pre_norm", # rwkv7
|
||||
"backbone.norm", # wavtokenizer
|
||||
"model.embedding_norm", # lfm2
|
||||
),
|
||||
|
||||
# Position embeddings
|
||||
@@ -136,6 +137,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.ln1", # rwkv7
|
||||
"model.layers.{bid}.input_layernorm", # llama4
|
||||
"transformer_encoder.{bid}.attention_norm", # neobert
|
||||
"model.layers.{bid}.operator_norm", # lfm2
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
@@ -220,6 +222,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"h.{bid}.self_attention.dense", # bloom
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"model.layers.{bid}.self_attn.out_proj", # lfm2
|
||||
"model.layers.{bid}.self_attn.linear_attn", # deci
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
@@ -1015,6 +1018,18 @@ class TensorNameMap:
|
||||
"backbone.posnet.{bid}.proj_out", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SHORTCONV_CONV: (
|
||||
"model.layers.{bid}.conv.conv",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SHORTCONV_INPROJ: (
|
||||
"model.layers.{bid}.conv.in_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SHORTCONV_OUTPROJ: (
|
||||
"model.layers.{bid}.conv.out_proj",
|
||||
),
|
||||
|
||||
#############################################################################
|
||||
## Vision encoder
|
||||
|
||||
|
||||
@@ -79,47 +79,6 @@ extern "C" {
|
||||
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
||||
};
|
||||
|
||||
// pre-tokenization types
|
||||
enum llama_vocab_pre_type {
|
||||
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
||||
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
||||
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
||||
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
||||
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
||||
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
||||
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
|
||||
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
|
||||
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
|
||||
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
|
||||
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
||||
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
|
||||
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
|
||||
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
|
||||
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
||||
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
|
||||
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
||||
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
|
||||
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
||||
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
||||
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
LLAMA_ROPE_TYPE_NONE = -1,
|
||||
LLAMA_ROPE_TYPE_NORM = 0,
|
||||
|
||||
@@ -1 +1 @@
|
||||
0405219965324e11a29b6aadfe22a6d66131978f
|
||||
d62df60a07ba3deeb85e5cfc9b1ee07645ff35e2
|
||||
|
||||
@@ -83,6 +83,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
||||
{ LLM_ARCH_LFM2, "lfm2" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -188,6 +189,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
|
||||
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
|
||||
|
||||
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
|
||||
|
||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
||||
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
||||
@@ -1830,6 +1833,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LFM2,
|
||||
{
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
|
||||
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
|
||||
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
}
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
@@ -1997,6 +2021,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
|
||||
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
};
|
||||
|
||||
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
||||
@@ -2068,6 +2095,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||
case LLM_ARCH_JAMBA:
|
||||
case LLM_ARCH_FALCON_H1:
|
||||
case LLM_ARCH_GRANITE_HYBRID:
|
||||
case LLM_ARCH_LFM2:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -87,6 +87,7 @@ enum llm_arch {
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
LLM_ARCH_SMOLLM3,
|
||||
LLM_ARCH_LFM2,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -227,6 +228,8 @@ enum llm_kv {
|
||||
|
||||
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
|
||||
|
||||
LLM_KV_SHORTCONV_L_CACHE,
|
||||
|
||||
// deprecated:
|
||||
LLM_KV_TOKENIZER_PREFIX_ID,
|
||||
LLM_KV_TOKENIZER_SUFFIX_ID,
|
||||
@@ -396,6 +399,9 @@ enum llm_tensor {
|
||||
LLM_TENSOR_POS_NET_ATTN_K,
|
||||
LLM_TENSOR_POS_NET_ATTN_V,
|
||||
LLM_TENSOR_POS_NET_ATTN_OUT,
|
||||
LLM_TENSOR_SHORTCONV_CONV,
|
||||
LLM_TENSOR_SHORTCONV_INPROJ,
|
||||
LLM_TENSOR_SHORTCONV_OUTPROJ,
|
||||
};
|
||||
|
||||
enum llm_tensor_layer {
|
||||
|
||||
@@ -71,6 +71,11 @@ uint32_t llama_hparams::n_embd_r() const {
|
||||
return token_shift_count * n_embd;
|
||||
}
|
||||
|
||||
if (n_shortconv_l_cache != 0) {
|
||||
// for LFM2 models
|
||||
return n_embd * (n_shortconv_l_cache - 1);
|
||||
}
|
||||
|
||||
// TODO: maybe support other convolution strides than 1
|
||||
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
||||
// Corresponds to Mamba's conv_states size
|
||||
|
||||
@@ -55,6 +55,8 @@ struct llama_hparams {
|
||||
struct llama_hparams_posnet posnet;
|
||||
struct llama_hparams_convnext convnext;
|
||||
|
||||
uint32_t n_shortconv_l_cache = 0;
|
||||
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
||||
|
||||
@@ -43,15 +43,18 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_256M: return "256M";
|
||||
case LLM_TYPE_270M: return "270M";
|
||||
case LLM_TYPE_335M: return "335M";
|
||||
case LLM_TYPE_350M: return "350M";
|
||||
case LLM_TYPE_410M: return "410M";
|
||||
case LLM_TYPE_450M: return "450M";
|
||||
case LLM_TYPE_475M: return "475M";
|
||||
case LLM_TYPE_700M: return "700M";
|
||||
case LLM_TYPE_770M: return "770M";
|
||||
case LLM_TYPE_780M: return "780M";
|
||||
case LLM_TYPE_0_3B: return "0.3B";
|
||||
case LLM_TYPE_0_5B: return "0.5B";
|
||||
case LLM_TYPE_0_6B: return "0.6B";
|
||||
case LLM_TYPE_1B: return "1B";
|
||||
case LLM_TYPE_1_2B: return "1.2B";
|
||||
case LLM_TYPE_1_3B: return "1.3B";
|
||||
case LLM_TYPE_1_4B: return "1.4B";
|
||||
case LLM_TYPE_1_5B: return "1.5B";
|
||||
@@ -1663,6 +1666,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_LFM2:
|
||||
{
|
||||
ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
|
||||
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
|
||||
}
|
||||
switch (hparams.n_embd) {
|
||||
case 1024: type = LLM_TYPE_350M; break;
|
||||
case 1536: type = LLM_TYPE_700M; break;
|
||||
case 2048: type = LLM_TYPE_1_2B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
@@ -4906,6 +4923,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_LFM2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
// ffn is same for transformer and conv layers
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
|
||||
// for operator_norm
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
if (!hparams.is_recurrent(i)) {
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, hparams.n_embd_k_gqa(i)}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, hparams.n_embd_v_gqa(i)}, 0);
|
||||
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
} else {
|
||||
layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0);
|
||||
layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0);
|
||||
layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@@ -15859,6 +15909,163 @@ struct llm_build_smollm3 : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_lfm2 : public llm_graph_context {
|
||||
const llama_model & model;
|
||||
|
||||
llm_build_lfm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
|
||||
|
||||
ggml_tensor * cur = build_inp_embd(model.tok_embd);
|
||||
cb(cur, "model.embed_tokens", -1);
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_hybrid = build_inp_mem_hybrid();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto * prev_cur = cur;
|
||||
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "model.layers.{}.operator_norm", il);
|
||||
|
||||
cur = hparams.is_recurrent(il) ?
|
||||
build_shortconv_block(gf, cur, inp_hybrid->get_recr(), il) :
|
||||
build_attn_block(gf, cur, inp_pos, inp_hybrid->get_attn(), il) ;
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
prev_cur = ggml_get_rows(ctx0, prev_cur, inp_out_ids);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, prev_cur, cur);
|
||||
cur = ggml_add(ctx0, cur, build_feed_forward(cur, il));
|
||||
}
|
||||
|
||||
cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1);
|
||||
cb(cur, "model.embedding_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head is tied with embeddings
|
||||
cur = build_lora_mm(model.tok_embd, cur);
|
||||
cb(cur, "lm_head", -1);
|
||||
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
ggml_tensor * build_feed_forward(ggml_tensor * cur,
|
||||
int il) const {
|
||||
cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "model.layers.{}.ffn_norm", il);
|
||||
|
||||
GGML_ASSERT(!model.layers[il].ffn_up_b);
|
||||
GGML_ASSERT(!model.layers[il].ffn_gate_b);
|
||||
GGML_ASSERT(!model.layers[il].ffn_down_b);
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "model.layers.{}.feed_forward.w2", il);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * build_attn_block(ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
llm_graph_input_attn_kv_unified * inp_attn,
|
||||
int il) const {
|
||||
GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
|
||||
auto const n_embd_head = hparams.n_embd_head_v;
|
||||
auto const n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
auto * q = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(q, "model.layers.{}.self_attn.q_proj", il);
|
||||
auto * k = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(k, "model.layers.{}.self_attn.k_proj", il);
|
||||
auto * v = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(v, "model.layers.{}.self_attn.v_proj", il);
|
||||
|
||||
q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens);
|
||||
k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens);
|
||||
v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// qk norm
|
||||
q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(q, "model.layers.{}.self_attn.q_layernorm", il);
|
||||
k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(k, "model.layers.{}.self_attn.k_layernorm", il);
|
||||
|
||||
// RoPE
|
||||
q = ggml_rope_ext(
|
||||
ctx0, q, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
k = ggml_rope_ext(
|
||||
ctx0, k, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL,
|
||||
q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
|
||||
cb(cur, "model.layers.{}.self_attn.out_proj", il);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * build_shortconv_block(ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
llm_graph_input_rs * inp_recr,
|
||||
int il) {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
||||
|
||||
auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
|
||||
cb(bcx, "model.layers.{}.conv.in_proj", il);
|
||||
|
||||
constexpr auto n_chunks = 3;
|
||||
GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
|
||||
auto const chunk_size = bcx->ne[0] / n_chunks;
|
||||
auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx));
|
||||
auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx));
|
||||
auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx));
|
||||
|
||||
auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
|
||||
|
||||
// read conv state directly, with build_rs generation is slower
|
||||
ggml_tensor * conv_state = mctx_cur->get_r_l(il);
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
|
||||
conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
|
||||
|
||||
bx = ggml_concat(ctx0, conv, bx, 0);
|
||||
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
|
||||
|
||||
auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
|
||||
GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
|
||||
|
||||
// write conv state
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state));
|
||||
|
||||
auto * conv_kernel = model.layers[il].shortconv.conv;
|
||||
GGML_ASSERT(hparams.n_shortconv_l_cache > 0);
|
||||
|
||||
// construct ssm_conv op
|
||||
ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
|
||||
cb(conv_out, "model.layers.{}.conv.conv", il);
|
||||
|
||||
auto * y = ggml_mul(ctx0, c, conv_out);
|
||||
|
||||
y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
|
||||
cb(y, "model.layers.{}.conv.out_proj", il);
|
||||
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
||||
llama_memory_i * res;
|
||||
|
||||
@@ -16261,6 +16468,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||
{
|
||||
llm = std::make_unique<llm_build_falcon_h1>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_LFM2:
|
||||
{
|
||||
llm = std::make_unique<llm_build_lfm2>(*this, params, gf);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -16454,6 +16665,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_MINICPM3:
|
||||
case LLM_ARCH_DOTS1:
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
case LLM_ARCH_LFM2:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
||||
@@ -35,15 +35,18 @@ enum llm_type {
|
||||
LLM_TYPE_256M,
|
||||
LLM_TYPE_270M,
|
||||
LLM_TYPE_335M,
|
||||
LLM_TYPE_350M,
|
||||
LLM_TYPE_410M,
|
||||
LLM_TYPE_450M,
|
||||
LLM_TYPE_475M,
|
||||
LLM_TYPE_700M,
|
||||
LLM_TYPE_770M,
|
||||
LLM_TYPE_780M,
|
||||
LLM_TYPE_0_3B,
|
||||
LLM_TYPE_0_5B,
|
||||
LLM_TYPE_0_6B,
|
||||
LLM_TYPE_1B,
|
||||
LLM_TYPE_1_2B,
|
||||
LLM_TYPE_1_3B,
|
||||
LLM_TYPE_1_4B,
|
||||
LLM_TYPE_1_5B,
|
||||
@@ -155,6 +158,12 @@ struct llama_layer_convnext {
|
||||
struct ggml_tensor * gamma = nullptr;
|
||||
};
|
||||
|
||||
struct llama_layer_shortconv {
|
||||
struct ggml_tensor * in_proj = nullptr;
|
||||
struct ggml_tensor * conv = nullptr;
|
||||
struct ggml_tensor * out_proj = nullptr;
|
||||
};
|
||||
|
||||
struct llama_layer {
|
||||
// normalization
|
||||
struct ggml_tensor * attn_norm = nullptr;
|
||||
@@ -341,6 +350,8 @@ struct llama_layer {
|
||||
struct llama_layer_posnet posnet;
|
||||
|
||||
struct llama_layer_convnext convnext;
|
||||
|
||||
struct llama_layer_shortconv shortconv;
|
||||
};
|
||||
|
||||
struct llama_model {
|
||||
|
||||
@@ -844,6 +844,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// do not quantize Mamba's small yet 2D weights
|
||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
||||
quantize &= name.find("shortconv.conv.weight") == std::string::npos;
|
||||
|
||||
// do not quantize RWKV's small yet 2D weights
|
||||
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
||||
|
||||
@@ -1525,7 +1525,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "falcon3" ||
|
||||
tokenizer_pre == "falcon-h1" ||
|
||||
tokenizer_pre == "pixtral" ||
|
||||
tokenizer_pre == "midm-2.0") {
|
||||
tokenizer_pre == "midm-2.0" ||
|
||||
tokenizer_pre == "lfm2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
|
||||
ignore_merges = true;
|
||||
add_bos = true;
|
||||
|
||||
@@ -6,6 +6,47 @@
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
// pre-tokenization types
|
||||
enum llama_vocab_pre_type {
|
||||
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
||||
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
||||
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
||||
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
||||
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
||||
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
||||
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
|
||||
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
|
||||
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
|
||||
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
|
||||
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
||||
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
|
||||
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
|
||||
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
|
||||
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
||||
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
|
||||
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
||||
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
|
||||
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
||||
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
||||
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
struct llama_model_loader;
|
||||
|
||||
|
||||
@@ -4114,6 +4114,32 @@ struct test_pad_reflect_1d : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ROLL
|
||||
struct test_roll : public test_case {
|
||||
const int shift0;
|
||||
const int shift1;
|
||||
const int shift3;
|
||||
const int shift4;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(shift0, shift1, shift3, shift4);
|
||||
}
|
||||
|
||||
test_roll(int shift0 = 3, int shift1 = -2, int shift3 = 1, int shift4 = -1)
|
||||
: shift0(shift0), shift1(shift1), shift3(shift3), shift4(shift4) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
int64_t ne[4] = {10, 5, 4, 3};
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_roll(ctx, a, shift0, shift1, shift3, shift4);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ARANGE
|
||||
struct test_arange : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -5144,9 +5170,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
||||
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
|
||||
for (int64_t d_conv : {3, 4}) {
|
||||
for (int64_t d_inner: {1024, 1536, 2048}) {
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
|
||||
@@ -5484,6 +5514,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_acc());
|
||||
test_cases.emplace_back(new test_pad());
|
||||
test_cases.emplace_back(new test_pad_reflect_1d());
|
||||
test_cases.emplace_back(new test_roll());
|
||||
test_cases.emplace_back(new test_arange());
|
||||
test_cases.emplace_back(new test_timestep_embedding());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
@@ -7,14 +7,15 @@ More information is available here: https://github.com/ggml-org/llama.cpp/pull/4
|
||||
|
||||
```
|
||||
./llama-imatrix \
|
||||
-m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] [--verbosity 1] \
|
||||
-m model.gguf -f some-text.txt [-o imatrix.gguf] [--process-output] \
|
||||
[--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \
|
||||
[--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...]
|
||||
[--in-file imatrix-prev-0.gguf --in-file imatrix-prev-1.gguf ...] \
|
||||
[--parse-special]
|
||||
```
|
||||
|
||||
Here `-m` with a model name and `-f` with a file containing training data (such as e.g. `wiki.train.raw`) are mandatory.
|
||||
The parameters in square brackets are optional and have the following meaning:
|
||||
* `-o` (or `--output-file`) specifies the name of the file where the computed data will be stored. If missing `imatrix.dat` is used.
|
||||
* `-o` (or `--output-file`) specifies the name of the file where the computed data will be stored. If missing `imatrix.gguf` is used.
|
||||
* `--verbosity` specifies the verbosity level. If set to `0`, no output other than the perplexity of the processed chunks will be generated. If set to `1`, each time the results are saved a message is written to `stderr`. If `>=2`, a message is output each time data is collected for any tensor. Default verbosity level is `1`.
|
||||
* `--output-frequency` specifies how often the so far computed result is saved to disk. Default is 10 (i.e., every 10 chunks)
|
||||
* `--save-frequency` specifies how often to save a copy of the imatrix in a separate file. Default is 0 (i.e., never)
|
||||
@@ -25,9 +26,9 @@ For faster computation, make sure to use GPU offloading via the `-ngl` argument
|
||||
## Example
|
||||
|
||||
```bash
|
||||
# generate importance matrix (imatrix.dat)
|
||||
# generate importance matrix (imatrix.gguf)
|
||||
./llama-imatrix -m ggml-model-f16.gguf -f train-data.txt -ngl 99
|
||||
|
||||
# use the imatrix to perform a Q4_K_M quantization
|
||||
./llama-quantize --imatrix imatrix.dat ggml-model-f16.gguf ./ggml-model-q4_k_m.gguf q4_k_m
|
||||
./llama-quantize --imatrix imatrix.gguf ggml-model-f16.gguf ./ggml-model-q4_k_m.gguf q4_k_m
|
||||
```
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
@@ -13,7 +15,7 @@
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
@@ -22,17 +24,20 @@
|
||||
static void print_usage(int, char ** argv) {
|
||||
LOG("\nexample usage:\n");
|
||||
LOG("\n %s \\\n"
|
||||
" -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] \\\n"
|
||||
" -m model.gguf -f some-text.txt [-o imatrix.gguf] [--process-output] \\\n"
|
||||
" [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n"
|
||||
" [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...] \\\n"
|
||||
" [--in-file imatrix-prev-0.gguf --in-file imatrix-prev-1.gguf ...] \\\n"
|
||||
" [--parse-special]\n" , argv[0]);
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
static const char * const LLM_KV_IMATRIX_DATASETS = "imatrix.datasets";
|
||||
static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count";
|
||||
static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size";
|
||||
|
||||
struct Stats {
|
||||
std::vector<float> values;
|
||||
std::vector<int> counts;
|
||||
int ncall = 0;
|
||||
std::vector<float> values;
|
||||
std::vector<int64_t> counts;
|
||||
};
|
||||
|
||||
class IMatrixCollector {
|
||||
@@ -40,13 +45,16 @@ public:
|
||||
IMatrixCollector() = default;
|
||||
void set_params(common_params params) { m_params = std::move(params); }
|
||||
bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data);
|
||||
void save_imatrix(int ncall = -1) const;
|
||||
bool load_imatrix(const char * fname);
|
||||
void save_imatrix_legacy(int32_t ncall = -1) const;
|
||||
void save_imatrix(int32_t n_chunk = -1) const;
|
||||
bool load_imatrix_legacy(const char * fname);
|
||||
bool load_imatrix(const char * file_name);
|
||||
private:
|
||||
std::unordered_map<std::string, Stats> m_stats;
|
||||
common_params m_params;
|
||||
std::mutex m_mutex;
|
||||
int m_last_call = 0;
|
||||
std::vector<std::string> m_datasets;
|
||||
int32_t m_last_chunk = 0;
|
||||
std::vector<char> m_src1_data;
|
||||
std::vector<char> m_ids; // the expert ids from ggml_mul_mat_id
|
||||
};
|
||||
@@ -77,6 +85,8 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
|
||||
const struct ggml_tensor * src1 = t->src[1];
|
||||
std::string wname = filter_tensor_name(src0->name);
|
||||
|
||||
const int32_t chunk_size = m_params.n_ctx / m_params.n_parallel;
|
||||
|
||||
// when ask is true, the scheduler wants to know if we are interested in data from this tensor
|
||||
// if we return true, a follow-up call will be made with ask=false in which we can do the actual collection
|
||||
if (ask) {
|
||||
@@ -102,14 +112,21 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
|
||||
const char * data = is_host ? (const char *) src1->data : m_src1_data.data();
|
||||
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
|
||||
|
||||
// TODO: 4d? (is that even used in practice?)
|
||||
// the extra dimension would need to be stored somewhere to be reflected in the imatrix file
|
||||
if (ggml_nrows(src1) != src1->ne[1] * src1->ne[2]) {
|
||||
LOG_ERR("%s: tensor has more than 3 dimensions: %s", __func__, wname.c_str());
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
// this has been adapted to the new format of storing merged experts in a single 3d tensor
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/6387
|
||||
if (t->op == GGML_OP_MUL_MAT_ID) {
|
||||
// ids -> [n_experts_used, n_tokens]
|
||||
// src1 -> [cols, n_expert_used, n_tokens]
|
||||
const ggml_tensor * ids = t->src[2];
|
||||
const int n_as = src0->ne[2];
|
||||
const int n_ids = ids->ne[0];
|
||||
const int64_t n_as = src0->ne[2];
|
||||
const int64_t n_ids = ids->ne[0];
|
||||
|
||||
// the top-k selected expert ids are stored in the ids tensor
|
||||
// for simplicity, always copy ids to host, because it is small
|
||||
@@ -122,23 +139,29 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
|
||||
|
||||
auto & e = m_stats[wname];
|
||||
|
||||
++e.ncall;
|
||||
|
||||
if (e.counts.size() == 1 && n_as > 1) {
|
||||
// broadcast, when loading an old imatrix
|
||||
e.counts.resize(n_as, e.counts[0]);
|
||||
}
|
||||
if (e.values.empty()) {
|
||||
e.values.resize(src1->ne[0]*n_as, 0);
|
||||
e.counts.resize(src1->ne[0]*n_as, 0);
|
||||
e.counts.resize(n_as, 0);
|
||||
}
|
||||
else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
|
||||
LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
|
||||
LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)(src1->ne[0]*n_as));
|
||||
exit(1); //GGML_ABORT("fatal error");
|
||||
}
|
||||
LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
|
||||
else if (e.counts.size() != (size_t)n_as) {
|
||||
LOG_ERR("%s: inconsistent expert count for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.counts.size(), (int)n_as);
|
||||
exit(1); //GGML_ABORT("fatal error");
|
||||
}
|
||||
LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
|
||||
// loop over all possible experts, regardless if they are used or not in the batch
|
||||
for (int ex = 0; ex < n_as; ++ex) {
|
||||
for (int64_t ex = 0; ex < n_as; ++ex) {
|
||||
size_t e_start = ex*src1->ne[0];
|
||||
|
||||
for (int idx = 0; idx < n_ids; ++idx) {
|
||||
for (int row = 0; row < (int)src1->ne[2]; ++row) {
|
||||
for (int64_t idx = 0; idx < n_ids; ++idx) {
|
||||
for (int64_t row = 0; row < src1->ne[2]; ++row) {
|
||||
const int excur = *(const int32_t *) (m_ids.data() + row*ids->nb[1] + idx*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check
|
||||
@@ -149,57 +172,73 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
|
||||
const int64_t i12 = row;
|
||||
const float * x = (const float *)(data + i11*src1->nb[1] + i12*src1->nb[2]);
|
||||
|
||||
for (int j = 0; j < (int)src1->ne[0]; ++j) {
|
||||
e.values[e_start + j] += x[j]*x[j];
|
||||
e.counts[e_start + j]++;
|
||||
if (!std::isfinite(e.values[e_start + j])) {
|
||||
LOG("\n");
|
||||
LOG_ERR("%f detected in %s\n", e.values[e_start + j], wname.c_str());
|
||||
e.counts[ex]++;
|
||||
|
||||
for (int64_t j = 0; j < src1->ne[0]; ++j) {
|
||||
e.values[e_start + j] += x[j] * x[j];
|
||||
if (!std::isfinite((float)e.values[e_start + j])) {
|
||||
LOG_ERR("%f detected in %s\n", (float)e.values[e_start + j], wname.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (e.ncall > m_last_call) {
|
||||
m_last_call = e.ncall;
|
||||
if (m_last_call % m_params.n_out_freq == 0) {
|
||||
const int32_t n_chunk = e.counts[ex] / chunk_size;
|
||||
if (n_chunk > m_last_chunk) {
|
||||
const int32_t chunk_step = n_chunk - m_last_chunk;
|
||||
m_last_chunk = n_chunk;
|
||||
if ((m_last_chunk % m_params.n_out_freq) / chunk_step == 0) {
|
||||
save_imatrix();
|
||||
}
|
||||
if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) {
|
||||
save_imatrix(m_last_call);
|
||||
if (m_params.n_save_freq > 0 && (m_last_chunk % m_params.n_save_freq) / chunk_step == 0) {
|
||||
save_imatrix(m_last_chunk);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto & e = m_stats[wname];
|
||||
const int64_t n_mat = src1->ne[2] * src1->ne[3];
|
||||
|
||||
if (e.values.empty()) {
|
||||
e.values.resize(src1->ne[0], 0);
|
||||
e.counts.resize(src1->ne[0], 0);
|
||||
e.values.resize(src1->ne[0] * n_mat, 0);
|
||||
e.counts.resize(n_mat, 0);
|
||||
}
|
||||
else if (e.values.size() != (size_t)src1->ne[0]) {
|
||||
LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
|
||||
else if (e.values.size() != (size_t)(src1->ne[0] * n_mat)) {
|
||||
LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)(src1->ne[0] * n_mat));
|
||||
exit(1); //GGML_ABORT("fatal error");
|
||||
}
|
||||
++e.ncall;
|
||||
LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
|
||||
for (int row = 0; row < (int)src1->ne[1]; ++row) {
|
||||
const float * x = (const float *) (data + row * src1->nb[1]);
|
||||
for (int j = 0; j < (int)src1->ne[0]; ++j) {
|
||||
e.values[j] += x[j]*x[j];
|
||||
e.counts[j]++;
|
||||
if (!std::isfinite(e.values[j])) {
|
||||
LOG_ERR("%f detected in %s\n", e.values[j], wname.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
else if (e.counts.size() != (size_t)n_mat) {
|
||||
LOG_ERR("%s: inconsistent expert count for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.counts.size(), (int)n_mat);
|
||||
exit(1); //GGML_ABORT("fatal error");
|
||||
}
|
||||
if (e.ncall > m_last_call) {
|
||||
m_last_call = e.ncall;
|
||||
if (m_last_call % m_params.n_out_freq == 0) {
|
||||
save_imatrix();
|
||||
}
|
||||
if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) {
|
||||
save_imatrix(m_last_call);
|
||||
LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->ne[2], (int)src1->type);
|
||||
for (int64_t i3 = 0; i3 < src1->ne[3]; ++i3) {
|
||||
for (int64_t i2 = 0; i2 < src1->ne[2]; ++i2) {
|
||||
const int64_t mat_id = i3 * src1->ne[2] + i2;
|
||||
const int64_t mat_start = mat_id * src1->ne[0];
|
||||
|
||||
for (int64_t row = 0; row < src1->ne[1]; ++row) {
|
||||
const float * x = (const float *) (data + row * src1->nb[1] + i2 * src1->nb[2] + i3 * src1->ne[3]);
|
||||
e.counts[mat_id]++;
|
||||
for (int64_t j = 0; j < src1->ne[0]; ++j) {
|
||||
e.values[mat_start + j] += x[j] * x[j];
|
||||
if (!std::isfinite((float)e.values[j])) {
|
||||
LOG_ERR("%f detected in %s\n", (float)e.values[j], wname.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
const int32_t n_chunk = e.counts[mat_id] / chunk_size;
|
||||
if (n_chunk > m_last_chunk) {
|
||||
const int32_t chunk_step = n_chunk - m_last_chunk;
|
||||
m_last_chunk = n_chunk;
|
||||
if ((m_last_chunk % m_params.n_out_freq) / chunk_step == 0) {
|
||||
save_imatrix();
|
||||
}
|
||||
if (m_params.n_save_freq > 0 && (m_last_chunk % m_params.n_save_freq) / chunk_step == 0) {
|
||||
save_imatrix(m_last_chunk);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -207,7 +246,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
|
||||
return true;
|
||||
}
|
||||
|
||||
void IMatrixCollector::save_imatrix(int ncall) const {
|
||||
void IMatrixCollector::save_imatrix_legacy(int32_t ncall) const {
|
||||
auto fname = m_params.out_file;
|
||||
|
||||
if (ncall > 0) {
|
||||
@@ -215,7 +254,7 @@ void IMatrixCollector::save_imatrix(int ncall) const {
|
||||
fname += std::to_string(ncall);
|
||||
}
|
||||
|
||||
// avoid writing imatrix entries that do not have full data
|
||||
// warn when writing imatrix entries that do not have full data
|
||||
// this can happen with MoE models where some of the experts end up not being exercised by the provided training data
|
||||
|
||||
int n_entries = 0;
|
||||
@@ -247,8 +286,7 @@ void IMatrixCollector::save_imatrix(int ncall) const {
|
||||
}
|
||||
|
||||
if (n_zeros > 0) {
|
||||
LOG_WRN("%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all);
|
||||
continue;
|
||||
LOG_WRN("%s: entry '%40s' has partial data (%.2f%%)\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all);
|
||||
}
|
||||
|
||||
n_entries++;
|
||||
@@ -259,93 +297,378 @@ void IMatrixCollector::save_imatrix(int ncall) const {
|
||||
LOG_WRN("%s: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size());
|
||||
}
|
||||
|
||||
// deterministic tensor name order
|
||||
std::sort(to_store.begin(), to_store.end());
|
||||
|
||||
const int32_t chunk_size = m_params.n_ctx / m_params.n_parallel;
|
||||
|
||||
std::ofstream out(fname, std::ios::binary);
|
||||
out.write((const char *) &n_entries, sizeof(n_entries));
|
||||
for (const auto & name : to_store) {
|
||||
const auto & stat = m_stats.at(name);
|
||||
int len = name.size();
|
||||
const int32_t len = name.size();
|
||||
out.write((const char *) &len, sizeof(len));
|
||||
out.write(name.c_str(), len);
|
||||
out.write((const char *) &stat.ncall, sizeof(stat.ncall));
|
||||
int nval = stat.values.size();
|
||||
// ceiling division to avoid accidental zeros
|
||||
const int32_t ncall = (*std::max_element(stat.counts.begin(), stat.counts.end()) + (chunk_size - 1)) / chunk_size;
|
||||
out.write((const char *) &ncall, sizeof(ncall));
|
||||
const int32_t nval = stat.values.size();
|
||||
const int32_t nmat = stat.counts.size();
|
||||
out.write((const char *) &nval, sizeof(nval));
|
||||
if (nval > 0) {
|
||||
if (nval > 0 && nmat > 0) {
|
||||
std::vector<float> tmp(nval);
|
||||
for (int i = 0; i < nval; i++) {
|
||||
tmp[i] = (stat.values[i] / static_cast<float>(stat.counts[i])) * static_cast<float>(stat.ncall);
|
||||
for (int32_t i = 0; i < nval; i++) {
|
||||
float count = static_cast<float>(stat.counts[i / (nval / nmat)]);
|
||||
float value = stat.values[i];
|
||||
if (count == 0.0f) {
|
||||
// store 1 for partial data
|
||||
value = 1.0f;
|
||||
count = 1.0f;
|
||||
}
|
||||
tmp[i] = (value / count) * static_cast<float>(ncall);
|
||||
}
|
||||
out.write((const char*)tmp.data(), nval*sizeof(float));
|
||||
out.write((const char *) tmp.data(), nval * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// Write the number of call the matrix was computed with
|
||||
out.write((const char *) &m_last_call, sizeof(m_last_call));
|
||||
out.write((const char *) &m_last_chunk, sizeof(m_last_chunk));
|
||||
|
||||
// Write the input filename at the end of the file to later on specify it in quantize
|
||||
{
|
||||
int len = m_params.prompt_file.size();
|
||||
const char * dataset_file = m_params.prompt_file.c_str();
|
||||
int32_t len = m_params.prompt_file.size();
|
||||
// When there is no prompt but there were other imatrix files loaded, use the last dataset
|
||||
if (m_params.prompt_file.empty() && !m_datasets.empty()) {
|
||||
const std::string & dataset_str = m_datasets[m_datasets.size() - 1];
|
||||
dataset_file = dataset_str.c_str();
|
||||
len = dataset_str.size();
|
||||
}
|
||||
out.write((const char *) &len, sizeof(len));
|
||||
out.write(m_params.prompt_file.c_str(), len);
|
||||
out.write(dataset_file, len);
|
||||
}
|
||||
|
||||
LOGV(1, "\n");
|
||||
LOG_DBGV(1, "%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname.c_str());
|
||||
LOG_DBGV(1, "%s: stored collected data after %d chunks in %s\n", __func__, m_last_chunk, fname.c_str());
|
||||
}
|
||||
|
||||
bool IMatrixCollector::load_imatrix(const char * fname) {
|
||||
void IMatrixCollector::save_imatrix(int32_t n_chunk) const {
|
||||
auto fname = m_params.out_file;
|
||||
|
||||
// TODO: use the new format in more cases
|
||||
if (!string_ends_with(fname, ".gguf")) {
|
||||
LOG_WRN("\n%s: saving to legacy imatrix format because output suffix is not .gguf\n", __func__);
|
||||
this->save_imatrix_legacy(n_chunk);
|
||||
return;
|
||||
}
|
||||
|
||||
if (n_chunk > 0) {
|
||||
fname += ".at_";
|
||||
fname += std::to_string(n_chunk);
|
||||
}
|
||||
|
||||
// write imatrix entries even if they don't have full data. (can be corrected when reading)
|
||||
// this can happen with MoE models where some of the experts end up not being exercised by the provided training data
|
||||
|
||||
std::vector<std::string> to_store;
|
||||
size_t data_size = 0;
|
||||
|
||||
bool is_first = true; // for printing
|
||||
for (const auto & kv : m_stats) {
|
||||
const int n_all = kv.second.counts.size();
|
||||
|
||||
int n_zeros = 0;
|
||||
for (const auto c : kv.second.counts) {
|
||||
if (c == 0) {
|
||||
n_zeros++;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_zeros != 0 && is_first) {
|
||||
LOG_INF("\n");
|
||||
is_first = false;
|
||||
}
|
||||
|
||||
if (n_zeros > 0) {
|
||||
LOG_WRN("%s: entry '%40s' has partial data (%.2f%%)\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all);
|
||||
}
|
||||
|
||||
to_store.push_back(kv.first);
|
||||
data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.values.size(), GGML_MEM_ALIGN);
|
||||
data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.counts.size(), GGML_MEM_ALIGN);
|
||||
}
|
||||
|
||||
// deterministic tensor name order
|
||||
std::sort(to_store.begin(), to_store.end());
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/* .mem_size = */ data_size,
|
||||
/* .mem_buffer = */ NULL,
|
||||
/* .no_alloc = */ false,
|
||||
};
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
struct gguf_context * ctx_gguf = gguf_init_empty();
|
||||
|
||||
{
|
||||
std::vector<const char *> datasets;
|
||||
datasets.reserve(m_datasets.size() + 1);
|
||||
for (size_t i = 0; i < m_datasets.size(); ++i) {
|
||||
datasets.push_back(m_datasets[i].c_str());
|
||||
}
|
||||
if (!m_params.prompt_file.empty()) {
|
||||
datasets.push_back(m_params.prompt_file.c_str());
|
||||
}
|
||||
|
||||
gguf_set_val_str(ctx_gguf, "general.type", "imatrix");
|
||||
// Write the dataset paths
|
||||
gguf_set_arr_str(ctx_gguf, LLM_KV_IMATRIX_DATASETS, datasets.data(), datasets.size());
|
||||
// Write the number of chunks the matrix was computed with
|
||||
gguf_set_val_u32(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT, m_last_chunk);
|
||||
gguf_set_val_u32(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE, m_params.n_ctx / m_params.n_parallel);
|
||||
}
|
||||
|
||||
for (const auto & name : to_store) {
|
||||
const auto & stat = m_stats.at(name);
|
||||
const int32_t nval = (int32_t) stat.values.size();
|
||||
const int32_t nmat = (int32_t) stat.counts.size();
|
||||
if (nval > 0 && nmat > 0) {
|
||||
struct ggml_tensor * in_sum2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nval / nmat, nmat);
|
||||
struct ggml_tensor * counts = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, nmat);
|
||||
ggml_format_name(in_sum2, "%s.in_sum2", name.c_str());
|
||||
ggml_format_name(counts, "%s.counts", name.c_str());
|
||||
|
||||
for (int32_t j = 0; j < nval; ++j) {
|
||||
((float *) in_sum2->data)[j] = (float) stat.values[j];
|
||||
}
|
||||
for (int32_t j = 0; j < nmat; ++j) {
|
||||
((float *) counts->data)[j] = (float) stat.counts[j];
|
||||
}
|
||||
|
||||
gguf_add_tensor(ctx_gguf, in_sum2);
|
||||
gguf_add_tensor(ctx_gguf, counts);
|
||||
}
|
||||
}
|
||||
|
||||
gguf_write_to_file(ctx_gguf, fname.c_str(), false);
|
||||
|
||||
LOGV(1, "\n");
|
||||
LOG_DBGV(1, "%s: stored collected data after %d chunks in %s\n", __func__, m_last_chunk, fname.c_str());
|
||||
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
bool IMatrixCollector::load_imatrix_legacy(const char * fname) {
|
||||
std::ifstream in(fname, std::ios::binary);
|
||||
if (!in) {
|
||||
LOG_ERR("%s: failed to open %s\n",__func__, fname);
|
||||
LOG_ERR("%s: failed to open %s\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
int n_entries;
|
||||
in.read((char*)&n_entries, sizeof(n_entries));
|
||||
in.read((char *) &n_entries, sizeof(n_entries));
|
||||
if (in.fail() || n_entries < 1) {
|
||||
LOG_ERR("%s: no data in file %s\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
// Guess the chunk size because it's not stored in the file
|
||||
const int32_t chunk_size = m_params.n_ctx / m_params.n_parallel;
|
||||
|
||||
for (int i = 0; i < n_entries; ++i) {
|
||||
int len; in.read((char *)&len, sizeof(len));
|
||||
std::vector<char> name_as_vec(len+1);
|
||||
in.read((char *)name_as_vec.data(), len);
|
||||
int32_t len = 0;
|
||||
in.read((char *) &len, sizeof(len));
|
||||
std::vector<char> name_as_vec(len + 1);
|
||||
in.read((char *) name_as_vec.data(), len);
|
||||
if (in.fail()) {
|
||||
LOG_ERR("%s: failed reading name for entry %d from %s\n",__func__,i+1, fname);
|
||||
LOG_ERR("%s: failed reading name for entry %d from %s\n", __func__, i + 1, fname);
|
||||
return false;
|
||||
}
|
||||
name_as_vec[len] = 0;
|
||||
std::string name{name_as_vec.data()};
|
||||
std::string name{ name_as_vec.data() };
|
||||
auto & e = m_stats[std::move(name)];
|
||||
int ncall;
|
||||
in.read((char*)&ncall, sizeof(ncall));
|
||||
int nval;
|
||||
in.read((char *)&nval, sizeof(nval));
|
||||
int32_t ncall = 0;
|
||||
in.read((char *) &ncall, sizeof(ncall));
|
||||
int32_t nval = 0;
|
||||
in.read((char *) &nval, sizeof(nval));
|
||||
if (in.fail() || nval < 1) {
|
||||
LOG_ERR("%s: failed reading number of values for entry %d\n",__func__,i);
|
||||
LOG_ERR("%s: failed reading number of values for entry %d\n", __func__, i);
|
||||
m_stats = {};
|
||||
return false;
|
||||
}
|
||||
|
||||
if (e.values.empty()) {
|
||||
e.values.resize(nval, 0);
|
||||
e.counts.resize(nval, 0);
|
||||
e.values.resize(nval, 0.0f);
|
||||
e.counts.resize(1, 0);
|
||||
}
|
||||
|
||||
std::vector<float> tmp(nval);
|
||||
in.read((char*)tmp.data(), nval*sizeof(float));
|
||||
in.read((char *) tmp.data(), nval * sizeof(float));
|
||||
if (in.fail()) {
|
||||
LOG_ERR("%s: failed reading data for entry %d\n",__func__,i);
|
||||
LOG_ERR("%s: failed reading data for entry %d\n", __func__, i);
|
||||
m_stats = {};
|
||||
return false;
|
||||
}
|
||||
|
||||
// Recreate the state as expected by save_imatrix(), and corerct for weighted sum.
|
||||
// Recreate the state as expected by save_imatrix(), and correct for weighted sum.
|
||||
for (int i = 0; i < nval; i++) {
|
||||
e.values[i] += tmp[i];
|
||||
e.counts[i] += ncall;
|
||||
e.values[i] += tmp[i] * chunk_size;
|
||||
}
|
||||
// The legacy format doesn't distinguish the counts for different experts
|
||||
for (size_t j = 0; j < e.counts.size(); ++j) {
|
||||
e.counts[j] += ncall * chunk_size;
|
||||
}
|
||||
e.ncall += ncall;
|
||||
|
||||
}
|
||||
|
||||
{
|
||||
// TODO: extract into its own method; this is also used by the GGUF-based format
|
||||
// Calculate the last chunk count
|
||||
int64_t max_count = 0;
|
||||
for (const auto & stats : m_stats) {
|
||||
for (int64_t count : stats.second.counts) {
|
||||
if (count > max_count) {
|
||||
max_count = count;
|
||||
}
|
||||
}
|
||||
}
|
||||
m_last_chunk = max_count / (chunk_size);
|
||||
}
|
||||
|
||||
{
|
||||
// Read the number of calls the matrix was computed with
|
||||
int32_t n_calls;
|
||||
in.read((char *) &n_calls, sizeof(n_calls));
|
||||
// ignore it because it's not important
|
||||
}
|
||||
|
||||
// Read the dataset path to include it when writing to GGUF
|
||||
if (!in.fail()){
|
||||
int32_t len = 0;
|
||||
in.read((char *) &len, sizeof(len));
|
||||
if (!in.fail()) {
|
||||
std::vector<char> dataset;
|
||||
dataset.resize(len + 1, 0);
|
||||
in.read(dataset.data(), len);
|
||||
if (!in.fail()) {
|
||||
m_datasets.push_back(dataset.data());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Using GGUF as the file format, for greater extensibility
|
||||
bool IMatrixCollector::load_imatrix(const char * file_name) {
|
||||
struct ggml_context * ctx = nullptr;
|
||||
struct gguf_init_params meta_gguf_params = {
|
||||
/* .no_alloc = */ false, // the data is needed
|
||||
/* .ctx = */ &ctx,
|
||||
};
|
||||
struct gguf_context * ctx_gguf = gguf_init_from_file(file_name, meta_gguf_params);
|
||||
if (!ctx_gguf) {
|
||||
return this->load_imatrix_legacy(file_name);
|
||||
}
|
||||
const int32_t n_entries = gguf_get_n_tensors(ctx_gguf);
|
||||
if (n_entries < 1) {
|
||||
LOG_ERR("%s: no data in file %s\n", __func__, file_name);
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t datasets_key = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_DATASETS);
|
||||
if (datasets_key != -1 && gguf_get_arr_type(ctx_gguf, datasets_key) == GGUF_TYPE_STRING) {
|
||||
const int64_t n = gguf_get_arr_n(ctx_gguf, datasets_key);
|
||||
m_datasets.reserve(m_datasets.size() + n);
|
||||
for (int64_t i = 0; i < n; ++i) {
|
||||
m_datasets.push_back(gguf_get_arr_str(ctx_gguf, datasets_key, i));
|
||||
}
|
||||
}
|
||||
|
||||
const std::string in_sum2_suffix{ ".in_sum2" };
|
||||
const std::string counts_suffix{ ".counts" };
|
||||
|
||||
// Could re-use m_stats instead, but this allows
|
||||
// checking for completeness of *each* loaded imatrix file
|
||||
// and also makes it easier to re-use a similar implementation in quantize.cpp
|
||||
// Using an ordered map to get a deterministic iteration order.
|
||||
std::map<std::string, std::pair<struct ggml_tensor *, struct ggml_tensor *>> sums_counts_for;
|
||||
|
||||
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
|
||||
std::string name = cur->name;
|
||||
|
||||
if (name.empty()) { continue; }
|
||||
|
||||
if (string_remove_suffix(name, in_sum2_suffix)) {
|
||||
// in_sum2
|
||||
sums_counts_for[std::move(name)].first = cur;
|
||||
} else if (string_remove_suffix(name, counts_suffix)) {
|
||||
// counts
|
||||
sums_counts_for[std::move(name)].second = cur;
|
||||
} else {
|
||||
// ignore other tensors
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & sc : sums_counts_for) {
|
||||
const std::string & name = sc.first;
|
||||
const struct ggml_tensor * in_sum2 = sc.second.first;
|
||||
const struct ggml_tensor * counts = sc.second.second;
|
||||
|
||||
if (!in_sum2 || !counts) {
|
||||
LOG_ERR("%s: mismatched sums and counts for %s\n", __func__, name.c_str());
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto & e = m_stats[name];
|
||||
|
||||
int64_t nval = ggml_nelements(in_sum2);
|
||||
if (e.values.empty()) {
|
||||
e.values.resize(nval, 0.0f);
|
||||
} else if ((size_t) nval != e.values.size()) {
|
||||
LOG_ERR("%s: mismatched sums size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) nval, e.values.size());
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t ncounts = ggml_nelements(counts);
|
||||
if (e.counts.empty()) {
|
||||
e.counts.resize(ncounts, 0);
|
||||
} else if (e.counts.size() == 1 && ncounts > 1) {
|
||||
// broadcast, when loading an old imatrix
|
||||
e.counts.resize(ncounts, e.counts[0]);
|
||||
} else if ((size_t) ncounts != e.counts.size()) {
|
||||
LOG_ERR("%s: mismatched counts size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) ncounts, e.counts.size());
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Recreate the state as expected by save_imatrix()
|
||||
for (int64_t j = 0; j < nval; j++) {
|
||||
e.values[j] += ((const float *) in_sum2->data)[j];
|
||||
}
|
||||
for (int64_t j = 0; j < ncounts; j++) {
|
||||
e.counts[j] += std::lround(((const float *) counts->data)[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: extract into its own method; this is also used by the legacy format
|
||||
// Calculate the last chunk count
|
||||
int64_t max_count = 0;
|
||||
for (const auto & stats : m_stats) {
|
||||
for (int64_t count : stats.second.counts) {
|
||||
if (count > max_count) {
|
||||
max_count = count;
|
||||
}
|
||||
}
|
||||
}
|
||||
m_last_chunk = max_count / (m_params.n_ctx / m_params.n_parallel);
|
||||
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -428,12 +751,11 @@ static void process_logits(
|
||||
}
|
||||
}
|
||||
|
||||
static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
||||
static bool compute_imatrix(llama_context * ctx, const common_params & params, const int32_t n_ctx) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
|
||||
|
||||
@@ -478,45 +800,61 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
||||
double nll = 0.0;
|
||||
double nll2 = 0.0;
|
||||
|
||||
LOG_INF("%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch);
|
||||
|
||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||
|
||||
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
||||
const int n_seq = std::max(1, n_batch / n_ctx);
|
||||
|
||||
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
|
||||
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
|
||||
|
||||
std::vector<float> logits;
|
||||
if (params.compute_ppl && num_batches > 1) {
|
||||
logits.reserve((size_t)n_ctx * n_vocab);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_chunk; ++i) {
|
||||
LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
|
||||
|
||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||
|
||||
for (int i = 0; i < n_chunk; i += n_seq) {
|
||||
const int start = i * n_ctx;
|
||||
const int end = start + n_ctx;
|
||||
|
||||
std::vector<float> logits;
|
||||
const int n_seq_batch = std::min(n_seq, n_chunk - i);
|
||||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
|
||||
// save original token and restore it after eval
|
||||
const auto token_org = tokens[batch_start];
|
||||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[batch_start] = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
// clear the batch
|
||||
common_batch_clear(batch);
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
||||
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
int seq_start = batch_start + seq*n_ctx;
|
||||
|
||||
// save original token and restore it after eval
|
||||
const auto token_org = tokens[seq_start];
|
||||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[seq_start] = llama_vocab_bos(vocab);
|
||||
}
|
||||
for (int k = 0; k < batch_size; ++k) {
|
||||
// NOTE: specifying all logits to get activations for the output.weight tensor
|
||||
// and also for the perplexity calculation.
|
||||
// TODO: only get outputs when (params.process_output || params.compute_ppl)
|
||||
// (not possible when this skips FFN computation of the last layer)
|
||||
common_batch_add(batch, tokens[seq_start + k], j*n_batch + k, { seq }, true);
|
||||
}
|
||||
|
||||
// restore the original token in case it was set to BOS
|
||||
tokens[seq_start] = token_org;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
@@ -525,23 +863,19 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// restore the original token in case it was set to BOS
|
||||
tokens[batch_start] = token_org;
|
||||
|
||||
if (params.compute_ppl && num_batches > 1) {
|
||||
const auto * batch_logits = llama_get_logits(ctx);
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
if (i == 0) {
|
||||
llama_synchronize(ctx);
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||
int total_seconds = (int)(t_total * n_chunk);
|
||||
int total_seconds = (int)(t_total * n_chunk / n_seq);
|
||||
if (total_seconds >= 60*60) {
|
||||
LOG("%d hours ", total_seconds / (60*60));
|
||||
total_seconds = total_seconds % (60*60);
|
||||
@@ -551,17 +885,27 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
||||
|
||||
if (params.compute_ppl) {
|
||||
const int first = n_ctx/2;
|
||||
const auto * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
||||
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
||||
count += n_ctx - first - 1;
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
|
||||
|
||||
LOG("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
||||
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
||||
|
||||
process_logits(n_vocab, all_logits + first*n_vocab,
|
||||
tokens_data, n_ctx - 1 - first,
|
||||
workers, nll, nll2,
|
||||
logit_history.data() + start + seq*n_ctx + first,
|
||||
prob_history.data() + start + seq*n_ctx + first);
|
||||
|
||||
count += n_ctx - first - 1;
|
||||
|
||||
LOG("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
|
||||
}
|
||||
fflush(stdout);
|
||||
|
||||
logits.clear();
|
||||
}
|
||||
}
|
||||
|
||||
LOG("\n");
|
||||
|
||||
if (params.compute_ppl) {
|
||||
@@ -577,13 +921,15 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.out_file = "imatrix.dat" ;
|
||||
params.out_file = "imatrix.gguf";
|
||||
|
||||
params.n_ctx = 512;
|
||||
params.escape = false;
|
||||
@@ -594,7 +940,22 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_init();
|
||||
|
||||
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
||||
const int32_t n_ctx = params.n_ctx;
|
||||
|
||||
if (n_ctx <= 0) {
|
||||
LOG_ERR("%s: imatrix tool requires '--ctx-size' > 0\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
{
|
||||
const int32_t n_seq = std::max(1, params.n_batch / n_ctx);
|
||||
const int32_t n_kv = n_seq * n_ctx;
|
||||
|
||||
params.n_parallel = n_seq;
|
||||
params.n_ctx = n_kv;
|
||||
|
||||
params.n_batch = std::min(params.n_batch, n_kv);
|
||||
}
|
||||
|
||||
g_collector.set_params(params);
|
||||
|
||||
@@ -606,9 +967,23 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
if (params.in_files.size() > 1) {
|
||||
LOG_INF("%s : saving combined imatrix to '%s'\n", __func__, params.out_file.c_str());
|
||||
if (params.prompt.empty()) {
|
||||
LOG_INF("No prompt provided; combining precomputed matrices only.\n");
|
||||
|
||||
if (params.in_files.empty()) {
|
||||
LOG_ERR("Error: No prompt provided and no precomputed matrices (--in-file) to combine.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.in_files.size() == 1) {
|
||||
LOG_INF("%s : saving imatrix to '%s'\n", __func__, params.out_file.c_str());
|
||||
} else if (params.in_files.size() > 1) {
|
||||
LOG_INF("%s : saving combined imatrix to '%s'\n", __func__, params.out_file.c_str());
|
||||
}
|
||||
|
||||
g_collector.save_imatrix();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
@@ -643,19 +1018,10 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
}
|
||||
|
||||
if (params.prompt.empty()) {
|
||||
if (params.in_files.empty()) {
|
||||
LOG_ERR("Error: No prompt provided and no precomputed matrices (--in-file) to combine.\n");
|
||||
return 1;
|
||||
}
|
||||
LOG_INF("No prompt provided; combining precomputed matrices only.\n");
|
||||
} else {
|
||||
if (!compute_imatrix(ctx, params)) {
|
||||
return 1;
|
||||
}
|
||||
if (!compute_imatrix(ctx, params, n_ctx)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
g_collector.save_imatrix();
|
||||
|
||||
LOG("\n");
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
#include <fstream>
|
||||
#include <cmath>
|
||||
#include <cctype>
|
||||
@@ -68,6 +70,11 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix
|
||||
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
|
||||
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count";
|
||||
|
||||
// TODO: share with imatrix.cpp
|
||||
static const char * const LLM_KV_IMATRIX_DATASETS = "imatrix.datasets";
|
||||
static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count";
|
||||
static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size";
|
||||
|
||||
static bool striequals(const char * a, const char * b) {
|
||||
while (*a && *b) {
|
||||
if (std::tolower(*a) != std::tolower(*b)) {
|
||||
@@ -84,7 +91,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
|
||||
for (auto ch : ftype_str_in) {
|
||||
ftype_str.push_back(std::toupper(ch));
|
||||
}
|
||||
for (auto & it : QUANT_OPTIONS) {
|
||||
for (const auto & it : QUANT_OPTIONS) {
|
||||
if (striequals(it.name.c_str(), ftype_str.c_str())) {
|
||||
ftype = it.ftype;
|
||||
ftype_str_out = it.name;
|
||||
@@ -93,7 +100,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
|
||||
}
|
||||
try {
|
||||
int ftype_int = std::stoi(ftype_str);
|
||||
for (auto & it : QUANT_OPTIONS) {
|
||||
for (const auto & it : QUANT_OPTIONS) {
|
||||
if (it.ftype == ftype_int) {
|
||||
ftype = it.ftype;
|
||||
ftype_str_out = it.name;
|
||||
@@ -129,7 +136,7 @@ static void usage(const char * executable) {
|
||||
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
|
||||
printf("Note: --include-weights and --exclude-weights cannot be used together\n");
|
||||
printf("\nAllowed quantization types:\n");
|
||||
for (auto & it : QUANT_OPTIONS) {
|
||||
for (const auto & it : QUANT_OPTIONS) {
|
||||
if (it.name != "COPY") {
|
||||
printf(" %2d or ", it.ftype);
|
||||
} else {
|
||||
@@ -140,7 +147,7 @@ static void usage(const char * executable) {
|
||||
exit(1);
|
||||
}
|
||||
|
||||
static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_dataset, std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
|
||||
static int load_legacy_imatrix(const std::string & imatrix_file, std::vector<std::string> & imatrix_datasets, std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
|
||||
std::ifstream in(imatrix_file.c_str(), std::ios::binary);
|
||||
if (!in) {
|
||||
printf("%s: failed to open %s\n",__func__, imatrix_file.c_str());
|
||||
@@ -180,7 +187,9 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_
|
||||
exit(1);
|
||||
}
|
||||
if (ncall > 0) {
|
||||
for (auto& v : e) v /= ncall;
|
||||
for (auto & v : e) {
|
||||
v /= ncall;
|
||||
}
|
||||
}
|
||||
|
||||
if (getenv("LLAMA_TRACE")) {
|
||||
@@ -188,7 +197,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_
|
||||
}
|
||||
}
|
||||
|
||||
// latest imatrix version contains the dataset filename at the end of the file
|
||||
// latest legacy imatrix version contains the dataset filename at the end of the file
|
||||
int m_last_call = 0;
|
||||
if (in.peek() != EOF) {
|
||||
in.read((char *)&m_last_call, sizeof(m_last_call));
|
||||
@@ -196,15 +205,130 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_
|
||||
in.read((char *)&dataset_len, sizeof(dataset_len));
|
||||
std::vector<char> dataset_as_vec(dataset_len);
|
||||
in.read(dataset_as_vec.data(), dataset_len);
|
||||
imatrix_dataset.assign(dataset_as_vec.begin(), dataset_as_vec.end());
|
||||
printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str());
|
||||
imatrix_datasets.resize(1);
|
||||
imatrix_datasets[0].assign(dataset_as_vec.begin(), dataset_as_vec.end());
|
||||
printf("%s: imatrix dataset='%s'\n", __func__, imatrix_datasets[0].c_str());
|
||||
}
|
||||
printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_call);
|
||||
return m_last_call;
|
||||
}
|
||||
|
||||
static int load_imatrix(const std::string & imatrix_file, std::vector<std::string> & imatrix_datasets, std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
|
||||
|
||||
struct ggml_context * ctx = nullptr;
|
||||
struct gguf_init_params meta_gguf_params = {
|
||||
/* .no_alloc = */ false, // the data is needed
|
||||
/* .ctx = */ &ctx,
|
||||
};
|
||||
struct gguf_context * ctx_gguf = gguf_init_from_file(imatrix_file.c_str(), meta_gguf_params);
|
||||
if (!ctx_gguf) {
|
||||
fprintf(stderr, "%s: imatrix file '%s' is using old format\n", __func__, imatrix_file.c_str());
|
||||
return load_legacy_imatrix(imatrix_file, imatrix_datasets, imatrix_data);
|
||||
}
|
||||
const int32_t n_entries = gguf_get_n_tensors(ctx_gguf);
|
||||
if (n_entries < 1) {
|
||||
fprintf(stderr, "%s: no data in file %s\n", __func__, imatrix_file.c_str());
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_free(ctx);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const int dataset_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_DATASETS);
|
||||
const int chunk_count_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT);
|
||||
const int chunk_size_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE);
|
||||
if (dataset_idx < 0 || chunk_count_idx < 0 || chunk_size_idx < 0) {
|
||||
fprintf(stderr, "%s: missing imatrix metadata in file %s\n", __func__, imatrix_file.c_str());
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_free(ctx);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const uint32_t chunk_size = gguf_get_val_u32(ctx_gguf, chunk_size_idx);
|
||||
|
||||
const std::string sums_suffix{ ".in_sum2" };
|
||||
const std::string counts_suffix{ ".counts" };
|
||||
|
||||
// Using an ordered map to get a deterministic iteration order.
|
||||
std::map<std::string, std::pair<struct ggml_tensor *, struct ggml_tensor *>> sums_counts_for;
|
||||
|
||||
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
|
||||
std::string name = cur->name;
|
||||
|
||||
if (name.empty()) { continue; }
|
||||
|
||||
if (string_remove_suffix(name, sums_suffix)) {
|
||||
// in_sum2
|
||||
sums_counts_for[std::move(name)].first = cur;
|
||||
} else if (string_remove_suffix(name, counts_suffix)) {
|
||||
// counts
|
||||
sums_counts_for[std::move(name)].second = cur;
|
||||
} else {
|
||||
// ignore other tensors
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & sc : sums_counts_for) {
|
||||
const std::string & name = sc.first;
|
||||
const struct ggml_tensor * sums = sc.second.first;
|
||||
const struct ggml_tensor * counts = sc.second.second;
|
||||
|
||||
if (!sums || !counts) {
|
||||
fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str());
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_free(ctx);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const int64_t ne0 = sums->ne[0];
|
||||
const int64_t ne1 = sums->ne[1];
|
||||
|
||||
auto & e = imatrix_data[name];
|
||||
e.resize(ggml_nelements(sums));
|
||||
float max_count = 0.0f;
|
||||
for (int64_t j = 0; j < ne1; ++j) {
|
||||
const float count = ((const float *) counts->data)[j];
|
||||
if (count > 0.0f) {
|
||||
for (int64_t i = 0; i < ne0; ++i) {
|
||||
e[j*ne0 + i] = ((const float *) sums->data)[j*ne0 + i] / count;
|
||||
}
|
||||
} else {
|
||||
// Partial imatrix data, this tensor never got any input during calibration
|
||||
for (int64_t i = 0; i < ne0; ++i) {
|
||||
e[j*ne0 + i] = 1;
|
||||
}
|
||||
}
|
||||
if (count > max_count) {
|
||||
max_count = count;
|
||||
}
|
||||
}
|
||||
if (getenv("LLAMA_TRACE")) {
|
||||
printf("%s: loaded data (size = %6d, n_tokens = %6d, n_chunks = %6d) for '%s'\n", __func__, int(e.size()), int(max_count), int(max_count / chunk_size), name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
int m_last_chunk = gguf_get_val_u32(ctx_gguf, chunk_count_idx);
|
||||
|
||||
int64_t n_datasets = gguf_get_arr_n(ctx_gguf, dataset_idx);
|
||||
imatrix_datasets.reserve(n_datasets);
|
||||
for (int64_t i = 0; i < n_datasets; ++i) {
|
||||
imatrix_datasets.push_back(gguf_get_val_str(ctx_gguf, dataset_idx));
|
||||
}
|
||||
printf("%s: imatrix datasets=['%s'", __func__, imatrix_datasets[0].c_str());
|
||||
for (size_t i = 1; i < imatrix_datasets.size(); ++i) {
|
||||
printf(", '%s'", imatrix_datasets[i].c_str());
|
||||
}
|
||||
printf("]\n");
|
||||
|
||||
printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_chunk);
|
||||
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_free(ctx);
|
||||
|
||||
return m_last_chunk;
|
||||
}
|
||||
|
||||
static int prepare_imatrix(const std::string & imatrix_file,
|
||||
std::string & imatrix_dataset,
|
||||
std::vector<std::string> & imatrix_dataset,
|
||||
const std::vector<std::string> & included_weights,
|
||||
const std::vector<std::string> & excluded_weights,
|
||||
std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
|
||||
@@ -216,18 +340,21 @@ static int prepare_imatrix(const std::string & imatrix_file,
|
||||
return m_last_call;
|
||||
}
|
||||
if (!excluded_weights.empty()) {
|
||||
for (auto& name : excluded_weights) {
|
||||
for (auto it = imatrix_data.begin(); it != imatrix_data.end(); ) {
|
||||
for (const auto & name : excluded_weights) {
|
||||
for (auto it = imatrix_data.begin(); it != imatrix_data.end();) {
|
||||
auto pos = it->first.find(name);
|
||||
if (pos != std::string::npos) it = imatrix_data.erase(it);
|
||||
else ++it;
|
||||
if (pos != std::string::npos) {
|
||||
it = imatrix_data.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!included_weights.empty()) {
|
||||
std::unordered_map<std::string, std::vector<float>> tmp;
|
||||
for (auto& name : included_weights) {
|
||||
for (auto& e : imatrix_data) {
|
||||
for (const auto & name : included_weights) {
|
||||
for (auto & e : imatrix_data) {
|
||||
auto pos = e.first.find(name);
|
||||
if (pos != std::string::npos) {
|
||||
tmp.emplace(std::move(e));
|
||||
@@ -396,9 +523,9 @@ int main(int argc, char ** argv) {
|
||||
usage(argv[0]);
|
||||
}
|
||||
|
||||
std::string imatrix_dataset;
|
||||
std::vector<std::string> imatrix_datasets;
|
||||
std::unordered_map<std::string, std::vector<float>> imatrix_data;
|
||||
int m_last_call = prepare_imatrix(imatrix_file, imatrix_dataset, included_weights, excluded_weights, imatrix_data);
|
||||
int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, imatrix_data);
|
||||
if (!imatrix_data.empty()) {
|
||||
params.imatrix = &imatrix_data;
|
||||
{
|
||||
@@ -409,11 +536,12 @@ int main(int argc, char ** argv) {
|
||||
kvo.val_str[127] = '\0';
|
||||
kv_overrides.emplace_back(std::move(kvo));
|
||||
}
|
||||
if (!imatrix_dataset.empty()) {
|
||||
if (!imatrix_datasets.empty()) {
|
||||
llama_model_kv_override kvo;
|
||||
// TODO: list multiple datasets when there are more than one
|
||||
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_DATASET);
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
|
||||
strncpy(kvo.val_str, imatrix_dataset.c_str(), 127);
|
||||
strncpy(kvo.val_str, imatrix_datasets[0].c_str(), 127);
|
||||
kvo.val_str[127] = '\0';
|
||||
kv_overrides.emplace_back(std::move(kvo));
|
||||
}
|
||||
|
||||
@@ -2581,12 +2581,14 @@ struct server_context {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
if (embd == NULL) {
|
||||
const float * embd = nullptr;
|
||||
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
} else {
|
||||
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
}
|
||||
|
||||
if (embd == NULL) {
|
||||
if (embd == nullptr) {
|
||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
||||
|
||||
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||
@@ -2594,12 +2596,12 @@ struct server_context {
|
||||
}
|
||||
|
||||
// normalize only when there is pooling
|
||||
// TODO: configurable
|
||||
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
||||
res->embedding.push_back(embd_res);
|
||||
break;
|
||||
} else {
|
||||
res->embedding.push_back({ embd, embd + n_embd });
|
||||
res->embedding.emplace_back(embd, embd + n_embd);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user