mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-03-05 14:33:24 +02:00
Compare commits
16 Commits
b8190
...
compilade/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ef41855cf | ||
|
|
f88a4b9398 | ||
|
|
4be1a5d44b | ||
|
|
6ffa46d8f4 | ||
|
|
3126b5ee4e | ||
|
|
e097d98a22 | ||
|
|
5712aa895f | ||
|
|
d3fcb0e90e | ||
|
|
614b95a88d | ||
|
|
c3738cfcef | ||
|
|
791bd97b3c | ||
|
|
d921057027 | ||
|
|
562aa42c12 | ||
|
|
e996f3aef8 | ||
|
|
e7b7ed8ab1 | ||
|
|
c4b630f25d |
@@ -11,6 +11,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from hashlib import sha256
|
||||
@@ -76,6 +77,14 @@ class ModelType(IntEnum):
|
||||
AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelTensorInfo:
|
||||
load: Callable[[], Tensor]
|
||||
size: int # in elements
|
||||
src_type: str
|
||||
auto_qtype: gguf.GGMLQuantizationType | None = None
|
||||
|
||||
|
||||
class ModelBase:
|
||||
_model_classes: dict[ModelType, dict[str, type[ModelBase]]] = {
|
||||
ModelType.TEXT: {},
|
||||
@@ -84,14 +93,16 @@ class ModelBase:
|
||||
|
||||
dir_model: Path
|
||||
ftype: gguf.LlamaFileType
|
||||
ftype_guessed: bool
|
||||
fname_out: Path
|
||||
is_big_endian: bool
|
||||
endianess: gguf.GGUFEndian
|
||||
use_temp_file: bool
|
||||
use_reflinks: bool
|
||||
lazy: bool
|
||||
dry_run: bool
|
||||
hparams: dict[str, Any]
|
||||
model_tensors: dict[str, Callable[[], Tensor]]
|
||||
model_tensors: dict[str, ModelTensorInfo]
|
||||
gguf_writer: gguf.GGUFWriter
|
||||
model_name: str | None
|
||||
metadata_override: Path | None
|
||||
@@ -116,7 +127,8 @@ class ModelBase:
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
|
||||
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
|
||||
disable_mistral_community_chat_template: bool = False,
|
||||
sentence_transformers_dense_modules: bool = False):
|
||||
sentence_transformers_dense_modules: bool = False,
|
||||
use_reflinks: bool = False):
|
||||
if type(self) is ModelBase or \
|
||||
type(self) is TextModel or \
|
||||
type(self) is MmprojModel:
|
||||
@@ -127,10 +139,12 @@ class ModelBase:
|
||||
|
||||
self.dir_model = dir_model
|
||||
self.ftype = ftype
|
||||
self.ftype_guessed = ftype == gguf.LlamaFileType.GUESSED
|
||||
self.fname_out = fname_out
|
||||
self.is_big_endian = is_big_endian
|
||||
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
||||
self.use_temp_file = use_temp_file
|
||||
self.use_reflinks = use_reflinks
|
||||
self.lazy = not eager or (remote_hf_model_id is not None)
|
||||
self.dry_run = dry_run
|
||||
self.remote_hf_model_id = remote_hf_model_id
|
||||
@@ -141,22 +155,40 @@ class ModelBase:
|
||||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
|
||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
|
||||
_, first_tensor = next(self.get_tensors())
|
||||
if first_tensor.dtype == torch.float16:
|
||||
logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_F16
|
||||
else:
|
||||
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
|
||||
|
||||
self.dequant_model()
|
||||
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
# find out the most common type
|
||||
hist: dict[gguf.GGMLQuantizationType, int] = {}
|
||||
for t in self.model_tensors.values():
|
||||
if t.auto_qtype is not None:
|
||||
if t.auto_qtype not in hist:
|
||||
hist[t.auto_qtype] = 0
|
||||
hist[t.auto_qtype] += t.size
|
||||
max_qtype = gguf.GGMLQuantizationType.F32
|
||||
max_size = 0
|
||||
for qtype, size in hist.items():
|
||||
if size > max_size:
|
||||
max_qtype = qtype
|
||||
max_size = size
|
||||
# TODO: add more type if they're used as auto_qtype
|
||||
if max_qtype == gguf.GGMLQuantizationType.F32:
|
||||
self.ftype = gguf.LlamaFileType.ALL_F32
|
||||
elif max_qtype == gguf.GGMLQuantizationType.F16:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_F16
|
||||
elif max_qtype == gguf.GGMLQuantizationType.BF16:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
|
||||
elif max_qtype == gguf.GGMLQuantizationType.Q8_0:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_Q8_0
|
||||
elif max_qtype == gguf.GGMLQuantizationType.Q4_1:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_Q4_1
|
||||
elif max_qtype == gguf.GGMLQuantizationType.TQ1_0:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_TQ1_0
|
||||
|
||||
# Configure GGUF Writer
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
||||
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
|
||||
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard,
|
||||
use_reflinks=self.use_reflinks)
|
||||
|
||||
# Mistral specific
|
||||
self.disable_mistral_community_chat_template = disable_mistral_community_chat_template
|
||||
@@ -175,8 +207,8 @@ class ModelBase:
|
||||
return None
|
||||
raise KeyError(f"could not find any of: {keys}")
|
||||
|
||||
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
|
||||
tensors: dict[str, Callable[[], Tensor]] = {}
|
||||
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, ModelTensorInfo]:
|
||||
tensors: dict[str, ModelTensorInfo] = {}
|
||||
|
||||
if remote_hf_model_id is not None:
|
||||
is_safetensors = True
|
||||
@@ -184,7 +216,14 @@ class ModelBase:
|
||||
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
|
||||
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
|
||||
for name, remote_tensor in remote_tensors.items():
|
||||
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
|
||||
dtype = LazyTorchTensor._dtype_str_map[remote_tensor.dtype]
|
||||
qtype = LazyTorchTensor._qtype_map.get(dtype)
|
||||
tensors[name] = ModelTensorInfo(
|
||||
load=lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r),
|
||||
size=math.prod(remote_tensor.shape),
|
||||
src_type=str(dtype),
|
||||
auto_qtype=qtype,
|
||||
)
|
||||
|
||||
return tensors
|
||||
|
||||
@@ -218,8 +257,7 @@ class ModelBase:
|
||||
logger.info(f"gguf: indexing model part '{part_name}'")
|
||||
ctx: ContextManager[Any]
|
||||
if is_safetensors:
|
||||
from safetensors import safe_open
|
||||
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
||||
ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name, reflink=self.use_reflinks))
|
||||
else:
|
||||
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
||||
|
||||
@@ -228,19 +266,28 @@ class ModelBase:
|
||||
|
||||
for name in model_part.keys():
|
||||
if is_safetensors:
|
||||
data: gguf.utility.LocalTensor = model_part[name]
|
||||
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
|
||||
size = math.prod(data.shape)
|
||||
if self.lazy:
|
||||
data = model_part.get_slice(name)
|
||||
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
|
||||
data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731
|
||||
else:
|
||||
data = model_part.get_tensor(name)
|
||||
data_gen = lambda data=data: data # noqa: E731
|
||||
data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
|
||||
else:
|
||||
data = model_part[name]
|
||||
data_torch: Tensor = model_part[name]
|
||||
size = data_torch.numel()
|
||||
dtype = data_torch.dtype
|
||||
if self.lazy:
|
||||
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
|
||||
data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
|
||||
else:
|
||||
data_gen = lambda data=data: data # noqa: E731
|
||||
tensors[name] = data_gen
|
||||
data_gen = lambda data=data_torch: data # noqa: E731
|
||||
qtype = LazyTorchTensor._qtype_map.get(dtype)
|
||||
tensors[name] = ModelTensorInfo(
|
||||
load=data_gen,
|
||||
size=size,
|
||||
src_type=str(dtype),
|
||||
auto_qtype=qtype,
|
||||
)
|
||||
|
||||
# verify tensor name presence and identify potentially missing files
|
||||
if len(tensor_names_from_index) > 0:
|
||||
@@ -261,7 +308,7 @@ class ModelBase:
|
||||
|
||||
def dequant_model(self):
|
||||
tensors_to_remove: list[str] = []
|
||||
new_tensors: dict[str, Callable[[], Tensor]] = {}
|
||||
new_tensors: dict[str, ModelTensorInfo] = {}
|
||||
|
||||
if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
|
||||
quant_method = quant_config.get("quant_method")
|
||||
@@ -339,7 +386,12 @@ class ModelBase:
|
||||
weight_name = name.removesuffix("_scale")
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
|
||||
self.model_tensors[weight_name] = ModelTensorInfo(
|
||||
load=lambda w=w, s=s: dequant_bitnet(w.load(), s.load()),
|
||||
size=w.size,
|
||||
src_type="bitnet",
|
||||
auto_qtype=gguf.GGMLQuantizationType.TQ1_0,
|
||||
)
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method == "fp8":
|
||||
for name in self.model_tensors.keys():
|
||||
@@ -347,9 +399,17 @@ class ModelBase:
|
||||
weight_name = name.removesuffix("_scale_inv")
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
|
||||
# TODO: change to FP8 once natively supported
|
||||
auto_qtype = s.auto_qtype if s.auto_qtype is not gguf.GGMLQuantizationType.F32 else gguf.GGMLQuantizationType.BF16
|
||||
self.model_tensors[weight_name] = ModelTensorInfo(
|
||||
load=lambda w=w, s=s: dequant_simple(w.load(), s.load()),
|
||||
size=w.size,
|
||||
src_type=w.src_type,
|
||||
auto_qtype=auto_qtype,
|
||||
)
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method == "gptq":
|
||||
bits = quant_config["bits"]
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".qweight"):
|
||||
base_name = name.removesuffix(".qweight")
|
||||
@@ -357,10 +417,13 @@ class ModelBase:
|
||||
qweight = self.model_tensors[base_name + ".qweight"]
|
||||
qzeros = self.model_tensors[base_name + ".qzeros"]
|
||||
scales = self.model_tensors[base_name + ".scales"]
|
||||
new_tensors[base_name + ".weight"] = (
|
||||
lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
|
||||
g(), w(), z(), s()
|
||||
)
|
||||
new_tensors[base_name + ".weight"] = ModelTensorInfo(
|
||||
load=lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
|
||||
g.load(), w.load(), z.load(), s.load()
|
||||
),
|
||||
size=qweight.size, # TODO: use more accurate value
|
||||
src_type=f"GPTQ-{bits}bit",
|
||||
auto_qtype=gguf.GGMLQuantizationType.Q8_0 if bits == 8 else gguf.GGMLQuantizationType.Q4_1,
|
||||
)
|
||||
tensors_to_remove += [
|
||||
base_name + n
|
||||
@@ -382,8 +445,8 @@ class ModelBase:
|
||||
self.model_tensors[name] = value
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
for name, gen in self.model_tensors.items():
|
||||
yield name, gen()
|
||||
for name, t in self.model_tensors.items():
|
||||
yield name, t.load()
|
||||
|
||||
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
|
||||
if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
||||
@@ -438,10 +501,12 @@ class ModelBase:
|
||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
||||
old_dtype = data_torch.dtype
|
||||
tensor_info = self.model_tensors.get(name)
|
||||
old_dtype: str = tensor_info.src_type if tensor_info is not None else str(data_torch.dtype)
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
# TODO: handle pre-quantized tensors for repacking
|
||||
if data_torch.dtype not in (torch.float16, torch.bfloat16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
|
||||
# use the first number-like part of the tensor name as the block id
|
||||
@@ -452,8 +517,18 @@ class ModelBase:
|
||||
break
|
||||
|
||||
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
|
||||
# TODO: why do we squeeze here?
|
||||
# data = data_torch.squeeze().numpy()
|
||||
old_qtype = LazyTorchTensor._qtype_map[data_torch.dtype]
|
||||
|
||||
# workaround BF16 not being supported by Numpy
|
||||
if data_torch.dtype == torch.bfloat16:
|
||||
# Need a contiguous last dimension otherwise byte view doesn't work
|
||||
# (problem can be reproduced with DeepSeek-V2-Lite-Chat)
|
||||
data_torch = data_torch.contiguous().view(torch.uint8)
|
||||
|
||||
# if data ends up empty, it means data_torch was a scalar tensor -> restore
|
||||
if len(data_torch.shape) == 0:
|
||||
data_torch = data_torch.reshape(1)
|
||||
|
||||
data = data_torch.numpy()
|
||||
|
||||
n_dims = len(data.shape)
|
||||
@@ -512,7 +587,9 @@ class ModelBase:
|
||||
|
||||
# No override (data_qtype is False), or wants to be quantized (data_qtype is True)
|
||||
if isinstance(data_qtype, bool):
|
||||
if self.ftype == gguf.LlamaFileType.ALL_F32:
|
||||
if self.ftype_guessed:
|
||||
data_qtype = old_qtype if tensor_info is None or tensor_info.auto_qtype is None else tensor_info.auto_qtype
|
||||
elif self.ftype == gguf.LlamaFileType.ALL_F32:
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
@@ -527,12 +604,18 @@ class ModelBase:
|
||||
else:
|
||||
raise ValueError(f"Unknown file type: {self.ftype.name}")
|
||||
|
||||
try:
|
||||
data = gguf.quants.quantize(data, data_qtype)
|
||||
except gguf.QuantError as e:
|
||||
logger.warning("%s, %s", e, "falling back to F16")
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
data = gguf.quants.quantize(data, data_qtype)
|
||||
if old_qtype != data_qtype:
|
||||
if old_qtype not in (
|
||||
gguf.GGMLQuantizationType.F32,
|
||||
gguf.GGMLQuantizationType.F16,
|
||||
):
|
||||
data = gguf.quants.dequantize(data, old_qtype)
|
||||
try:
|
||||
data = gguf.quants.quantize(data, data_qtype)
|
||||
except gguf.QuantError as e:
|
||||
logger.warning("%s, %s", e, "falling back to F16")
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
data = gguf.quants.quantize(data, data_qtype)
|
||||
|
||||
shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
|
||||
|
||||
@@ -4705,7 +4788,7 @@ class Plamo2Model(TextModel):
|
||||
del bid # unused
|
||||
|
||||
if name.endswith(".A_log"):
|
||||
data_torch = -torch.exp(data_torch)
|
||||
data_torch = -torch.exp(data_torch.float())
|
||||
elif name.endswith(".dt_bias"):
|
||||
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
|
||||
elif name.endswith(".dt_norm_weight"):
|
||||
@@ -6229,7 +6312,7 @@ class MambaModel(TextModel):
|
||||
|
||||
if name.endswith(".A_log"):
|
||||
logger.debug("A_log --> A ==> " + new_name)
|
||||
data_torch = -torch.exp(data_torch)
|
||||
data_torch = -torch.exp(data_torch.float())
|
||||
|
||||
# [4 1 8192 1] -> [4 8192 1 1]
|
||||
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
|
||||
@@ -6334,7 +6417,7 @@ class Mamba2Model(TextModel):
|
||||
|
||||
if name.endswith(".A_log"):
|
||||
logger.debug("A_log --> A ==> " + new_name)
|
||||
data_torch = -torch.exp(data_torch)
|
||||
data_torch = -torch.exp(data_torch.float())
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
@@ -6434,7 +6517,7 @@ class JambaModel(TextModel):
|
||||
|
||||
if name.endswith(".A_log"):
|
||||
logger.debug("A_log --> A ==> " + new_name)
|
||||
data_torch = -torch.exp(data_torch)
|
||||
data_torch = -torch.exp(data_torch.float())
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
@@ -9983,12 +10066,20 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
"F8_E5M2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
_qtype_map: dict[torch.dtype, gguf.GGMLQuantizationType] = {
|
||||
torch.float64: gguf.GGMLQuantizationType.F64,
|
||||
torch.float32: gguf.GGMLQuantizationType.F32,
|
||||
torch.float16: gguf.GGMLQuantizationType.F16,
|
||||
torch.bfloat16: gguf.GGMLQuantizationType.BF16,
|
||||
}
|
||||
|
||||
def numpy(self) -> gguf.LazyNumpyTensor:
|
||||
dtype = self._dtype_map[self.dtype]
|
||||
return gguf.LazyNumpyTensor(
|
||||
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
|
||||
args=(self,),
|
||||
func=(lambda s: s.numpy())
|
||||
func=(lambda s: s.numpy()),
|
||||
ranges=self._ranges,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -10002,6 +10093,16 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:])
|
||||
return cast(torch.Tensor, lazy)
|
||||
|
||||
@classmethod
|
||||
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
|
||||
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
|
||||
dtype = cls._dtype_str_map[tensor.dtype]
|
||||
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
|
||||
dtype = cls._dtype_str_map[t.dtype]
|
||||
shape = t.shape
|
||||
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r), ranges=(t.data_range,))
|
||||
return cast(torch.Tensor, lazy)
|
||||
|
||||
@classmethod
|
||||
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
|
||||
dtype = cls._dtype_str_map[remote_tensor.dtype]
|
||||
@@ -10020,7 +10121,27 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
if func is torch.Tensor.numpy:
|
||||
return args[0].numpy()
|
||||
|
||||
return cls._wrap_fn(func)(*args, **kwargs)
|
||||
result = cls._wrap_fn(func)(*args, **kwargs)
|
||||
|
||||
def get_dim(index: int, key: str = "dim", default: int = 0, args=args, kwargs=kwargs) -> int:
|
||||
# TODO: handle negative dim
|
||||
if len(args) > index:
|
||||
return args[index]
|
||||
else:
|
||||
return kwargs.get(key, default)
|
||||
|
||||
# Track file ranges
|
||||
# TODO: handle tensor splits (with torch.split, torch.chunk, and torch.__getitem__)
|
||||
if isinstance(result, LazyTorchTensor):
|
||||
if isinstance(args[0], LazyTorchTensor):
|
||||
if func is torch.Tensor.to and not isinstance(args[1], torch.dtype):
|
||||
result._ranges = args[0]._ranges
|
||||
if func is torch.stack and get_dim(1) == 0:
|
||||
if all(isinstance(t, LazyTorchTensor) and len(t._ranges) > 0 for t in args[0]):
|
||||
# collect ranges of all stacked tensors
|
||||
result._ranges = tuple(r for t in args[0] for r in t._ranges)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
@@ -10035,8 +10156,8 @@ def parse_args() -> argparse.Namespace:
|
||||
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="auto",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for mostly unchanged types",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bigendian", action="store_true",
|
||||
@@ -10055,6 +10176,10 @@ def parse_args() -> argparse.Namespace:
|
||||
"--no-lazy", action="store_true",
|
||||
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reflink", action="store_true",
|
||||
help="(Experimental) Use copy-on-write reflinks when possible (e.g. on BTRFS, XFS, ZFS, etc.). File alignment and padding will differ compared to not using this option. Should be very fast when source model layout is compatible enough.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name", type=str, default=None,
|
||||
help="name of the model",
|
||||
@@ -10249,7 +10374,8 @@ def main() -> None:
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split,
|
||||
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
|
||||
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
|
||||
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules,
|
||||
use_reflinks=args.reflink,
|
||||
)
|
||||
|
||||
if args.vocab_only:
|
||||
|
||||
@@ -42,8 +42,8 @@ void ggml_print_backtrace(void);
|
||||
# define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
// required for mmap as gguf only guarantees 32-byte alignment
|
||||
#define TENSOR_ALIGNMENT 32
|
||||
// required for mmap as gguf converted with reflinks from safetensors only guarantees 8-byte alignment
|
||||
#define TENSOR_ALIGNMENT 8
|
||||
|
||||
// static_assert should be a #define, but if it's not,
|
||||
// fall back to the _Static_assert C11 keyword.
|
||||
|
||||
@@ -624,14 +624,16 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
ctx->size = 0;
|
||||
for (size_t i = 0; i < ctx->info.size(); ++i) {
|
||||
const gguf_tensor_info & ti = ctx->info[i];
|
||||
if (ti.offset != ctx->size) {
|
||||
// alignment offset only exists for GGUF converted with reflinks
|
||||
const size_t align_offset = ti.offset % ctx->alignment;
|
||||
if (ti.offset - align_offset != ctx->size) {
|
||||
GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
|
||||
__func__, ti.t.name, ti.offset, ctx->size);
|
||||
__func__, ti.t.name, ti.offset, ctx->size + align_offset);
|
||||
GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
|
||||
size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t) + align_offset, ctx->alignment);
|
||||
if (SIZE_MAX - ctx->size < padded_size) {
|
||||
GGML_LOG_ERROR("%s: tensor '%s' size overflow, cannot accumulate size %zu + %zu\n",
|
||||
__func__, ti.t.name, ctx->size, padded_size);
|
||||
|
||||
@@ -29,6 +29,7 @@ from .constants import (
|
||||
ExpertGatingFuncType,
|
||||
)
|
||||
|
||||
from .lazy import best_extra_offset, count_reflinkable_size
|
||||
from .quants import quant_shape_from_byte_shape
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -84,14 +85,16 @@ class GGUFWriter:
|
||||
|
||||
def __init__(
|
||||
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
|
||||
use_reflinks = False, # opportunistically attempt to use copy-on-write
|
||||
):
|
||||
self.fout = None
|
||||
self.path = Path(path) if path else None
|
||||
self.arch = arch
|
||||
self.endianess = endianess
|
||||
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
||||
self.use_temp_file = use_temp_file
|
||||
self.use_reflinks = use_reflinks
|
||||
self.use_temp_file = False if self.use_reflinks else use_temp_file
|
||||
self.temp_file = None
|
||||
self.tensors = [{}]
|
||||
self.kv_data = [{}]
|
||||
@@ -178,13 +181,28 @@ class GGUFWriter:
|
||||
self.fout = [open(filename, "wb") for filename in filenames]
|
||||
self.state = WriterState.EMPTY
|
||||
|
||||
if self.use_reflinks:
|
||||
# reflinks require alignment to the filesystem blocks
|
||||
block_size = os.stat(self.path.parent).st_blksize
|
||||
# necessary to get an appropriate data start offset when padding for reflinks;
|
||||
# using the real alignment (8 bytes, from safetensors) would result in a unusable base data offset
|
||||
self.data_alignment = block_size
|
||||
# for all shards to allow reading them on their own
|
||||
for i, kv in enumerate(self.kv_data):
|
||||
# insert at the start of the key-values
|
||||
if Keys.General.ALIGNMENT in kv:
|
||||
del kv[Keys.General.ALIGNMENT]
|
||||
self.kv_data[i] = {Keys.General.ALIGNMENT: GGUFValue(block_size, GGUFValueType.UINT32), **kv}
|
||||
|
||||
def print_plan(self) -> list[Path]:
|
||||
logger.info("Writing the following files:")
|
||||
assert self.path is not None
|
||||
filenames = self.format_shard_names(self.path)
|
||||
assert len(filenames) == len(self.tensors)
|
||||
for name, tensors in zip(filenames, self.tensors):
|
||||
logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
|
||||
total_size = sum(ti.nbytes for ti in tensors.values())
|
||||
reflinkable_size = count_reflinkable_size((name, ti.tensor) for name, ti in tensors.items()) if self.use_reflinks else 0
|
||||
logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(total_size)}{', reflinked = ' + GGUFWriter.format_n_bytes_to_str(total_size - reflinkable_size) if self.use_reflinks else ''}")
|
||||
|
||||
if self.dry_run:
|
||||
logger.info("Dry run, not writing files")
|
||||
@@ -257,14 +275,18 @@ class GGUFWriter:
|
||||
offset_tensor = 0
|
||||
|
||||
for name, ti in tensors.items():
|
||||
extra_offset = 0
|
||||
if self.use_reflinks:
|
||||
extra_offset = best_extra_offset(ti.tensor, offset_tensor)
|
||||
|
||||
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
|
||||
n_dims = len(ti.shape)
|
||||
ti_data += self._pack("I", n_dims)
|
||||
for j in range(n_dims):
|
||||
ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
|
||||
ti_data += self._pack("I", ti.dtype)
|
||||
ti_data += self._pack("Q", offset_tensor)
|
||||
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
|
||||
ti_data += self._pack("Q", offset_tensor + extra_offset)
|
||||
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes + extra_offset, self.data_alignment)
|
||||
|
||||
fout.write(ti_data)
|
||||
fout.flush()
|
||||
@@ -392,7 +414,7 @@ class GGUFWriter:
|
||||
def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
|
||||
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
|
||||
if pad != 0:
|
||||
fp.write(bytes([0] * pad))
|
||||
fp.write(b"\x00" * pad)
|
||||
|
||||
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
||||
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
|
||||
@@ -418,7 +440,7 @@ class GGUFWriter:
|
||||
|
||||
self.write_padding(fout, fout.tell())
|
||||
tensor.tofile(fout)
|
||||
self.write_padding(fout, tensor.nbytes)
|
||||
self.write_padding(fout, fout.tell())
|
||||
|
||||
self.state = WriterState.WEIGHTS
|
||||
|
||||
@@ -458,7 +480,7 @@ class GGUFWriter:
|
||||
shard_bar.update(ti.nbytes)
|
||||
if bar is not None:
|
||||
bar.update(ti.nbytes)
|
||||
self.write_padding(fout, ti.nbytes)
|
||||
self.write_padding(fout, fout.tell())
|
||||
ti.tensor = None
|
||||
else:
|
||||
self.temp_file.seek(0)
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
from io import BufferedReader, BufferedWriter
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from .utility import LocalTensorRange
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,10 +27,11 @@ class LazyMeta(ABCMeta):
|
||||
return type(self)._wrap_fn(
|
||||
(lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
|
||||
use_self=self,
|
||||
data_noop=name in ("view", "reshape", "squeeze", "unsqueeze", "contiguous"),
|
||||
)
|
||||
elif isinstance(meta_attr, self._tensor_type):
|
||||
# e.g. self.T with torch.Tensor should still be wrapped
|
||||
return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
|
||||
return type(self)._wrap_fn(lambda s: getattr(s, name), use_self=self)()
|
||||
else:
|
||||
# no need to wrap non-tensor properties,
|
||||
# and they likely don't depend on the actual contents of the tensor
|
||||
@@ -39,8 +47,9 @@ class LazyMeta(ABCMeta):
|
||||
def wrapped_special_op(self, *args, **kwargs):
|
||||
return type(self)._wrap_fn(
|
||||
getattr(type(self)._tensor_type, op_name),
|
||||
use_self=self,
|
||||
meta_noop=meta_noop,
|
||||
)(self, *args, **kwargs)
|
||||
)(*args, **kwargs)
|
||||
return wrapped_special_op
|
||||
|
||||
# special methods bypass __getattr__, so they need to be added manually
|
||||
@@ -76,14 +85,16 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
_args: tuple
|
||||
_kwargs: dict[str, Any]
|
||||
_func: Callable[[Any], Any] | None
|
||||
_ranges: tuple[LocalTensorRange, ...]
|
||||
|
||||
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
|
||||
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None, ranges: tuple[LocalTensorRange, ...] = ()):
|
||||
super().__init__()
|
||||
self._meta = meta
|
||||
self._data = data
|
||||
self._args = args
|
||||
self._kwargs = kwargs if kwargs is not None else {}
|
||||
self._func = func
|
||||
self._ranges = ranges
|
||||
assert self._func is not None or self._data is not None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
@@ -107,7 +118,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
return o
|
||||
|
||||
@classmethod
|
||||
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
|
||||
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False, data_noop: bool = False) -> Callable[[Any], Any]:
|
||||
def wrapped_fn(*args, **kwargs):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
@@ -116,6 +127,8 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
|
||||
# TODO: maybe handle tensors in kwargs too
|
||||
|
||||
ranges = use_self._ranges if use_self is not None and data_noop else ()
|
||||
|
||||
if isinstance(meta_noop, bool) and not meta_noop:
|
||||
try:
|
||||
res = fn(*meta_args, **kwargs)
|
||||
@@ -138,7 +151,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
|
||||
|
||||
if isinstance(res, cls._tensor_type):
|
||||
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
|
||||
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn, ranges=ranges)
|
||||
elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res):
|
||||
# share the evaluation between lazy tuple elements
|
||||
shared_args: list = [args, None]
|
||||
@@ -202,6 +215,7 @@ class LazyNumpyTensor(LazyBase):
|
||||
_tensor_type = np.ndarray
|
||||
|
||||
shape: tuple[int, ...] # Makes the type checker happy in quants.py
|
||||
nbytes: int
|
||||
|
||||
@classmethod
|
||||
def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
|
||||
@@ -214,10 +228,154 @@ class LazyNumpyTensor(LazyBase):
|
||||
def astype(self, dtype, *args, **kwargs):
|
||||
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
|
||||
full_args = (self, dtype,) + args
|
||||
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
|
||||
ranges = self._ranges if self._meta.dtype == dtype else ()
|
||||
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)), ranges=ranges)
|
||||
|
||||
def tofile(self, *args, **kwargs):
|
||||
eager = LazyNumpyTensor.to_eager(self)
|
||||
return eager.tofile(*args, **kwargs)
|
||||
def tofile(self, fid, *args, **kwargs):
|
||||
if isinstance(fid, BufferedWriter) and len(self._ranges) > 0:
|
||||
return copy_tensor_ranges(self, fid)
|
||||
else:
|
||||
eager = LazyNumpyTensor.to_eager(self)
|
||||
return eager.tofile(fid, *args, **kwargs)
|
||||
|
||||
# TODO: __array_function__
|
||||
|
||||
|
||||
# For aligning blocks when reflinking
|
||||
def best_extra_offset(t: np.ndarray | LazyNumpyTensor | None, current_offset: int) -> int:
|
||||
if not isinstance(t, LazyNumpyTensor):
|
||||
# no file ranges, no need for an offset
|
||||
return 0
|
||||
|
||||
ranges = t._ranges
|
||||
|
||||
histogram: dict[int, int] = {}
|
||||
|
||||
max_block_size = 0
|
||||
for r in ranges:
|
||||
# Ensure minimal alignment is 8 bytes (common with safetensors)
|
||||
# and that the block size is valid
|
||||
if r.offset % 8 == 0 and r.block_size > 0:
|
||||
align_offset = r.offset % r.block_size
|
||||
if align_offset not in histogram:
|
||||
histogram[align_offset] = 0
|
||||
histogram[align_offset] += r.size
|
||||
if r.block_size > max_block_size:
|
||||
max_block_size = r.block_size
|
||||
|
||||
best_offset = 0
|
||||
best_size = 0
|
||||
for offset, size in histogram.items():
|
||||
if size > best_size:
|
||||
best_size = size
|
||||
best_offset = offset
|
||||
|
||||
if max_block_size > 0:
|
||||
# the offset needs to be aligned properly
|
||||
# or else there's probably a block size mismatch
|
||||
assert current_offset % max_block_size == 0, current_offset % max_block_size
|
||||
|
||||
return best_offset
|
||||
|
||||
|
||||
def count_reflinkable_size(tensors: Iterable[tuple[str, np.ndarray | LazyNumpyTensor | None]]) -> int:
|
||||
if not hasattr(os, "copy_file_range"):
|
||||
return 0
|
||||
size = 0
|
||||
for name, t in tensors:
|
||||
if isinstance(t, LazyNumpyTensor) and len(t._ranges) > 0:
|
||||
align_offset = best_extra_offset(t, 0)
|
||||
misaligned = 0
|
||||
for range in t._ranges:
|
||||
if range.block_size > 0:
|
||||
if range.offset % range.block_size == align_offset:
|
||||
size += range.size
|
||||
else:
|
||||
misaligned += 1
|
||||
if misaligned > 0:
|
||||
logger.debug(f"{name} misaligned for reflinking, fallback to copy for {misaligned} of {len(t._ranges)} parts")
|
||||
return size
|
||||
|
||||
|
||||
# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes
|
||||
# to make it more likely that copy-on-write is used where possible.
|
||||
# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too).
|
||||
#
|
||||
# Falls back to shutil.copyfileobj when os.copy_file_range is not present.
|
||||
def copy_tensor_ranges(t: LazyNumpyTensor, fout: BufferedWriter):
|
||||
ranges = t._ranges
|
||||
assert len(ranges) > 0
|
||||
dst_offset = fout.tell()
|
||||
extra_offset = best_extra_offset(t, dst_offset)
|
||||
|
||||
if extra_offset > 0:
|
||||
# initial padding
|
||||
fout.write(b"\x00" * extra_offset)
|
||||
|
||||
dst_offset += extra_offset
|
||||
start_offset = dst_offset
|
||||
|
||||
src_files: dict[Path, BufferedReader] = {}
|
||||
for r in ranges:
|
||||
if r.filename not in src_files:
|
||||
src_files[r.filename] = open(r.filename, "rb")
|
||||
|
||||
has_copy_file_range = hasattr(os, "copy_file_range")
|
||||
|
||||
for r in ranges:
|
||||
src = src_files[r.filename]
|
||||
if has_copy_file_range:
|
||||
if r.block_size > 0 and (r.offset % r.block_size) == (start_offset % r.block_size):
|
||||
# Attempting to align copies for reflinking
|
||||
|
||||
# Block 0, 1, 2, 3, 4,
|
||||
# |___0000|0000000|0001111|1111111|111____|
|
||||
#
|
||||
# 1. block 0 is partially overwritten with contents from range[0]
|
||||
# 2. blocks 1 and 2 are copied from range[0] using os.copy_file_range
|
||||
# 3. block 2 is partially overwritten with contents from range[1]
|
||||
# 4. blocks 3 and 4 are copied from range[1] using os.copy_file_range
|
||||
# (repeated for further ranges)
|
||||
if dst_offset % r.block_size == 0:
|
||||
extra_size = 0
|
||||
else:
|
||||
extra_size = r.block_size - (dst_offset % r.block_size)
|
||||
extra_size = min(extra_size, r.size)
|
||||
src.seek(r.offset)
|
||||
buf = src.read(extra_size)
|
||||
fout.seek(dst_offset)
|
||||
fout.write(buf)
|
||||
dst_offset += extra_size
|
||||
if extra_size == r.size:
|
||||
continue
|
||||
|
||||
assert dst_offset % r.block_size == 0, dst_offset % r.block_size
|
||||
|
||||
offset_src = r.offset + extra_size
|
||||
offset_src_end = r.offset + r.size
|
||||
if offset_src_end % r.block_size != 0:
|
||||
offset_src_end += r.block_size - (offset_src_end % r.block_size)
|
||||
size = offset_src_end - offset_src
|
||||
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
|
||||
dst_offset += r.size - extra_size
|
||||
else:
|
||||
# not trying to use reflinks, but still using os.copy_file_range for speed
|
||||
try:
|
||||
os.copy_file_range(src.fileno(), fout.fileno(), r.size, r.offset, dst_offset)
|
||||
except OSError:
|
||||
# fallback when there's a problem (e.g. cross-filesystem copies)
|
||||
src.seek(r.offset)
|
||||
fout.seek(dst_offset)
|
||||
shutil.copyfileobj(src, fout, r.size)
|
||||
dst_offset += r.size
|
||||
else:
|
||||
# not using reflinks, fallback when os.copy_file_range is not supported
|
||||
src.seek(r.offset)
|
||||
fout.seek(dst_offset)
|
||||
shutil.copyfileobj(src, fout, r.size)
|
||||
dst_offset += r.size
|
||||
|
||||
for f in src_files.values():
|
||||
f.close()
|
||||
|
||||
fout.seek(dst_offset)
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
||||
@@ -177,6 +182,10 @@ class SafetensorRemote:
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
|
||||
|
||||
# order by name (same as default safetensors behavior)
|
||||
# ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
|
||||
res = dict(sorted(res.items(), key=lambda t: t[0]))
|
||||
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@@ -266,3 +275,82 @@ class SafetensorRemote:
|
||||
if os.environ.get("HF_TOKEN"):
|
||||
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
|
||||
return headers
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalTensorRange:
|
||||
filename: Path
|
||||
block_size: int
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalTensor:
|
||||
dtype: str
|
||||
shape: tuple[int, ...]
|
||||
data_range: LocalTensorRange
|
||||
|
||||
def mmap_bytes(self) -> np.ndarray:
|
||||
return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)
|
||||
|
||||
|
||||
class SafetensorsLocal:
|
||||
"""
|
||||
Read a safetensors file from the local filesystem.
|
||||
|
||||
Custom parsing gives a bit more control over the memory usage.
|
||||
The official safetensors library doesn't expose file ranges.
|
||||
"""
|
||||
ALIGNMENT = 8 # bytes
|
||||
|
||||
tensors: dict[str, LocalTensor]
|
||||
|
||||
def __init__(self, filename: Path, *, reflink: bool = False):
|
||||
stat = os.stat(filename)
|
||||
# using the preferred block size to signal whether reflinks are desired when copying
|
||||
block_size = stat.st_blksize if reflink else -1
|
||||
with open(filename, "rb") as f:
|
||||
metadata_length = int.from_bytes(f.read(8), byteorder='little')
|
||||
file_size = stat.st_size
|
||||
if file_size < 8 + metadata_length:
|
||||
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
|
||||
|
||||
metadata_str = f.read(metadata_length).decode('utf-8')
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
|
||||
|
||||
data_start_offset = f.tell()
|
||||
alignment = self.ALIGNMENT
|
||||
if data_start_offset % alignment != 0:
|
||||
data_start_offset += alignment - (data_start_offset % alignment)
|
||||
|
||||
tensors: dict[str, LocalTensor] = {}
|
||||
for name, meta in metadata.items():
|
||||
if name == "__metadata__":
|
||||
# ignore metadata, it's not a tensor
|
||||
continue
|
||||
|
||||
tensors[name] = LocalTensor(
|
||||
dtype=meta["dtype"],
|
||||
shape=tuple(meta["shape"]),
|
||||
data_range=LocalTensorRange(
|
||||
filename=filename,
|
||||
block_size=block_size,
|
||||
offset=data_start_offset + meta["data_offsets"][0],
|
||||
size=meta["data_offsets"][1] - meta["data_offsets"][0],
|
||||
),
|
||||
)
|
||||
|
||||
# order by name (same as default safetensors behavior)
|
||||
# ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
|
||||
self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
|
||||
|
||||
def __enter__(self, *args, **kwargs):
|
||||
del args, kwargs # unused
|
||||
return self.tensors
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
del args, kwargs # unused
|
||||
|
||||
Reference in New Issue
Block a user