mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-23 16:37:33 +03:00
Compare commits
7 Commits
b2867
...
fix-conver
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
284870c868 | ||
|
|
ee52225067 | ||
|
|
614d3b914e | ||
|
|
30e70334f7 | ||
|
|
1c570d8bee | ||
|
|
948f4ec7c5 | ||
|
|
941de11759 |
@@ -240,23 +240,6 @@ class Model:
|
||||
return False
|
||||
|
||||
def write_tensors(self):
|
||||
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
|
||||
def np_fp32_to_bf16(n: np.ndarray):
|
||||
# force nan to quiet
|
||||
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
|
||||
# flush subnormals to zero
|
||||
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
|
||||
# round to nearest even
|
||||
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
|
||||
return n.astype(np.int16)
|
||||
|
||||
# Doing this row-wise is much, much faster than element-wise, hence the signature
|
||||
v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
|
||||
if self.lazy:
|
||||
# TODO: find a way to implicitly wrap np.vectorize functions
|
||||
# NOTE: the type is changed to reflect otypes passed to np.vectorize above
|
||||
v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
|
||||
|
||||
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
||||
|
||||
for name, data_torch in self.get_tensors():
|
||||
@@ -309,27 +292,31 @@ class Model:
|
||||
))
|
||||
|
||||
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
|
||||
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
|
||||
if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||
data = gguf.quantize_bf16(data)
|
||||
assert data.dtype == np.int16
|
||||
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
|
||||
data = gguf.quantize_q8_0(data)
|
||||
assert data.dtype == np.uint8
|
||||
data_qtype = gguf.GGMLQuantizationType.Q8_0
|
||||
|
||||
else: # default to float16 for quantized tensors
|
||||
if data_dtype != np.float16:
|
||||
data = data.astype(np.float16)
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||
if data_dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
data = v_fp32_to_bf16(data.view(np.int32))
|
||||
assert data.dtype == np.int16
|
||||
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||
|
||||
else: # by default, convert to float32
|
||||
if data_qtype is None: # by default, convert to float32
|
||||
if data_dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
|
||||
assert data_qtype is not None
|
||||
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[data_qtype]
|
||||
# reverse shape to make it similar to the internal ggml dimension order
|
||||
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
|
||||
shape_str = f"""{{{', '.join(str(n) for n in reversed(
|
||||
(*data.shape[:-1], data.shape[-1] * data.dtype.itemsize // type_size * block_size))
|
||||
)}}}"""
|
||||
|
||||
# n_dims is implicit in the shape
|
||||
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
|
||||
@@ -859,6 +846,7 @@ class BaichuanModel(Model):
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||
@@ -981,6 +969,7 @@ class XverseModel(Model):
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||
@@ -1215,6 +1204,7 @@ class StableLMModel(Model):
|
||||
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
|
||||
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
|
||||
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
_q_norms: list[dict[str, Tensor]] | None = None
|
||||
_k_norms: list[dict[str, Tensor]] | None = None
|
||||
@@ -1591,6 +1581,7 @@ class QwenModel(Model):
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
|
||||
@Model.register("Qwen2ForCausalLM")
|
||||
@@ -1828,6 +1819,7 @@ class PlamoModel(Model):
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
|
||||
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def shuffle_attn_q_weight(self, data_torch):
|
||||
assert data_torch.size() == (5120, 5120)
|
||||
@@ -2007,6 +1999,7 @@ in chat mode so that the conversation can end normally.")
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
num_heads = self.hparams["num_attention_heads"]
|
||||
@@ -2415,25 +2408,15 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
def numpy(self) -> gguf.LazyNumpyTensor:
|
||||
dtype = self._dtype_map[self.dtype]
|
||||
return gguf.LazyNumpyTensor(
|
||||
meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
|
||||
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
|
||||
lazy=self._lazy,
|
||||
args=(self,),
|
||||
func=(lambda s: s[0].numpy())
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def eager_to_meta(cls, t: Tensor) -> Tensor:
|
||||
if t.is_meta:
|
||||
return t
|
||||
return t.detach().to("meta")
|
||||
|
||||
@classmethod
|
||||
def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
|
||||
m = m.detach()
|
||||
if not m.is_meta:
|
||||
m = m.to("meta")
|
||||
m.dtype = dtype
|
||||
return m
|
||||
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor:
|
||||
return torch.empty(size=shape, dtype=dtype, device="meta")
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
@@ -2464,8 +2447,8 @@ def parse_args() -> argparse.Namespace:
|
||||
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bigendian", action="store_true",
|
||||
@@ -2523,6 +2506,7 @@ def main() -> None:
|
||||
"f32": gguf.LlamaFileType.ALL_F32,
|
||||
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
||||
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
||||
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
|
||||
"auto": gguf.LlamaFileType.GUESSED,
|
||||
}
|
||||
|
||||
|
||||
@@ -1109,7 +1109,7 @@ class OutputFile:
|
||||
if metadata is not None and metadata.name is not None:
|
||||
name = metadata.name
|
||||
elif params.path_model is not None:
|
||||
name = str(params.path_model.parent).split("/")[-1]
|
||||
name = params.path_model.name
|
||||
elif params.n_ctx == 4096:
|
||||
# Heuristic detection of LLaMA v2 model
|
||||
name = "LLaMA v2"
|
||||
|
||||
@@ -300,14 +300,10 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (auto & image : params.image) {
|
||||
if (prompt_contains_image(params.prompt)) {
|
||||
auto ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
auto image_embed = load_image(ctx_llava, ¶ms, image);
|
||||
if (!image_embed) {
|
||||
std::cerr << "error: failed to load image " << image << ". Terminating\n\n";
|
||||
return 1;
|
||||
}
|
||||
auto image_embed = load_image(ctx_llava, ¶ms, "");
|
||||
|
||||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
@@ -316,7 +312,26 @@ int main(int argc, char ** argv) {
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
} else {
|
||||
for (auto & image : params.image) {
|
||||
auto ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
auto image_embed = load_image(ctx_llava, ¶ms, image);
|
||||
if (!image_embed) {
|
||||
std::cerr << "error: failed to load image " << image << ". Terminating\n\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_print_timings(ctx_llava->ctx_llama);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
}
|
||||
}
|
||||
|
||||
llama_free_model(model);
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -7,6 +7,8 @@ Also note that finetunes typically result in a higher perplexity value even thou
|
||||
|
||||
Within llama.cpp the perplexity of base models is used primarily to judge the quality loss from e.g. quantized models vs. FP16.
|
||||
The convention among contributors is to use the Wikitext-2 test set for testing unless noted otherwise (can be obtained with `scripts/get-wikitext-2.sh`).
|
||||
When numbers are listed all command line arguments and compilation options are left at their defaults unless noted otherwise.
|
||||
llama.cpp numbers are **not** directly comparable to those of other projects because the exact values depend strongly on the implementation details.
|
||||
|
||||
By default only the mean perplexity value and the corresponding uncertainty is calculated.
|
||||
The uncertainty is determined empirically by assuming a Gaussian distribution of the "correct" logits per and then applying error propagation.
|
||||
@@ -32,7 +34,13 @@ In addition to the KL divergence the following statistics are calculated with `-
|
||||
|
||||
## LLaMA 3 8b Scoreboard
|
||||
|
||||
Results are sorted by Kullback-Leibler divergence relative to FP16.
|
||||
| Revision | f364eb6f |
|
||||
|:---------|:-------------------|
|
||||
| Backend | CUDA |
|
||||
| CPU | AMD Epyc 7742 |
|
||||
| GPU | 1x NVIDIA RTX 4090 |
|
||||
|
||||
Results were generated using the CUDA backend and are sorted by Kullback-Leibler divergence relative to FP16.
|
||||
The "WT" importance matrices were created using varying numbers of Wikitext tokens and can be found [here](https://huggingface.co/JohannesGaessler/llama.cpp_importance_matrices/blob/main/imatrix-llama_3-8b-f16-2.7m_tokens.dat).
|
||||
|
||||
| Quantization | imatrix | Model size [GiB] | PPL | ΔPPL | KLD | Mean Δp | RMS Δp |
|
||||
@@ -89,6 +97,12 @@ K-quants score better on mean Δp than the legacy quants than e.g. KL divergence
|
||||
|
||||
## LLaMA 2 vs. LLaMA 3 Quantization comparison
|
||||
|
||||
| Revision | f364eb6f |
|
||||
|:---------|:-------------------|
|
||||
| Backend | CUDA |
|
||||
| CPU | AMD Epyc 7742 |
|
||||
| GPU | 1x NVIDIA RTX 4090 |
|
||||
|
||||
| Metric | L2 7b q2_K | L3 8b q2_K | L2 7b q4_K_M | L3 8b q4_K_M | L2 7b q6_K | L3 8b q6_K | L2 7b q8_0 | L3 8b q8_0 |
|
||||
|-----------------|---------------------|---------------------|---------------------|---------------------|---------------------|---------------------|---------------------|---------------------|
|
||||
| Mean PPL | 5.794552 ± 0.032298 | 9.751568 ± 0.063312 | 5.877078 ± 0.032781 | 6.407115 ± 0.039119 | 5.808494 ± 0.032425 | 6.253382 ± 0.038078 | 5.798542 ± 0.032366 | 6.234284 ± 0.037878 |
|
||||
@@ -107,6 +121,50 @@ K-quants score better on mean Δp than the legacy quants than e.g. KL divergence
|
||||
| RMS Δp | 9.762 ± 0.053 % | 21.421 ± 0.079 % | 3.252 ± 0.024 % | 5.519 ± 0.050 % | 1.339 ± 0.010 % | 2.295 ± 0.019 % | 0.618 ± 0.011 % | 1.198 ± 0.007 % |
|
||||
| Same top p | 85.584 ± 0.086 % | 71.138 ± 0.119 % | 94.665 ± 0.055 % | 91.901 ± 0.072 % | 97.520 ± 0.038 % | 96.031 ± 0.051 % | 98.846 ± 0.026 % | 97.674 ± 0.040 % |
|
||||
|
||||
## LLaMA 3 BF16 vs. FP16 comparison
|
||||
|
||||
| Revision | 83330d8c |
|
||||
|:---------|:--------------|
|
||||
| Backend | CPU |
|
||||
| CPU | AMD Epyc 7742 |
|
||||
| GPU | N/A |
|
||||
|
||||
Results were calculated with LLaMA 3 8b BF16 as `--kl-divergence-base` and LLaMA 3 8b FP16 as the `--model` for comparison.
|
||||
|
||||
| Metric | Value |
|
||||
|--------------------------------|--------------------------|
|
||||
| Mean PPL(Q) | 6.227711 ± 0.037833 |
|
||||
| Mean PPL(base) | 6.225194 ± 0.037771 |
|
||||
| Cor(ln(PPL(Q)), ln(PPL(base))) | 99.990% |
|
||||
| Mean ln(PPL(Q)/PPL(base)) | 0.000404 ± 0.000086 |
|
||||
| Mean PPL(Q)/PPL(base) | 1.000404 ± 0.000086 |
|
||||
| Mean PPL(Q)-PPL(base) | 0.002517 ± 0.000536 |
|
||||
| Mean KLD | 0.00002515 ± 0.00000020 |
|
||||
| Maximum KLD | 0.012206 |
|
||||
| 99.9% KLD | 0.000799 |
|
||||
| 99.0% KLD | 0.000222 |
|
||||
| 99.0% KLD | 0.000222 |
|
||||
| Median KLD | 0.000013 |
|
||||
| 10.0% KLD | -0.000002 |
|
||||
| 5.0% KLD | -0.000008 |
|
||||
| 1.0% KLD | -0.000023 |
|
||||
| Minimum KLD | -0.000059 |
|
||||
| Mean Δp | -0.0000745 ± 0.0003952 % |
|
||||
| Maximum Δp | 4.186% |
|
||||
| 99.9% Δp | 1.049% |
|
||||
| 99.0% Δp | 0.439% |
|
||||
| 95.0% Δp | 0.207% |
|
||||
| 90.0% Δp | 0.125% |
|
||||
| 75.0% Δp | 0.029% |
|
||||
| Median Δp | 0.000% |
|
||||
| 25.0% Δp | -0.030% |
|
||||
| 10.0% Δp | -0.126% |
|
||||
| 5.0% Δp | -0.207% |
|
||||
| 1.0% Δp | -0.434% |
|
||||
| 0.1% Δp | -1.016% |
|
||||
| Minimum Δp | -4.672% |
|
||||
| RMS Δp | 0.150 ± 0.001 % |
|
||||
| Same top p | 99.739 ± 0.013 % |
|
||||
|
||||
## Old Numbers
|
||||
|
||||
|
||||
@@ -15564,26 +15564,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
|
||||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
|
||||
#if 0
|
||||
// use syclGemmEx
|
||||
{
|
||||
for (int i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int i12 = 0; i12 < ne12; ++i12) {
|
||||
int i03 = i13 / r3;
|
||||
int i02 = i12 / r2;
|
||||
|
||||
SYCL_CHECK(
|
||||
syclGemmEx(g_sycl_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , SYCL_R_16F, nb01/sizeof(half),
|
||||
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, SYCL_R_16F, nb11/sizeof(float),
|
||||
beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||
@@ -15595,7 +15575,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
|
||||
nb11 / nb10, nb12 / nb10, beta,
|
||||
(char *)dst_t, cu_data_type, ne01, nb2 / nb0,
|
||||
ne12 * ne13, cu_compute_type)));
|
||||
g_sycl_handles[g_main_device]->wait();
|
||||
} else {
|
||||
const int ne23 = ne12*ne13;
|
||||
|
||||
@@ -15626,7 +15605,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
|
||||
nb02, nb03, nb12_scaled, nb13_scaled,
|
||||
nbd2, nbd3, r2, r3, item_ct1);
|
||||
});
|
||||
}).wait();
|
||||
});
|
||||
}
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||
*g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
|
||||
@@ -15637,9 +15616,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
|
||||
dpct::library_data_t::real_half, nb11 / nb10, beta,
|
||||
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
|
||||
cu_compute_type)));
|
||||
g_sycl_handles[g_main_device]->wait();
|
||||
}
|
||||
#endif
|
||||
|
||||
if (no_mixed_dtypes) {
|
||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
||||
|
||||
@@ -2,5 +2,6 @@ from .constants import *
|
||||
from .lazy import *
|
||||
from .gguf_reader import *
|
||||
from .gguf_writer import *
|
||||
from .quants import *
|
||||
from .tensor_mapping import *
|
||||
from .vocab import *
|
||||
|
||||
@@ -13,6 +13,7 @@ from string import ascii_letters, digits
|
||||
import numpy as np
|
||||
|
||||
from .constants import (
|
||||
GGML_QUANT_SIZES,
|
||||
GGUF_DEFAULT_ALIGNMENT,
|
||||
GGUF_MAGIC,
|
||||
GGUF_VERSION,
|
||||
@@ -195,7 +196,7 @@ class GGUFWriter:
|
||||
return ((x + n - 1) // n) * n
|
||||
|
||||
def add_tensor_info(
|
||||
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32],
|
||||
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
|
||||
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
|
||||
) -> None:
|
||||
if self.state is not WriterState.EMPTY:
|
||||
@@ -208,10 +209,6 @@ class GGUFWriter:
|
||||
encoded_name = name.encode("utf-8")
|
||||
self.ti_data += self._pack("Q", len(encoded_name))
|
||||
self.ti_data += encoded_name
|
||||
n_dims = len(tensor_shape)
|
||||
self.ti_data += self._pack("I", n_dims)
|
||||
for i in range(n_dims):
|
||||
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
|
||||
if raw_dtype is None:
|
||||
if tensor_dtype == np.float16:
|
||||
dtype = GGMLQuantizationType.F16
|
||||
@@ -231,6 +228,15 @@ class GGUFWriter:
|
||||
raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
|
||||
else:
|
||||
dtype = raw_dtype
|
||||
if tensor_dtype == np.uint8:
|
||||
block_size, type_size = GGML_QUANT_SIZES[raw_dtype]
|
||||
if tensor_shape[-1] % type_size != 0:
|
||||
raise ValueError(f"Quantized tensor row size ({tensor_shape[-1]}) is not a multiple of {dtype.name} type size ({type_size})")
|
||||
tensor_shape = tuple(tensor_shape[:-1]) + (tensor_shape[-1] // type_size * block_size,)
|
||||
n_dims = len(tensor_shape)
|
||||
self.ti_data += self._pack("I", n_dims)
|
||||
for i in range(n_dims):
|
||||
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
|
||||
self.ti_data += self._pack("I", dtype)
|
||||
self.ti_data += self._pack("Q", self.offset_tensor)
|
||||
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Callable
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from numpy._typing import _Shape
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
|
||||
@@ -110,7 +111,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
return o
|
||||
|
||||
@classmethod
|
||||
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]:
|
||||
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
|
||||
def wrapped_fn(*args, **kwargs):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
@@ -130,9 +131,14 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
res = args[0]
|
||||
assert isinstance(res, cls)
|
||||
res = res._meta
|
||||
# allow operations to override the dtype
|
||||
# allow operations to override the dtype and shape
|
||||
if meta_noop is not True:
|
||||
res = cls.meta_with_dtype(res, meta_noop)
|
||||
if isinstance(meta_noop, tuple):
|
||||
dtype, shape = meta_noop
|
||||
assert callable(shape)
|
||||
res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
|
||||
else:
|
||||
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
|
||||
|
||||
if isinstance(res, cls._tensor_type):
|
||||
def collect_replace(t: LazyBase):
|
||||
@@ -168,7 +174,12 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
while _t._data is None:
|
||||
lt = _t._lazy.popleft()
|
||||
if lt._data is not None:
|
||||
raise ValueError(f"{lt} did not belong in the lazy queue")
|
||||
# Lazy tensor did not belong in the lazy queue.
|
||||
# Weirdly only happens with Bloom models...
|
||||
# likely because tensors aren't unique in the queue.
|
||||
# The final output is still the same as in eager mode,
|
||||
# so it's safe to ignore this.
|
||||
continue
|
||||
assert lt._func is not None
|
||||
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
|
||||
lt._data = lt._func(lt._args)
|
||||
@@ -183,12 +194,12 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
|
||||
@classmethod
|
||||
def eager_to_meta(cls, t: Any) -> Any:
|
||||
return cls.meta_with_dtype(t, t.dtype)
|
||||
return cls.meta_with_dtype_and_shape(t.dtype, t.shape)
|
||||
|
||||
# must be overridden, meta tensor init is backend-specific
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass
|
||||
def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass
|
||||
|
||||
@classmethod
|
||||
def from_eager(cls, t: Any) -> Any:
|
||||
@@ -205,15 +216,15 @@ class LazyNumpyTensor(LazyBase):
|
||||
_tensor_type = np.ndarray
|
||||
|
||||
@classmethod
|
||||
def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]:
|
||||
def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: _Shape) -> np.ndarray[Any, Any]:
|
||||
# The initial idea was to use np.nan as the fill value,
|
||||
# but non-float types like np.int16 can't use that.
|
||||
# So zero it is.
|
||||
cheat = np.zeros(1, dtype)
|
||||
return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape))
|
||||
return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape))
|
||||
|
||||
def astype(self, dtype, *args, **kwargs):
|
||||
meta = type(self).meta_with_dtype(self._meta, dtype)
|
||||
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
|
||||
full_args = (self, dtype,) + args
|
||||
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
|
||||
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
|
||||
|
||||
109
gguf-py/gguf/quants.py
Normal file
109
gguf-py/gguf/quants.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable
|
||||
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from .constants import GGML_QUANT_SIZES, GGMLQuantizationType
|
||||
from .lazy import LazyNumpyTensor
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
|
||||
def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
|
||||
n = n.astype(np.float32, copy=False).view(np.int32)
|
||||
# force nan to quiet
|
||||
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
|
||||
# flush subnormals to zero
|
||||
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
|
||||
# round to nearest even
|
||||
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
|
||||
return n.astype(np.int16)
|
||||
|
||||
|
||||
# This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
|
||||
def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
|
||||
rows = arr.reshape((-1, arr.shape[-1]))
|
||||
osize = 1
|
||||
for dim in oshape:
|
||||
osize *= dim
|
||||
out = np.empty(shape=osize, dtype=otype)
|
||||
# compute over groups of 16 rows (arbitrary, but seems good for performance)
|
||||
n_groups = rows.shape[0] // 16
|
||||
np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
|
||||
return out.reshape(oshape)
|
||||
|
||||
|
||||
def __quantize_bf16_array(n: np.ndarray) -> np.ndarray:
|
||||
return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.int16, oshape=n.shape)
|
||||
|
||||
|
||||
__quantize_bf16_lazy = LazyNumpyTensor._wrap_fn(__quantize_bf16_array, meta_noop=np.int16)
|
||||
|
||||
|
||||
def quantize_bf16(n: np.ndarray):
|
||||
if type(n) is LazyNumpyTensor:
|
||||
return __quantize_bf16_lazy(n)
|
||||
else:
|
||||
return __quantize_bf16_array(n)
|
||||
|
||||
|
||||
__q8_block_size, __q8_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q8_0]
|
||||
|
||||
|
||||
def can_quantize_to_q8_0(n: np.ndarray) -> bool:
|
||||
return n.shape[-1] % __q8_block_size == 0
|
||||
|
||||
|
||||
# round away from zero
|
||||
# ref: https://stackoverflow.com/a/59143326/22827863
|
||||
def np_roundf(n: np.ndarray) -> np.ndarray:
|
||||
a = abs(n)
|
||||
floored = np.floor(a)
|
||||
b = floored + np.floor(2 * (a - floored))
|
||||
return np.sign(n) * b
|
||||
|
||||
|
||||
def __quantize_q8_0_shape_change(s: tuple[int, ...]) -> tuple[int, ...]:
|
||||
return (*s[:-1], s[-1] // __q8_block_size * __q8_type_size)
|
||||
|
||||
|
||||
# Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c
|
||||
def __quantize_q8_0_rows(n: np.ndarray) -> np.ndarray:
|
||||
shape = n.shape
|
||||
assert shape[-1] % __q8_block_size == 0
|
||||
|
||||
n_blocks = n.size // __q8_block_size
|
||||
|
||||
blocks = n.reshape((n_blocks, __q8_block_size)).astype(np.float32, copy=False)
|
||||
|
||||
d = abs(blocks).max(axis=1, keepdims=True) / 127
|
||||
with np.errstate(divide="ignore"):
|
||||
id = np.where(d == 0, 0, 1 / d)
|
||||
qs = np_roundf(blocks * id)
|
||||
|
||||
# (n_blocks, 2)
|
||||
d = d.astype(np.float16).view(np.uint8)
|
||||
# (n_blocks, block_size)
|
||||
qs = qs.astype(np.int8).view(np.uint8)
|
||||
|
||||
assert d.shape[1] + qs.shape[1] == __q8_type_size
|
||||
|
||||
return np.concatenate([d, qs], axis=1).reshape(__quantize_q8_0_shape_change(shape))
|
||||
|
||||
|
||||
def __quantize_q8_0_array(n: np.ndarray) -> np.ndarray:
|
||||
return __apply_over_grouped_rows(__quantize_q8_0_rows, arr=n, otype=np.uint8, oshape=__quantize_q8_0_shape_change(n.shape))
|
||||
|
||||
|
||||
__quantize_q8_0_lazy = LazyNumpyTensor._wrap_fn(
|
||||
__quantize_q8_0_array,
|
||||
meta_noop=(np.uint8, __quantize_q8_0_shape_change),
|
||||
)
|
||||
|
||||
|
||||
def quantize_q8_0(data: np.ndarray):
|
||||
if type(data) is LazyNumpyTensor:
|
||||
return __quantize_q8_0_lazy(data)
|
||||
else:
|
||||
return __quantize_q8_0_array(data)
|
||||
20
llama.cpp
20
llama.cpp
@@ -2805,6 +2805,11 @@ static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
|
||||
cache.do_defrag = true;
|
||||
}
|
||||
|
||||
static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
|
||||
// the FA kernels require padding to avoid extra runtime boundary checks
|
||||
return cparams.flash_attn ? 256u : 32u;
|
||||
}
|
||||
|
||||
//
|
||||
// model loading and saving
|
||||
//
|
||||
@@ -11510,7 +11515,8 @@ static int llama_decode_internal(
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
|
||||
const uint32_t pad = llama_kv_cache_get_padding(cparams);
|
||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
}
|
||||
}
|
||||
@@ -15511,6 +15517,11 @@ struct llama_context * llama_new_context_with_model(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
|
||||
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
|
||||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
llama_context * ctx = new llama_context(*model);
|
||||
|
||||
const auto & hparams = model->hparams;
|
||||
@@ -15534,7 +15545,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
||||
|
||||
// this is necessary due to kv_self.n being padded later during inference
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams));
|
||||
|
||||
// with causal attention, the batch size is limited by the context size
|
||||
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
||||
@@ -15579,11 +15590,6 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
}
|
||||
|
||||
if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
|
||||
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
|
||||
cparams.flash_attn = false;
|
||||
}
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user