mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8f071fadd | ||
|
|
0bf47a1dbb | ||
|
|
dd62dcfab9 | ||
|
|
d0660f237a | ||
|
|
fe6a9882ac |
@@ -3435,7 +3435,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-format"}, "FORMAT",
|
||||
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
||||
|
||||
@@ -29,12 +29,29 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||
import gguf
|
||||
from gguf.vocab import MistralTokenizerType, MistralVocab
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
|
||||
_mistral_common_installed = True
|
||||
_mistral_import_error_msg = ""
|
||||
except ImportError:
|
||||
_MISTRAL_COMMON_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||
_MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
_mistral_common_installed = False
|
||||
TokenizerVersion = None
|
||||
Tekkenizer = None
|
||||
SentencePieceTokenizer = None
|
||||
_mistral_import_error_msg = (
|
||||
"Mistral format requires `mistral-common` to be installed. Please run "
|
||||
"`pip install mistral-common[image,audio]` to install it."
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("hf-to-gguf")
|
||||
@@ -73,10 +90,8 @@ class ModelBase:
|
||||
use_temp_file: bool
|
||||
lazy: bool
|
||||
dry_run: bool
|
||||
part_names: list[str]
|
||||
is_safetensors: bool
|
||||
hparams: dict[str, Any]
|
||||
tensor_names: set[str] | None
|
||||
model_tensors: dict[str, Callable[[], Tensor]]
|
||||
gguf_writer: gguf.GGUFWriter
|
||||
model_name: str | None
|
||||
metadata_override: Path | None
|
||||
@@ -107,6 +122,9 @@ class ModelBase:
|
||||
type(self) is MmprojModel:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
|
||||
if self.is_mistral_format and not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
self.dir_model = dir_model
|
||||
self.ftype = ftype
|
||||
self.fname_out = fname_out
|
||||
@@ -117,25 +135,8 @@ class ModelBase:
|
||||
self.dry_run = dry_run
|
||||
self.remote_hf_model_id = remote_hf_model_id
|
||||
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
|
||||
if remote_hf_model_id is not None:
|
||||
self.is_safetensors = True
|
||||
|
||||
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
|
||||
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
|
||||
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
|
||||
self.tensor_names = set(name for name in remote_tensors.keys())
|
||||
for name, remote_tensor in remote_tensors.items():
|
||||
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
|
||||
|
||||
self.get_tensors = get_remote_tensors
|
||||
else:
|
||||
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||
self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||
self.is_safetensors = len(self.part_names) > 0
|
||||
if not self.is_safetensors:
|
||||
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
|
||||
self.tensor_names = None
|
||||
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
|
||||
self.metadata_override = metadata_override
|
||||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
@@ -151,6 +152,8 @@ class ModelBase:
|
||||
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
|
||||
|
||||
self.dequant_model()
|
||||
|
||||
# Configure GGUF Writer
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
||||
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
|
||||
@@ -172,67 +175,215 @@ class ModelBase:
|
||||
return None
|
||||
raise KeyError(f"could not find any of: {keys}")
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
tensor_names_from_parts: set[str] = set()
|
||||
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
|
||||
tensors: dict[str, Callable[[], Tensor]] = {}
|
||||
|
||||
if remote_hf_model_id is not None:
|
||||
is_safetensors = True
|
||||
|
||||
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
|
||||
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
|
||||
for name, remote_tensor in remote_tensors.items():
|
||||
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
|
||||
|
||||
return tensors
|
||||
|
||||
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||
is_safetensors: bool = len(part_names) > 0
|
||||
if not is_safetensors:
|
||||
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
|
||||
tensor_names_from_index: set[str] = set()
|
||||
|
||||
if not self.is_mistral_format:
|
||||
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
|
||||
index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
|
||||
index_name += ".index.json"
|
||||
index_file = self.dir_model / index_name
|
||||
|
||||
if index_file.is_file():
|
||||
self.tensor_names = set()
|
||||
logger.info(f"gguf: loading model weight map from '{index_name}'")
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index: dict[str, Any] = json.load(f)
|
||||
weight_map = index.get("weight_map")
|
||||
if weight_map is None or not isinstance(weight_map, dict):
|
||||
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
|
||||
self.tensor_names.update(weight_map.keys())
|
||||
tensor_names_from_index.update(weight_map.keys())
|
||||
else:
|
||||
self.tensor_names = tensor_names_from_parts
|
||||
weight_map = {}
|
||||
else:
|
||||
self.tensor_names = tensor_names_from_parts
|
||||
weight_map = {}
|
||||
|
||||
for part_name in self.part_names:
|
||||
logger.info(f"gguf: loading model part '{part_name}'")
|
||||
for part_name in part_names:
|
||||
logger.info(f"gguf: indexing model part '{part_name}'")
|
||||
ctx: ContextManager[Any]
|
||||
if self.is_safetensors:
|
||||
if is_safetensors:
|
||||
from safetensors import safe_open
|
||||
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
||||
else:
|
||||
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
||||
|
||||
with ctx as model_part:
|
||||
tensor_names_from_parts.update(model_part.keys())
|
||||
assert model_part is not None
|
||||
|
||||
for name in model_part.keys():
|
||||
if self.is_safetensors:
|
||||
if is_safetensors:
|
||||
if self.lazy:
|
||||
data = model_part.get_slice(name)
|
||||
data = LazyTorchTensor.from_safetensors_slice(data)
|
||||
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
|
||||
else:
|
||||
data = model_part.get_tensor(name)
|
||||
data_gen = lambda data=data: data # noqa: E731
|
||||
else:
|
||||
data = model_part[name]
|
||||
if self.lazy:
|
||||
data = LazyTorchTensor.from_eager(data)
|
||||
yield name, data
|
||||
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
|
||||
else:
|
||||
data_gen = lambda data=data: data # noqa: E731
|
||||
tensors[name] = data_gen
|
||||
|
||||
# verify tensor name presence and identify potentially missing files
|
||||
if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
|
||||
missing = sorted(self.tensor_names.difference(tensor_names_from_parts))
|
||||
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
|
||||
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
|
||||
if len(extra) == 0 and len(missing_files) > 0:
|
||||
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
|
||||
f"Missing tensors: {missing}")
|
||||
if len(tensor_names_from_index) > 0:
|
||||
tensor_names_from_parts = set(tensors.keys())
|
||||
if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0:
|
||||
missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts))
|
||||
extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index))
|
||||
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
|
||||
if len(extra) == 0 and len(missing_files) > 0:
|
||||
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
|
||||
f"Missing tensors: {missing}")
|
||||
else:
|
||||
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
|
||||
f"Missing tensors: {missing}\n"
|
||||
f"Extra tensors: {extra}")
|
||||
|
||||
return tensors
|
||||
|
||||
def dequant_model(self):
|
||||
tensors_to_remove: list[str] = []
|
||||
new_tensors: dict[str, Callable[[], Tensor]] = {}
|
||||
|
||||
if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
|
||||
quant_method = quant_config.get("quant_method")
|
||||
|
||||
def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
|
||||
weight = weight.view(torch.uint8)
|
||||
orig_shape = weight.shape
|
||||
|
||||
shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape)))))
|
||||
data = weight.unsqueeze(0).expand((4, *orig_shape)) >> shift
|
||||
data = data & 3
|
||||
data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:]))
|
||||
|
||||
# The scale is inverted
|
||||
return data / scale.float()
|
||||
|
||||
def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
|
||||
scale = scale.float()
|
||||
|
||||
if (weight_block_size := quant_config.get("weight_block_size")):
|
||||
# TODO: make sure it's a list of integers
|
||||
for i, size in enumerate(weight_block_size):
|
||||
scale = scale.repeat_interleave(size, i)
|
||||
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
|
||||
scale = scale[tuple(slice(0, size) for size in weight.shape)]
|
||||
|
||||
return weight.float() * scale
|
||||
|
||||
# ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
|
||||
def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor:
|
||||
bits = quant_config["bits"]
|
||||
assert bits in (2, 3, 4, 8)
|
||||
assert qweight.dtype == qzeros.dtype
|
||||
maxq = (2 ** bits) - 1
|
||||
weight = None
|
||||
zeros = None
|
||||
pack_dtype_bits = qweight.dtype.itemsize * 8
|
||||
|
||||
if bits in [2, 4, 8]:
|
||||
pack_factor = pack_dtype_bits // bits
|
||||
wf = torch.tensor(list(range(0, pack_dtype_bits, bits)), dtype=torch.int32).unsqueeze(0)
|
||||
if self.lazy:
|
||||
wf = LazyTorchTensor.from_eager(wf)
|
||||
|
||||
zeros = torch.bitwise_right_shift(
|
||||
qzeros.unsqueeze(2).expand(-1, -1, pack_factor),
|
||||
wf.unsqueeze(0)
|
||||
).to(torch.int16 if bits == 8 else torch.int8)
|
||||
zeros = torch.bitwise_and(zeros, maxq).reshape(scales.shape)
|
||||
|
||||
weight = torch.bitwise_and(
|
||||
torch.bitwise_right_shift(
|
||||
qweight.unsqueeze(1).expand(-1, pack_factor, -1),
|
||||
wf.unsqueeze(-1)
|
||||
).to(torch.int16 if bits == 8 else torch.int8),
|
||||
maxq
|
||||
)
|
||||
elif bits == 3:
|
||||
raise NotImplementedError("3-bit gptq dequantization is not yet implemented")
|
||||
|
||||
assert weight is not None
|
||||
assert zeros is not None
|
||||
|
||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||
|
||||
# gptq_v2 doesn't need to offset zeros
|
||||
if quant_config.get("checkpoint_format", "gptq") == "gptq":
|
||||
zeros += 1
|
||||
|
||||
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
|
||||
|
||||
if quant_method == "bitnet":
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale"):
|
||||
weight_name = name.removesuffix("_scale")
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method == "fp8":
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale_inv"):
|
||||
weight_name = name.removesuffix("_scale_inv")
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method == "gptq":
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".qweight"):
|
||||
base_name = name.removesuffix(".qweight")
|
||||
g_idx = self.model_tensors[base_name + ".g_idx"]
|
||||
qweight = self.model_tensors[base_name + ".qweight"]
|
||||
qzeros = self.model_tensors[base_name + ".qzeros"]
|
||||
scales = self.model_tensors[base_name + ".scales"]
|
||||
new_tensors[base_name + ".weight"] = (
|
||||
lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
|
||||
g(), w(), z(), s()
|
||||
)
|
||||
)
|
||||
tensors_to_remove += [
|
||||
base_name + n
|
||||
for n in (
|
||||
".g_idx",
|
||||
".qzeros",
|
||||
".qweight",
|
||||
".scales",
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
|
||||
f"Missing tensors: {missing}\n"
|
||||
f"Extra tensors: {extra}")
|
||||
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
|
||||
|
||||
for name in tensors_to_remove:
|
||||
if name in self.model_tensors:
|
||||
del self.model_tensors[name]
|
||||
|
||||
for name, value in new_tensors.items():
|
||||
self.model_tensors[name] = value
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
for name, gen in self.model_tensors.items():
|
||||
yield name, gen()
|
||||
|
||||
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
|
||||
if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
||||
@@ -1363,8 +1514,8 @@ class MmprojModel(ModelBase):
|
||||
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
|
||||
|
||||
# preprocessor config
|
||||
image_mean = DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
|
||||
image_std = DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
|
||||
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
|
||||
image_std = _MISTRAL_COMMON_DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
|
||||
|
||||
self.gguf_writer.add_vision_image_mean(image_mean)
|
||||
self.gguf_writer.add_vision_image_std(image_std)
|
||||
@@ -2033,6 +2184,9 @@ class LlamaModel(TextModel):
|
||||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
||||
|
||||
def _set_vocab_mistral(self):
|
||||
if not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
vocab = MistralVocab(self.dir_model)
|
||||
logger.info(
|
||||
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
|
||||
@@ -4358,27 +4512,6 @@ class CodeShellModel(TextModel):
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(1.0)
|
||||
|
||||
_has_tok_embd = False
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
|
||||
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
|
||||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
# assuming token_embd.weight is seen before output.weight
|
||||
if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
|
||||
# even though the tensor file(s) does not contain the word embeddings they are still in the weight map
|
||||
if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
|
||||
logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
|
||||
self.tensor_names.remove("transformer.wte.weight")
|
||||
elif new_name == tok_embd_name:
|
||||
self._has_tok_embd = True
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("InternLM2ForCausalLM")
|
||||
class InternLM2Model(TextModel):
|
||||
@@ -9212,7 +9345,7 @@ class MistralModel(LlamaModel):
|
||||
|
||||
@staticmethod
|
||||
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
|
||||
assert TokenizerVersion is not None, "mistral_common is not installed"
|
||||
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
|
||||
assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), (
|
||||
f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}"
|
||||
)
|
||||
@@ -9594,6 +9727,8 @@ def main() -> None:
|
||||
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
|
||||
|
||||
is_mistral_format = args.mistral_format
|
||||
if is_mistral_format and not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
disable_mistral_community_chat_template = args.disable_mistral_community_chat_template
|
||||
|
||||
with torch.inference_mode():
|
||||
|
||||
@@ -14,12 +14,12 @@ except ImportError:
|
||||
SentencePieceProcessor = None
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from mistral_common.tokens.tokenizers.utils import (
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
|
||||
_filter_valid_tokenizer_files,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
except ImportError:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
mistral-common>=1.8.3
|
||||
|
||||
-r ./requirements-convert_legacy_llama.txt
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
|
||||
@@ -6,3 +6,8 @@ target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "AIX")
|
||||
# AIX's flock() function comes from libbsd.a
|
||||
target_link_libraries(${TARGET} PRIVATE -lbsd)
|
||||
endif()
|
||||
|
||||
@@ -76,9 +76,11 @@ struct mtmd_cli_context {
|
||||
|
||||
mtmd::bitmaps bitmaps;
|
||||
|
||||
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
|
||||
// so here we don't need to keep track of chat history
|
||||
// chat template
|
||||
common_chat_templates_ptr tmpls;
|
||||
std::vector<common_chat_msg> chat_history;
|
||||
bool use_jinja = false;
|
||||
// TODO: support for --system-prompt with /clear command
|
||||
|
||||
// support for legacy templates (models not having EOT token)
|
||||
llama_tokens antiprompt_tokens;
|
||||
@@ -108,6 +110,8 @@ struct mtmd_cli_context {
|
||||
}
|
||||
|
||||
tmpls = common_chat_templates_init(model, params.chat_template);
|
||||
use_jinja = params.use_jinja;
|
||||
chat_history.clear();
|
||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str());
|
||||
|
||||
init_vision_context(params);
|
||||
@@ -193,19 +197,33 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::string generated_text = common_detokenize(ctx.lctx, generated_tokens);
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
msg.content = generated_text;
|
||||
ctx.chat_history.push_back(std::move(msg));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
|
||||
common_chat_templates_inputs tmpl_inputs;
|
||||
tmpl_inputs.messages = {msg};
|
||||
tmpl_inputs.add_generation_prompt = true;
|
||||
tmpl_inputs.use_jinja = false; // jinja is buggy here
|
||||
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
|
||||
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
|
||||
static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & new_msg) {
|
||||
LOG_DBG("chat_add_and_format: new_msg.role='%s', new_msg.content='%s'\n",
|
||||
new_msg.role.c_str(), new_msg.content.c_str());
|
||||
auto formatted = common_chat_format_single(ctx.tmpls.get(), ctx.chat_history,
|
||||
new_msg, new_msg.role == "user",
|
||||
ctx.use_jinja);
|
||||
ctx.chat_history.push_back(new_msg);
|
||||
return formatted;
|
||||
}
|
||||
|
||||
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) {
|
||||
bool add_bos = ctx.chat_history.empty();
|
||||
auto formatted_chat = chat_add_and_format(ctx, msg);
|
||||
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str());
|
||||
|
||||
mtmd_input_text text;
|
||||
text.text = formatted_chat.prompt.c_str();
|
||||
text.text = formatted_chat.c_str();
|
||||
text.add_special = add_bos;
|
||||
text.parse_special = true;
|
||||
|
||||
@@ -303,7 +321,7 @@ int main(int argc, char ** argv) {
|
||||
return 1; // error is already printed by libmtmd
|
||||
}
|
||||
}
|
||||
if (eval_message(ctx, msg, true)) {
|
||||
if (eval_message(ctx, msg)) {
|
||||
return 1;
|
||||
}
|
||||
if (!g_is_interrupted && generate_response(ctx, n_predict)) {
|
||||
@@ -322,7 +340,6 @@ int main(int argc, char ** argv) {
|
||||
LOG("\n /quit or /exit exit the program");
|
||||
LOG("\n");
|
||||
|
||||
bool is_first_msg = true;
|
||||
std::string content;
|
||||
|
||||
while (!g_is_interrupted) {
|
||||
@@ -342,7 +359,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
if (line == "/clear") {
|
||||
ctx.n_past = 0;
|
||||
llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS
|
||||
ctx.chat_history.clear();
|
||||
llama_memory_clear(llama_get_memory(ctx.lctx), true);
|
||||
LOG("Chat history cleared\n\n");
|
||||
continue;
|
||||
}
|
||||
@@ -367,7 +385,7 @@ int main(int argc, char ** argv) {
|
||||
common_chat_msg msg;
|
||||
msg.role = "user";
|
||||
msg.content = content;
|
||||
int ret = eval_message(ctx, msg, is_first_msg);
|
||||
int ret = eval_message(ctx, msg);
|
||||
if (ret) {
|
||||
return 1;
|
||||
}
|
||||
@@ -376,7 +394,6 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
content.clear();
|
||||
is_first_msg = false;
|
||||
}
|
||||
}
|
||||
if (g_is_interrupted) LOG("\nInterrupted by user\n");
|
||||
|
||||
@@ -13,5 +13,11 @@ endif ()
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "AIX")
|
||||
# AIX's flock() function comes from libbsd.a
|
||||
target_link_libraries(${TARGET} PRIVATE -lbsd)
|
||||
endif()
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} ${LLAMA_RUN_EXTRA_LIBS})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -5714,6 +5714,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
clean_up();
|
||||
t.join();
|
||||
llama_memory_breakdown_print(ctx_server.ctx);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user