mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-05-28 17:27:26 +03:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fb1e70b59 | ||
|
|
d374e71e55 | ||
|
|
30af6e2b98 | ||
|
|
d7be46189f | ||
|
|
bc81d47aba | ||
|
|
0b246862b9 | ||
|
|
a919001134 | ||
|
|
48e7078ee0 | ||
|
|
bb771cbd2b | ||
|
|
7c48fb81ce | ||
|
|
91eb8f4fa0 | ||
|
|
d205df6812 | ||
|
|
e8d2567429 | ||
|
|
09e7b76c93 | ||
|
|
48e7eae41c | ||
|
|
c5229087a5 | ||
|
|
e31cdaa0eb |
101
.devops/zendnn.Dockerfile
Normal file
101
.devops/zendnn.Dockerfile
Normal file
@@ -0,0 +1,101 @@
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y gcc-13 g++-13 build-essential git cmake libssl-dev libomp-dev libnuma-dev python3 ca-certificates
|
||||
|
||||
ENV CC=gcc-13 CXX=g++-13
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_ZENDNN=ON && \
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r conversion /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
&& cp .devops/tools.sh /app/full/tools.sh
|
||||
|
||||
## Base image
|
||||
FROM ubuntu:$UBUNTU_VERSION AS base
|
||||
|
||||
ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
ARG IMAGE_URL=https://github.com/ggml-org/llama.cpp
|
||||
ARG IMAGE_SOURCE=https://github.com/ggml-org/llama.cpp
|
||||
LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.version=$APP_VERSION \
|
||||
org.opencontainers.image.revision=$APP_REVISION \
|
||||
org.opencontainers.image.title="llama.cpp" \
|
||||
org.opencontainers.image.description="LLM inference in C/C++" \
|
||||
org.opencontainers.image.url=$IMAGE_URL \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 libnuma1 curl \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
|
||||
&& find /var/cache -type f -delete
|
||||
|
||||
COPY --from=build /app/lib/ /app
|
||||
|
||||
### Full
|
||||
FROM base AS full
|
||||
|
||||
COPY --from=build /app/full /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
git \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-wheel \
|
||||
&& pip install --break-system-packages --upgrade setuptools \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
|
||||
&& find /var/cache -type f -delete
|
||||
|
||||
ENTRYPOINT ["/app/tools.sh"]
|
||||
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENTRYPOINT [ "/app/llama-cli" ]
|
||||
|
||||
### Server, Server only
|
||||
FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ]
|
||||
|
||||
ENTRYPOINT [ "/app/llama-server" ]
|
||||
@@ -2998,7 +2998,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
key_file.close();
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_KEY_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"--ssl-key-file"}, "FNAME",
|
||||
"path to file a PEM-encoded SSL private key",
|
||||
|
||||
@@ -119,7 +119,8 @@ class ModelBase:
|
||||
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
|
||||
disable_mistral_community_chat_template: bool = False,
|
||||
sentence_transformers_dense_modules: bool = False,
|
||||
fuse_gate_up_exps: bool = False):
|
||||
fuse_gate_up_exps: bool = False,
|
||||
fp8_as_q8: bool = False):
|
||||
if type(self) is ModelBase or \
|
||||
type(self) is TextModel or \
|
||||
type(self) is MmprojModel:
|
||||
@@ -148,6 +149,8 @@ class ModelBase:
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
self._is_nvfp4 = False
|
||||
self._is_mxfp4 = False
|
||||
self._fp8_as_q8 = fp8_as_q8
|
||||
self._fp8_dequantized: set[str] = set()
|
||||
|
||||
# Apply heuristics to figure out typical tensor encoding based on first tensor's dtype
|
||||
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
|
||||
@@ -429,6 +432,8 @@ class ModelBase:
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
|
||||
tensors_to_remove.append(name)
|
||||
if self._fp8_as_q8:
|
||||
self._fp8_dequantized.add(weight_name)
|
||||
if name.endswith(".activation_scale"): # unused
|
||||
tensors_to_remove.append(name)
|
||||
if name.endswith("_activation_scale"): # Mistral-Small-4-119B-2602, unused
|
||||
@@ -440,6 +445,8 @@ class ModelBase:
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
|
||||
tensors_to_remove.append(name)
|
||||
if self._fp8_as_q8:
|
||||
self._fp8_dequantized.add(weight_name)
|
||||
if name.endswith(".qscale_act"):
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method == "gptq":
|
||||
@@ -483,6 +490,11 @@ class ModelBase:
|
||||
strategy = weight_config.get("strategy")
|
||||
assert strategy == "channel" or strategy == "block"
|
||||
assert weight_config.get("group_size") is None # didn't find a model using this yet
|
||||
is_fp8 = (
|
||||
quant_format == "float-quantized"
|
||||
and weight_config.get("type") == "float"
|
||||
and weight_config.get("num_bits") == 8
|
||||
)
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale"):
|
||||
weight_name = name.removesuffix("_scale")
|
||||
@@ -490,6 +502,8 @@ class ModelBase:
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), block_size)
|
||||
tensors_to_remove.append(name)
|
||||
if self._fp8_as_q8 and is_fp8:
|
||||
self._fp8_dequantized.add(weight_name)
|
||||
elif quant_format == "pack-quantized":
|
||||
assert weight_config.get("strategy") == "group"
|
||||
assert weight_config.get("type", "int") == "int"
|
||||
@@ -524,10 +538,18 @@ class ModelBase:
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale"):
|
||||
weight_name = name.removesuffix("_scale")
|
||||
if weight_name not in self.model_tensors:
|
||||
tensors_to_remove.append(name)
|
||||
continue
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
is_fp8_weight = False
|
||||
if self._fp8_as_q8:
|
||||
is_fp8_weight = w().dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), None)
|
||||
tensors_to_remove.append(name)
|
||||
if is_fp8_weight:
|
||||
self._fp8_dequantized.add(weight_name)
|
||||
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method is not None:
|
||||
@@ -615,8 +637,10 @@ class ModelBase:
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||
del name, new_name, bid, n_dims # unused
|
||||
|
||||
del new_name, bid # unused
|
||||
# Force FP8-original tensors to Q8_0 when requested; Q8_0 is faster than F16/BF16.
|
||||
if self._fp8_as_q8 and name in self._fp8_dequantized and n_dims >= 2:
|
||||
return gguf.GGMLQuantizationType.Q8_0
|
||||
return False
|
||||
|
||||
# some models need extra generated tensors (like rope_freqs)
|
||||
@@ -791,7 +815,7 @@ class ModelBase:
|
||||
if quant_algo != "NVFP4":
|
||||
if nvfp4_compressed_tensors:
|
||||
quant_algo = "NVFP4"
|
||||
elif any(v.get("quant_algo") == "NVFP4" for v in quant_layers.values() if isinstance(v, dict)):
|
||||
elif any(str(v.get("quant_algo")).endswith("NVFP4") for v in quant_layers.values() if isinstance(v, dict)):
|
||||
quant_algo = "NVFP4"
|
||||
|
||||
self._is_nvfp4 = quant_algo == "NVFP4"
|
||||
@@ -2417,10 +2441,9 @@ class MmprojModel(ModelBase):
|
||||
raise KeyError(f"could not find any of: {keys}")
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
del bid, name, n_dims # unused
|
||||
if ".patch_embd.weight" in new_name or ".patch_merger.weight" in new_name:
|
||||
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
|
||||
return False
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
|
||||
class LazyTorchTensor(gguf.LazyBase):
|
||||
|
||||
@@ -148,6 +148,10 @@ def parse_args() -> argparse.Namespace:
|
||||
"--fuse-gate-up-exps", action="store_true",
|
||||
help="Fuse gate_exps and up_exps tensors into a single gate_up_exps tensor for MoE models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp8-as-q8", action="store_true",
|
||||
help="Store tensors dequantized from FP8 as Q8_0 instead of BF16/F16.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.print_supported_models and args.model is None:
|
||||
@@ -264,7 +268,8 @@ def main() -> None:
|
||||
small_first_shard=args.no_tensor_first_split,
|
||||
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
|
||||
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules,
|
||||
fuse_gate_up_exps=args.fuse_gate_up_exps
|
||||
fuse_gate_up_exps=args.fuse_gate_up_exps,
|
||||
fp8_as_q8=args.fp8_as_q8,
|
||||
)
|
||||
|
||||
if args.vocab_only:
|
||||
|
||||
@@ -273,67 +273,51 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
const int sve_register_length = svcntb() * 8; //get vector length
|
||||
const int ggml_f16_epr = sve_register_length / 16; // running when 16
|
||||
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
|
||||
const int ggml_f16_epr = svcnth();
|
||||
const int ggml_f16_step = 8 * ggml_f16_epr;
|
||||
const int np = n - (n % ggml_f16_step);
|
||||
const int np2 = n - (n % ggml_f16_epr);
|
||||
|
||||
const int np= (n & ~(ggml_f16_step - 1));
|
||||
svfloat16_t sum1 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum2 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum3 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum4 = svdup_n_f16(0.0f);
|
||||
svfloat32_t sum1_lo = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum1_hi = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum2_lo = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum2_hi = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum3_lo = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum3_hi = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum4_lo = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum4_hi = svdup_n_f32(0.0f);
|
||||
|
||||
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
|
||||
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
|
||||
sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1);
|
||||
|
||||
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
|
||||
sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2);
|
||||
|
||||
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
|
||||
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
|
||||
sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3);
|
||||
|
||||
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
|
||||
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
||||
sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4);
|
||||
|
||||
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
|
||||
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
|
||||
sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5);
|
||||
|
||||
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
|
||||
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
|
||||
sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6);
|
||||
|
||||
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
|
||||
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
|
||||
sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7);
|
||||
|
||||
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
|
||||
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
|
||||
sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8);
|
||||
ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0), GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0));
|
||||
ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1), GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1));
|
||||
ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2), GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2));
|
||||
ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3), GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3));
|
||||
ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4), GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4));
|
||||
ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5), GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5));
|
||||
ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6), GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6));
|
||||
ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7), GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7));
|
||||
}
|
||||
|
||||
const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8
|
||||
for (int k = np; k < np2; k += ggml_f16_epr) {
|
||||
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
|
||||
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
|
||||
sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry);
|
||||
for (int i = np; i < np2; i += ggml_f16_epr) {
|
||||
ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i, 0), GGML_F16x_VEC_LOAD(y + i, 0));
|
||||
}
|
||||
|
||||
if (np2 < n) {
|
||||
svbool_t pg = svwhilelt_b16(np2, n);
|
||||
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
|
||||
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
|
||||
const svbool_t pg = svwhilelt_b16(np2, n);
|
||||
const svfloat16_t rx = svld1_f16(pg, (const __fp16 *)(x + np2));
|
||||
const svfloat16_t ry = svld1_f16(pg, (const __fp16 *)(y + np2));
|
||||
|
||||
sum1 = svmad_f16_x(pg, hx, hy, sum1);
|
||||
ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, rx, ry);
|
||||
}
|
||||
GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);
|
||||
|
||||
sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum2_lo);
|
||||
sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum2_hi);
|
||||
sum3_lo = svadd_f32_m(DEFAULT_PG32, sum3_lo, sum4_lo);
|
||||
sum3_hi = svadd_f32_m(DEFAULT_PG32, sum3_hi, sum4_hi);
|
||||
sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum3_lo);
|
||||
sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum3_hi);
|
||||
|
||||
sumf = ggml_sve_sum_f32x2(sum1_lo, sum1_hi);
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
#if defined(__riscv_zvfh)
|
||||
int vl = __riscv_vsetvlmax_e32m2();
|
||||
|
||||
@@ -14,6 +14,35 @@
|
||||
// floating point type used to accumulate sums
|
||||
typedef double ggml_float;
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
inline static void ggml_sve_f16_fma_widened(
|
||||
svfloat32_t * acc_lo,
|
||||
svfloat32_t * acc_hi,
|
||||
svfloat16_t x,
|
||||
svfloat16_t y) {
|
||||
#if defined(__ARM_FEATURE_SVE2)
|
||||
*acc_lo = svmlalb_f32(*acc_lo, x, y);
|
||||
*acc_hi = svmlalt_f32(*acc_hi, x, y);
|
||||
#else
|
||||
// Plain SVE fallback path if SVE2 instructions not available
|
||||
svfloat16_t x_even = svtrn1_f16(x, x);
|
||||
svfloat16_t x_odd = svtrn2_f16(x, x);
|
||||
|
||||
svfloat16_t y_even = svtrn1_f16(y, y);
|
||||
svfloat16_t y_odd = svtrn2_f16(y, y);
|
||||
|
||||
svbool_t pg = svptrue_b32();
|
||||
|
||||
*acc_lo = svmla_f32_x(pg, *acc_lo, svcvt_f32_f16_x(pg, x_even), svcvt_f32_f16_x(pg, y_even));
|
||||
*acc_hi = svmla_f32_x(pg, *acc_hi, svcvt_f32_f16_x(pg, x_odd), svcvt_f32_f16_x(pg, y_odd));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static ggml_float ggml_sve_sum_f32x2(svfloat32_t sum_lo, svfloat32_t sum_hi) {
|
||||
return (ggml_float) (svaddv_f32(svptrue_b32(), sum_lo) + svaddv_f32(svptrue_b32(), sum_hi));
|
||||
}
|
||||
#endif
|
||||
|
||||
#define GGML_GELU_FP16
|
||||
#define GGML_GELU_QUICK_FP16
|
||||
|
||||
@@ -122,108 +151,61 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
||||
#if defined(GGML_SIMD)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
|
||||
const int sve_register_length = svcntb() * 8;
|
||||
const int ggml_f16_epr = sve_register_length / 16; // running when 16
|
||||
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
|
||||
const int ggml_f16_epr = svcnth();
|
||||
const int ggml_f16_step = 2 * ggml_f16_epr;
|
||||
int np = n - (n % ggml_f16_step);
|
||||
int np2 = n - (n % ggml_f16_epr);
|
||||
|
||||
int np = (n & ~(ggml_f16_step - 1));
|
||||
|
||||
svfloat16_t sum_00 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_01 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_02 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_03 = svdup_n_f16(0.0f);
|
||||
|
||||
svfloat16_t sum_10 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_11 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_12 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_13 = svdup_n_f16(0.0f);
|
||||
|
||||
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
|
||||
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
|
||||
svfloat32_t sum_0_0_lo = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum_0_0_hi = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum_0_1_lo = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum_0_1_hi = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum_1_0_lo = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum_1_0_hi = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum_1_1_lo = svdup_n_f32(0.0f);
|
||||
svfloat32_t sum_1_1_hi = svdup_n_f32(0.0f);
|
||||
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
|
||||
const svfloat16_t ay0 = GGML_F16x_VEC_LOAD(y + i, 0);
|
||||
const svfloat16_t ax00 = GGML_F16x_VEC_LOAD(x[0] + i, 0);
|
||||
const svfloat16_t ax01 = GGML_F16x_VEC_LOAD(x[1] + i, 0);
|
||||
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements
|
||||
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
|
||||
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
|
||||
ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax00, ay0);
|
||||
ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax01, ay0);
|
||||
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
|
||||
const svfloat16_t ay1 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 0);
|
||||
const svfloat16_t ax10 = GGML_F16x_VEC_LOAD(x[0] + i + 1 * ggml_f16_epr, 0);
|
||||
const svfloat16_t ax11 = GGML_F16x_VEC_LOAD(x[1] + i + 1 * ggml_f16_epr, 0);
|
||||
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements
|
||||
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
|
||||
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
|
||||
|
||||
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
|
||||
|
||||
ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
|
||||
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
|
||||
ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
|
||||
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
|
||||
|
||||
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
||||
|
||||
ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3);
|
||||
sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4);
|
||||
ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3);
|
||||
sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4);
|
||||
|
||||
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
|
||||
|
||||
ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4);
|
||||
|
||||
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5);
|
||||
ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4);
|
||||
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5);
|
||||
|
||||
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
|
||||
|
||||
ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5);
|
||||
|
||||
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6);
|
||||
ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5);
|
||||
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6);
|
||||
|
||||
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
|
||||
|
||||
ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6);
|
||||
|
||||
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7);
|
||||
ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6);
|
||||
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7);
|
||||
|
||||
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
|
||||
|
||||
ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7);
|
||||
|
||||
sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8);
|
||||
ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7);
|
||||
sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8);
|
||||
ggml_sve_f16_fma_widened(&sum_0_1_lo, &sum_0_1_hi, ax10, ay1);
|
||||
ggml_sve_f16_fma_widened(&sum_1_1_lo, &sum_1_1_hi, ax11, ay1);
|
||||
}
|
||||
|
||||
const int np2 = (n & ~(ggml_f16_epr - 1));
|
||||
for (int k = np; k < np2; k += ggml_f16_epr) {
|
||||
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
|
||||
for (int i = np; i < np2; i += ggml_f16_epr) {
|
||||
const svfloat16_t ry = GGML_F16x_VEC_LOAD(y + i, 0);
|
||||
const svfloat16_t rx0 = GGML_F16x_VEC_LOAD(x[0] + i, 0);
|
||||
const svfloat16_t rx1 = GGML_F16x_VEC_LOAD(x[1] + i, 0);
|
||||
|
||||
svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0);
|
||||
sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry);
|
||||
rx = GGML_F16x_VEC_LOAD(x[1] + k, 0);
|
||||
sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry);
|
||||
ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, rx0, ry);
|
||||
ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, rx1, ry);
|
||||
}
|
||||
|
||||
if (np2 < n) {
|
||||
svbool_t pg = svwhilelt_b16(np2, n);
|
||||
svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));
|
||||
svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));
|
||||
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
|
||||
const svbool_t pg = svwhilelt_b16(np2, n);
|
||||
const svfloat16_t ay = svld1_f16(pg, (const __fp16 *)(y + np2));
|
||||
const svfloat16_t ax0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));
|
||||
const svfloat16_t ax1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));
|
||||
|
||||
sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00);
|
||||
sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10);
|
||||
ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax0, ay);
|
||||
ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax1, ay);
|
||||
}
|
||||
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
|
||||
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
|
||||
|
||||
svfloat32_t sum_0_lo = svadd_f32_x(DEFAULT_PG32, sum_0_0_lo, sum_0_1_lo);
|
||||
svfloat32_t sum_0_hi = svadd_f32_x(DEFAULT_PG32, sum_0_0_hi, sum_0_1_hi);
|
||||
svfloat32_t sum_1_lo = svadd_f32_x(DEFAULT_PG32, sum_1_0_lo, sum_1_1_lo);
|
||||
svfloat32_t sum_1_hi = svadd_f32_x(DEFAULT_PG32, sum_1_0_hi, sum_1_1_hi);
|
||||
sumf[0] = ggml_sve_sum_f32x2(sum_0_lo, sum_0_hi);
|
||||
sumf[1] = ggml_sve_sum_f32x2(sum_1_lo, sum_1_hi);
|
||||
np = n;
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
#if defined(__riscv_zvfh)
|
||||
|
||||
@@ -472,7 +472,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
|
||||
const int i = 8 * (threadIdx.x % (nbatch_fa/8));
|
||||
|
||||
cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
|
||||
cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + int64_t(j_vram)*stride_mask + i);
|
||||
}
|
||||
} else if constexpr (oob_check) {
|
||||
#pragma unroll
|
||||
@@ -488,7 +488,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
|
||||
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[int64_t(j_vram)*stride_mask + i] : half(0.0f);
|
||||
}
|
||||
}
|
||||
} else if constexpr (nbatch_fa < 2*warp_size) {
|
||||
@@ -505,7 +505,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
|
||||
const int i = threadIdx.x % (warp_size/cols_per_warp);
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + int64_t(j_vram)*stride_mask + 2*i);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
@@ -521,7 +521,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
|
||||
const int i = i0 + 2*threadIdx.x;
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + int64_t(j_vram)*stride_mask + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2570,6 +2570,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
|
||||
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
|
||||
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
|
||||
use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
}
|
||||
} else {
|
||||
@@ -2578,6 +2579,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
|
||||
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
|
||||
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
|
||||
use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
}
|
||||
|
||||
@@ -4992,8 +4994,14 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
|
||||
GGML_UNUSED(dev);
|
||||
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context;
|
||||
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
|
||||
|
||||
return prop.integrated
|
||||
? GGML_BACKEND_DEVICE_TYPE_IGPU
|
||||
: GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||
|
||||
@@ -63,6 +63,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
|
||||
|
||||
enum mmvq_parameter_table_id {
|
||||
MMVQ_PARAMETERS_GENERIC = 0,
|
||||
MMVQ_PARAMETERS_TURING,
|
||||
MMVQ_PARAMETERS_GCN,
|
||||
MMVQ_PARAMETERS_RDNA2,
|
||||
MMVQ_PARAMETERS_RDNA3_0,
|
||||
@@ -78,6 +79,8 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
|
||||
return MMVQ_PARAMETERS_RDNA2;
|
||||
#elif defined(GCN) || defined(CDNA)
|
||||
return MMVQ_PARAMETERS_GCN;
|
||||
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING && __CUDA_ARCH__ < GGML_CUDA_CC_AMPERE
|
||||
return MMVQ_PARAMETERS_TURING;
|
||||
#else
|
||||
return MMVQ_PARAMETERS_GENERIC;
|
||||
#endif
|
||||
@@ -96,6 +99,9 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
|
||||
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return MMVQ_PARAMETERS_GCN;
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING && ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_AMPERE) {
|
||||
return MMVQ_PARAMETERS_TURING;
|
||||
}
|
||||
return MMVQ_PARAMETERS_GENERIC;
|
||||
}
|
||||
|
||||
@@ -271,6 +277,53 @@ int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
|
||||
return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
if (GGML_CUDA_CC_IS_CDNA1(cc)) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
return ne11 <= 7;
|
||||
case GGML_TYPE_Q5_1:
|
||||
return ne11 <= 7;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return ne11 <= 6;
|
||||
case GGML_TYPE_Q2_K:
|
||||
return ne11 <= 4;
|
||||
case GGML_TYPE_Q3_K:
|
||||
return ne11 <= 3;
|
||||
case GGML_TYPE_Q4_K:
|
||||
return ne11 <= 2;
|
||||
case GGML_TYPE_Q5_K:
|
||||
return ne11 <= 3;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return ne11 <= 4;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
return ne11 <= 5;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
return ne11 <= 6;
|
||||
default:
|
||||
return ne11 <= MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
}
|
||||
switch (type) { // tuned for CDNA2
|
||||
case GGML_TYPE_Q2_K:
|
||||
return ne11 <= 5;
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
return ne11 <= 3;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return ne11 <= 5;
|
||||
default:
|
||||
return ne11 <= MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
}
|
||||
return ne11 <= MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
// Device constexpr: returns the max batch size for the current arch+type at compile time.
|
||||
template <ggml_type type>
|
||||
static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() {
|
||||
@@ -370,11 +423,38 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
if (table_id == MMVQ_PARAMETERS_TURING) {
|
||||
if (ncols_dst == 1) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
return 2;
|
||||
default:
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
switch (ncols_dst) {
|
||||
case 2:
|
||||
case 3:
|
||||
case 4:
|
||||
return 4;
|
||||
case 5:
|
||||
case 6:
|
||||
case 7:
|
||||
case 8:
|
||||
return 2;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) {
|
||||
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
|
||||
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN || table_id == MMVQ_PARAMETERS_TURING) {
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
return small_k ? nwarps : 1;
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
||||
|
||||
bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11);
|
||||
|
||||
// Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID,
|
||||
// based on the quantization type and GPU architecture (compute capability).
|
||||
int get_mmvq_mmid_max_batch(ggml_type type, int cc);
|
||||
|
||||
@@ -58,15 +58,16 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
|
||||
if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-queue.c
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-queue.c
|
||||
)
|
||||
|
||||
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
|
||||
set_source_files_properties(
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-queue.c
|
||||
PROPERTIES COMPILE_OPTIONS "-mhmx"
|
||||
)
|
||||
|
||||
|
||||
@@ -22,6 +22,16 @@
|
||||
// Must be multiple of 32
|
||||
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
|
||||
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
|
||||
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
// This is a bit of a hack because the compiler is strugling to properly inline
|
||||
// the default hvx_vec_f32_to_f16 with output into the local array.
|
||||
static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
||||
@@ -54,8 +64,8 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
||||
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
|
||||
}
|
||||
|
||||
HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
|
||||
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)));
|
||||
HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p));
|
||||
rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
@@ -105,10 +115,10 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
|
||||
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
|
||||
}
|
||||
|
||||
HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
|
||||
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
|
||||
HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)));
|
||||
HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)));
|
||||
HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p));
|
||||
HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p));
|
||||
HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p));
|
||||
HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p));
|
||||
|
||||
HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
|
||||
return hvx_vec_reduce_sum_f32x4(rsum0123);
|
||||
@@ -123,7 +133,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
|
||||
const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
const size_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector sums; // initialize at j = 0
|
||||
HVX_Vector sums = Q6_V_vzero();
|
||||
const size_t stride_x_4 = stride_x * 4;
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
|
||||
HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
|
||||
@@ -132,8 +142,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
|
||||
x += stride_x_4;
|
||||
}
|
||||
|
||||
sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums);
|
||||
return Q6_Vsf_equals_Vqf32(sums);
|
||||
return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums);
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * s (F16)
|
||||
@@ -268,11 +277,10 @@ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t *
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; ++i) {
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
|
||||
vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs);
|
||||
}
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
|
||||
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -438,25 +446,44 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
||||
// Process in sub-blocks of 32 (VLEN_FP32)
|
||||
HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32];
|
||||
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
|
||||
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
||||
for (uint32_t iv = 0; ic < current_block_size; ic += VLEN_FP32, ++iv) {
|
||||
// 1. Compute scores
|
||||
HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale);
|
||||
|
||||
// 2. Softcap
|
||||
if (factx->logit_softcap != 0.0f) {
|
||||
scores = hvx_vec_tanh_f32(scores);
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
scores = HVX_OP_MUL_F32(scores, logit_cap);
|
||||
}
|
||||
|
||||
// 3. Mask
|
||||
if (mask) {
|
||||
const __fp16 * mp = m_base + ic;
|
||||
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
|
||||
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
|
||||
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
|
||||
// Multiplying -INFINITY (0xFC00) by a slope in VhfVhf instructions can incorrectly produce NaN on v79.
|
||||
// Clamp -INFINITY to the max negative fp16 finite value (-65504.0f).
|
||||
HVX_Vector vinf = Q6_Vh_vsplat_R(0xFC00);
|
||||
HVX_Vector vmin = Q6_Vh_vsplat_R(0xFBFF);
|
||||
HVX_VectorPred is_inf = Q6_Q_vcmp_eq_VhVh(m_vals_f16, vinf);
|
||||
m_vals_f16 = Q6_V_vmux_QVV(is_inf, vmin, m_vals_f16);
|
||||
|
||||
#if __HVX_ARCH__ >= 79
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
|
||||
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
|
||||
scores = Q6_Vsf_vadd_VsfVsf(add_val, scores);
|
||||
#else
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
|
||||
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
|
||||
scores = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Mask out invalid lanes for leftover handling
|
||||
uint32_t valid_lanes = current_block_size - ic;
|
||||
if (valid_lanes < VLEN_FP32) {
|
||||
HVX_VectorPred valid_pred = Q6_Q_vsetq_R(valid_lanes * 4); // 4 bytes per fp32 lane
|
||||
scores = Q6_V_vmux_QVV(valid_pred, scores, hvx_vec_splat_f32(-INFINITY));
|
||||
}
|
||||
|
||||
sb_scores[iv] = scores;
|
||||
@@ -466,78 +493,55 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
||||
{
|
||||
// 4. Online Softmax Update
|
||||
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
|
||||
HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec));
|
||||
HVX_Vector diff_vec = HVX_OP_SUB_F32(M_vec, M_new_vec);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
||||
M_vec = M_new_vec;
|
||||
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
|
||||
for (uint32_t ic2 = 0, iv = 0; ic2 < current_block_size; ic2 += VLEN_FP32, ++iv) {
|
||||
HVX_Vector scores = sb_scores[iv];
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
HVX_Vector scores_shifted = HVX_OP_SUB_F32(scores, M_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(scores_shifted);
|
||||
|
||||
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
|
||||
p_sum_vec = HVX_OP_ADD_F32(p_sum_vec, P);
|
||||
|
||||
// 5. Accumulate V
|
||||
__fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16];
|
||||
hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0));
|
||||
|
||||
float __attribute__((aligned(128))) P_arr[VLEN_FP32];
|
||||
hvx_vec_store_a(P_arr, 128, P);
|
||||
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
if (cur_ic >= current_block_size) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (cur_ic + 1 == current_block_size) {
|
||||
// Odd leftover, process single row
|
||||
if (P_arr[j] != 0.0f) {
|
||||
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, (p_arr + j), DV);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Avoid NaN * 0.0 = NaN for uninitialized V cache rows.
|
||||
// Check the f32 values to safely avoid strict aliasing violations.
|
||||
if (P_arr[j] == 0.0f && P_arr[j + 1] == 0.0f) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV);
|
||||
}
|
||||
}
|
||||
|
||||
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
|
||||
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
|
||||
}
|
||||
|
||||
if (ic < current_block_size) {
|
||||
// Sync scalars for leftover/next block if needed
|
||||
float M = hvx_vec_get_f32(M_vec);
|
||||
float S = hvx_vec_get_f32(S_vec);
|
||||
|
||||
// Leftover
|
||||
for (; ic < current_block_size; ++ic) {
|
||||
float s_val;
|
||||
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
|
||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
|
||||
if (factx->logit_softcap != 0.0f) {
|
||||
s_val = factx->logit_softcap * tanhf(s_val);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
const float m_val = m_base[ic];
|
||||
s_val += slope * m_val;
|
||||
}
|
||||
|
||||
const float Mold = M;
|
||||
__fp16 vs = 1.0f;
|
||||
|
||||
if (s_val > M) {
|
||||
M = s_val;
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
float ms = hvx_vec_get_f32(ms_vec);
|
||||
S = S * ms + vs;
|
||||
} else {
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
|
||||
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
|
||||
S += vs;
|
||||
}
|
||||
|
||||
const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
|
||||
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV);
|
||||
}
|
||||
|
||||
M_vec = hvx_vec_splat_f32(M);
|
||||
S_vec = hvx_vec_splat_f32(S);
|
||||
S_vec = HVX_OP_ADD_F32(HVX_OP_MUL_F32(S_vec, ms_vec), p_sum_vec);
|
||||
}
|
||||
|
||||
// Issue DMA for next+1 block (if exists)
|
||||
@@ -599,8 +603,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// dst is permuted
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
|
||||
// dst is permuted: [DV, n_heads, n_tokens, n_seq]
|
||||
// head stride is nb[1], token stride is nb[2], batch stride is nb[3]
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + i2 * dst->nb[1] + i1 * dst->nb[2] + i3 * dst->nb[3];
|
||||
|
||||
if (dst->type == HTP_TYPE_F32) {
|
||||
hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
@@ -623,8 +628,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
// HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV
|
||||
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) {
|
||||
// HMX path: head_dim multiple of 32, F16 KV
|
||||
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0) {
|
||||
int ret = hmx_flash_attn_ext(octx);
|
||||
if (ret == HTP_STATUS_OK) {
|
||||
return ret;
|
||||
|
||||
@@ -1248,9 +1248,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
if (DK % 32 != 0 || DV % 32 != 0) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if (neq1 < 32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
// GQA factor
|
||||
const uint32_t n_kv_heads = k->ne[2];
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hex-fastdiv.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#include "hvx-utils.h"
|
||||
@@ -187,45 +188,44 @@ next_nc:
|
||||
// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles
|
||||
// of the same 32 packed bytes.
|
||||
static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) {
|
||||
(void)vlut_cvt;
|
||||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
|
||||
HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale));
|
||||
// q4x4x2 stores two int4 values per byte. Keep only the selected nibble.
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
// Shuffle before LUT
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
// Use standard vlut16 (not _nomatch) to avoid stale-register NaN.
|
||||
// _nomatch retains the previous destination-register value for colliding
|
||||
// indices, but the C intrinsic doesn't model the implicit read so the
|
||||
// compiler may allocate a register containing garbage/NaN.
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_hf = Q6_V_lo_W(vp);
|
||||
|
||||
HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8);
|
||||
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_int8));
|
||||
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
|
||||
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
|
||||
}
|
||||
|
||||
// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using
|
||||
// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls.
|
||||
// full HVX vector width.
|
||||
// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes.
|
||||
static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
const uint8_t *packed_128, bool upper_nibbles,
|
||||
const __fp16 *scales_4, const HVX_Vector vlut_cvt) {
|
||||
// Load all 128 packed bytes (4 contiguous 32-byte groups)
|
||||
(void)vlut_cvt;
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
// Shuffle before LUT
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8);
|
||||
|
||||
// Full-width vlut16: 128 byte lookups -> 128 fp16 results in a VectorPair
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp); // [group0: 32 fp16 | group1: 32 fp16]
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16]
|
||||
HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_int8);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp_int16);
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp_int16);
|
||||
|
||||
v_lo = Q6_Vhf_equals_Vh(v_lo);
|
||||
v_hi = Q6_Vhf_equals_Vh(v_hi);
|
||||
|
||||
// Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b
|
||||
HVX_Vector vscale = hvx_vmemu(scales_4);
|
||||
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale);
|
||||
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4));
|
||||
@@ -233,13 +233,12 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
// Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter
|
||||
HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */
|
||||
v_hi /* group2 already in [0:63] */ };
|
||||
HVX_Vector_x2 r = { v_lo, v_hi };
|
||||
return r;
|
||||
}
|
||||
|
||||
static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) {
|
||||
(void)vlut_cvt;
|
||||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_dm = hvx_vmemu(scale_offset);
|
||||
@@ -248,9 +247,9 @@ static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32
|
||||
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_hf = Q6_V_lo_W(vp);
|
||||
|
||||
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_quants));
|
||||
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
|
||||
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets));
|
||||
}
|
||||
@@ -258,16 +257,18 @@ static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32
|
||||
static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx(
|
||||
const uint8_t *packed_128, bool upper_nibbles,
|
||||
const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) {
|
||||
(void)vlut_cvt;
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_quants);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp_int16);
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp_int16);
|
||||
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp);
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp);
|
||||
v_lo = Q6_Vhf_equals_Vh(v_lo);
|
||||
v_hi = Q6_Vhf_equals_Vh(v_hi);
|
||||
|
||||
HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4);
|
||||
HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2);
|
||||
@@ -287,6 +288,45 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx(
|
||||
return r;
|
||||
}
|
||||
|
||||
// LUT-based dequantizers for non-linear IQ4_NL format.
|
||||
static inline HVX_Vector dequantize_x4x2_iq4_nl_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale));
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_hf = Q6_V_lo_W(vp);
|
||||
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
|
||||
}
|
||||
|
||||
static inline HVX_Vector_x2 dequantize_x4x2_iq4_nl_x4groups_hvx(
|
||||
const uint8_t *packed_128, bool upper_nibbles,
|
||||
const __fp16 *scales_4, const HVX_Vector vlut_cvt) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp);
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp);
|
||||
|
||||
HVX_Vector vscale = hvx_vmemu(scales_4);
|
||||
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale);
|
||||
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4));
|
||||
|
||||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
HVX_Vector_x2 r = { v_lo, v_hi };
|
||||
return r;
|
||||
}
|
||||
|
||||
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
|
||||
static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) {
|
||||
HVX_Vector vq = hvx_vmemu(quants_32);
|
||||
@@ -374,122 +414,176 @@ static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t *
|
||||
return r;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
__fp16 *dst;
|
||||
const uint8_t *src;
|
||||
int n_cols;
|
||||
int k_block;
|
||||
size_t row_stride;
|
||||
int weight_type;
|
||||
int n_tot_tiles;
|
||||
int n_tiles_per_task;
|
||||
int n_tasks;
|
||||
int n_k_tiles;
|
||||
struct fastdiv_values n_k_tiles_div;
|
||||
} x4x2_dequantize_state_t;
|
||||
|
||||
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
|
||||
// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes.
|
||||
// Output: vtcm_dst in tile-major FP16 layout.
|
||||
static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
__fp16 *restrict vtcm_dst,
|
||||
const uint8_t *restrict vtcm_src,
|
||||
int n_cols, int k_block,
|
||||
size_t row_stride, int weight_type,
|
||||
|
||||
#define DEFINE_DEQUANTIZE_Q4_TASK(suffix, lut_name, helper_prefix, dblk_size, scale_step) \
|
||||
static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix( \
|
||||
const x4x2_dequantize_state_t *state, \
|
||||
int start_tile, int end_tile) { \
|
||||
\
|
||||
const int n_k_tiles = state->n_k_tiles; \
|
||||
const int qrow_size = (unsigned)state->k_block / 2; \
|
||||
const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; \
|
||||
const HVX_Vector vlut_cvt = hvx_vmem(lut_name); \
|
||||
\
|
||||
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); \
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); \
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); \
|
||||
\
|
||||
unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); \
|
||||
unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); \
|
||||
\
|
||||
for (unsigned t = start_tile; t < (unsigned)end_tile; ) { \
|
||||
if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } \
|
||||
\
|
||||
if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { \
|
||||
unsigned blk_idx = ((kt * 32) / QK_Q4_0x4x2); \
|
||||
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; \
|
||||
bool upper = (sub_blk_base >= 4); \
|
||||
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); \
|
||||
unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk_base * (scale_step); \
|
||||
\
|
||||
__fp16 *tile_bases[4]; \
|
||||
for (unsigned g = 0; g < 4; g++) { \
|
||||
tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; \
|
||||
} \
|
||||
\
|
||||
HVX_Vector v_off = v_scat_base; \
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \
|
||||
\
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { \
|
||||
const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \
|
||||
const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \
|
||||
\
|
||||
HVX_Vector_x2 dv0 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \
|
||||
r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); \
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); \
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \
|
||||
\
|
||||
HVX_Vector_x2 dv1 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \
|
||||
r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); \
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); \
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); \
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \
|
||||
} \
|
||||
\
|
||||
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } \
|
||||
t += 4; kt += 4; \
|
||||
continue; \
|
||||
} \
|
||||
\
|
||||
__fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; \
|
||||
{ \
|
||||
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; \
|
||||
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; \
|
||||
bool upper = (sub_blk >= 4); \
|
||||
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; \
|
||||
unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk * (scale_step); \
|
||||
\
|
||||
HVX_Vector v_off = v_scat_base; \
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; \
|
||||
\
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { \
|
||||
const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \
|
||||
const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \
|
||||
\
|
||||
HVX_Vector v0 = dequantize_x4x2_##helper_prefix##_group_hvx( \
|
||||
r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \
|
||||
HVX_Vector v1 = (row1 < (unsigned)state->n_cols) \
|
||||
? dequantize_x4x2_##helper_prefix##_group_hvx( \
|
||||
r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) \
|
||||
: Q6_V_vzero(); \
|
||||
\
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); \
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); \
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \
|
||||
} \
|
||||
(void) *(volatile HVX_Vector *)(tile_base); \
|
||||
} \
|
||||
++t; ++kt; \
|
||||
} \
|
||||
\
|
||||
if (start_tile < end_tile) { \
|
||||
(void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \
|
||||
int start = task_id * state->n_tiles_per_task; \
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \
|
||||
} \
|
||||
}
|
||||
|
||||
DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16))
|
||||
DEFINE_DEQUANTIZE_Q4_TASK(q4_1, q4_1_to_fp16_lut, q4_1, 32, 4)
|
||||
DEFINE_DEQUANTIZE_Q4_TASK(iq4_nl, iq4_nl_to_fp16_lut, iq4_nl, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16))
|
||||
|
||||
static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(
|
||||
const x4x2_dequantize_state_t *state,
|
||||
int start_tile, int end_tile) {
|
||||
|
||||
const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS;
|
||||
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL);
|
||||
const bool is_q4_1 = (weight_type == HTP_TYPE_Q4_1);
|
||||
const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block;
|
||||
const int n_k_tiles = state->n_k_tiles;
|
||||
const int qrow_size = state->k_block;
|
||||
const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div;
|
||||
const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut);
|
||||
|
||||
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
|
||||
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
|
||||
(weight_type == HTP_TYPE_Q4_1) ? hvx_vmem(q4_1_to_fp16_lut) :
|
||||
hvx_vmem(q4_0_to_fp16_lut);
|
||||
|
||||
// vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
|
||||
// Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128
|
||||
// maps to K-rows 2i and 2i+1. Column offset (n*4) added per row.
|
||||
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes)
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
|
||||
|
||||
unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index
|
||||
unsigned kt = (unsigned)start_tile % n_k_tiles; // K tile index
|
||||
for (unsigned t = start_tile; t < end_tile; ) {
|
||||
if (kt >= n_k_tiles) { kt = 0; ct++; }
|
||||
unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div);
|
||||
unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div);
|
||||
|
||||
// --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row ---
|
||||
if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||||
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
|
||||
unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE;
|
||||
unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16);
|
||||
unsigned scale_off = qrow_size + blk_idx * dblk_size
|
||||
+ sub_blk_base * scale_step;
|
||||
for (unsigned t = start_tile; t < (unsigned)end_tile; ) {
|
||||
if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; }
|
||||
|
||||
__fp16 *tile_bases[4];
|
||||
for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
|
||||
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
|
||||
|
||||
if (is_q4_1) {
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_1_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
} else {
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); }
|
||||
t += 4; kt += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
// --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales ---
|
||||
if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||||
// Batch-4 fast path for MXFP4
|
||||
if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) {
|
||||
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
|
||||
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4
|
||||
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32;
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes
|
||||
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales
|
||||
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2);
|
||||
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE;
|
||||
|
||||
__fp16 * tile_bases[4];
|
||||
for (int g = 0; g < 4; g++) {
|
||||
tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS;
|
||||
tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
const uint8_t * r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t * r1 = vtcm_src + row1 * row_stride;
|
||||
const uint8_t * r0 = state->src + row0 * state->row_stride;
|
||||
const uint8_t * r1 = state->src + row1 * state->row_stride;
|
||||
|
||||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||||
|
||||
HVX_Vector_x4 dv0, dv1;
|
||||
dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8);
|
||||
if (row1 < n_cols) {
|
||||
if (row1 < state->n_cols) {
|
||||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||||
dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8);
|
||||
} else {
|
||||
@@ -510,58 +604,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
(void) *(volatile HVX_Vector *) (tile_bases[g]);
|
||||
}
|
||||
|
||||
t += 4;
|
||||
t += 4; kt += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
// --- Single-tile fallback ---
|
||||
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
if (is_q4) {
|
||||
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
|
||||
unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE;
|
||||
unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16);
|
||||
unsigned scale_off = qrow_size + blk_idx * dblk_size + sub_blk * scale_step;
|
||||
|
||||
HVX_Vector v_off = v_scat_base; // reset to column 0
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
|
||||
if (is_q4_1) {
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_q4_1_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector v1 = (row1 < n_cols)
|
||||
? dequantize_x4x2_q4_1_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
|
||||
: Q6_V_vzero();
|
||||
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
} else {
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector v1 = (row1 < n_cols)
|
||||
? dequantize_x4x2_q4_0_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
|
||||
: Q6_V_vzero();
|
||||
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
}
|
||||
(void) *(volatile HVX_Vector *)(tile_base);
|
||||
} else if (weight_type == HTP_TYPE_MXFP4) {
|
||||
// Single-tile fallback
|
||||
__fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS;
|
||||
{
|
||||
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
|
||||
int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
@@ -573,15 +622,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
|
||||
const uint8_t * r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t * r1 = vtcm_src + row1 * row_stride;
|
||||
const uint8_t * r0 = state->src + row0 * state->row_stride;
|
||||
const uint8_t * r1 = state->src + row1 * state->row_stride;
|
||||
|
||||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8);
|
||||
HVX_Vector v1;
|
||||
if (row1 < n_cols) {
|
||||
if (row1 < state->n_cols) {
|
||||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||||
v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8);
|
||||
} else {
|
||||
@@ -594,23 +642,59 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
(void) *(volatile HVX_Vector *) (tile_base);
|
||||
} else {
|
||||
// Q8_0
|
||||
}
|
||||
++t; ++kt;
|
||||
}
|
||||
|
||||
if (start_tile < end_tile) {
|
||||
(void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS);
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) {
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||||
int start = task_id * state->n_tiles_per_task;
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end);
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
|
||||
const x4x2_dequantize_state_t *state,
|
||||
int start_tile, int end_tile) {
|
||||
|
||||
const int n_k_tiles = state->n_k_tiles;
|
||||
const int qrow_size = state->k_block;
|
||||
const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div;
|
||||
|
||||
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
|
||||
|
||||
unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div);
|
||||
unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div);
|
||||
|
||||
for (unsigned t = start_tile; t < (unsigned)end_tile; ) {
|
||||
if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; }
|
||||
|
||||
__fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS;
|
||||
{
|
||||
int blk_idx = (kt * 32) / QK_Q8_0x4x2;
|
||||
int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32;
|
||||
int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32;
|
||||
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
|
||||
|
||||
HVX_Vector v_off = v_scat_base; // reset to column 0
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
|
||||
const uint8_t *r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row1 * row_stride;
|
||||
const uint8_t *r0 = state->src + row0 * state->row_stride;
|
||||
const uint8_t *r1 = state->src + row1 * state->row_stride;
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off));
|
||||
HVX_Vector v1 = (row1 < n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero();
|
||||
HVX_Vector v1 = (row1 < state->n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero();
|
||||
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
@@ -622,50 +706,31 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
++t; ++kt;
|
||||
}
|
||||
|
||||
// Drain HVX scatter write buffer: a vmem load on the same HW thread retires
|
||||
// all pending scatter entries to VTCM. Without this, the main thread's HMX
|
||||
// reads may see stale data because atomic_fetch_sub (release) only orders
|
||||
// regular stores, not the HVX scatter buffer.
|
||||
if (start_tile < end_tile) {
|
||||
(void) *(volatile HVX_Vector *)(vtcm_dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS);
|
||||
(void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS);
|
||||
}
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
__fp16 *dst;
|
||||
const uint8_t *src;
|
||||
int n_cols;
|
||||
int k_block;
|
||||
size_t row_stride;
|
||||
int weight_type;
|
||||
int n_tot_tiles;
|
||||
int n_tiles_per_task;
|
||||
int n_tasks;
|
||||
} x4x2_dequantize_state_t;
|
||||
|
||||
static void dequantize_x4x2_worker_loop(unsigned int n, unsigned int i, void *data) {
|
||||
static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) {
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||||
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||||
int start = task_id * state->n_tiles_per_task;
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||||
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
state->dst, state->src, state->n_cols, state->k_block,
|
||||
state->row_stride, state->weight_type, start, end);
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end);
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
|
||||
struct htp_context *ctx, __fp16 *vtcm_dst,
|
||||
const void *vtcm_src, int n_cols, int k_block,
|
||||
size_t row_stride, int weight_type) {
|
||||
size_t row_stride, int weight_type,
|
||||
int n_k_tiles, struct fastdiv_values n_k_tiles_div,
|
||||
worker_callback_t dequant_worker_fn) {
|
||||
|
||||
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
|
||||
assert(k_block % HMX_FP16_TILE_N_COLS == 0);
|
||||
|
||||
size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
|
||||
size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
|
||||
size_t n_tot_tiles = n_col_tiles * n_k_tiles;
|
||||
|
||||
size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads);
|
||||
@@ -680,8 +745,10 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
|
||||
state.k_block = k_block;
|
||||
state.row_stride = row_stride;
|
||||
state.weight_type = weight_type;
|
||||
state.n_k_tiles = n_k_tiles;
|
||||
state.n_k_tiles_div = n_k_tiles_div;
|
||||
|
||||
worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads);
|
||||
worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, ctx->n_threads);
|
||||
}
|
||||
|
||||
// --- End x4x2 dequantizers ---
|
||||
@@ -978,6 +1045,20 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *
|
||||
return -1;
|
||||
}
|
||||
|
||||
worker_callback_t dequant_worker_fn = NULL;
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break;
|
||||
case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break;
|
||||
case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break;
|
||||
case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break;
|
||||
case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
|
||||
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
|
||||
const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles);
|
||||
|
||||
// --- Dynamic VTCM layout ---
|
||||
const size_t vec_dot_size = k * sizeof(__fp16);
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
@@ -1070,7 +1151,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *
|
||||
{
|
||||
// B0: wait for DMA, dequant weight chunk 0
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn);
|
||||
|
||||
// A1: issue DMA for weight chunk 1
|
||||
const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
|
||||
@@ -1089,7 +1170,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *
|
||||
// B1: DMA pop + dequant (runs in parallel with C0 on HMX worker)
|
||||
if (1 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1131,7 +1212,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *
|
||||
// B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1})
|
||||
if (i + 2 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -860,6 +860,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
||||
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
|
||||
vk_pipeline pipeline_sum_rows_f32;
|
||||
vk_pipeline pipeline_fwht_f32[4];
|
||||
vk_pipeline pipeline_cumsum_f32;
|
||||
vk_pipeline pipeline_cumsum_small_f32;
|
||||
vk_pipeline pipeline_cumsum_multipass1_f32;
|
||||
@@ -1150,6 +1151,13 @@ struct vk_op_push_constants {
|
||||
float param4;
|
||||
};
|
||||
|
||||
struct vk_op_fwht_push_constants {
|
||||
uint32_t n_rows;
|
||||
uint32_t src_offset;
|
||||
uint32_t dst_offset;
|
||||
float scale;
|
||||
};
|
||||
|
||||
struct vk_op_count_experts_push_constants {
|
||||
uint32_t ne00;
|
||||
uint32_t ne01;
|
||||
@@ -2055,6 +2063,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
|
||||
GGML_UNUSED(src3);
|
||||
}
|
||||
|
||||
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_fwht_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
|
||||
p.src_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
||||
p.dst_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(src2);
|
||||
GGML_UNUSED(src3);
|
||||
}
|
||||
|
||||
struct ggml_backend_vk_buffer_context {
|
||||
vk_device_ref device;
|
||||
vk_buffer dev_buffer;
|
||||
@@ -2095,9 +2112,9 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
|
||||
const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
std::string type = device ? "device" : "host";
|
||||
auto it = allocations.find(buf->buffer);
|
||||
total_device -= device ? it->second : 0;
|
||||
total_host -= device ? 0 : it->second;
|
||||
if (it != allocations.end()) {
|
||||
total_device -= device ? it->second : 0;
|
||||
total_host -= device ? 0 : it->second;
|
||||
VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
|
||||
allocations.erase(it);
|
||||
} else {
|
||||
@@ -4982,6 +4999,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
// Intel Arc B390 was observed segfaulting with this shader.
|
||||
if (device->subgroup_basic && device->subgroup_shuffle && device->vendor_id != VK_VENDOR_ID_INTEL) {
|
||||
int idx = 0;
|
||||
for (uint32_t n : {64, 128, 256, 512}) {
|
||||
if (device->subgroup_size <= n) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_f32", fwht_f32_len, fwht_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { device->subgroup_size, n }, 1, true, true, device->subgroup_size);
|
||||
}
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
|
||||
@@ -7233,7 +7260,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
|
||||
const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
|
||||
const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
|
||||
for (uint64_t i0 = 0; i0 < ne0; i0++) {
|
||||
slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 });
|
||||
slices.push_back({ s_off + i0*nb0, d_off + i0*dstnb0, dstnb0 });
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8741,6 +8768,68 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
||||
}, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
|
||||
}
|
||||
|
||||
static int ggml_vk_fwht_pipeline_idx(int64_t n) {
|
||||
switch (n) {
|
||||
case 64: return 0;
|
||||
case 128: return 1;
|
||||
case 256: return 2;
|
||||
case 512: return 3;
|
||||
default: return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_can_use_fwht(const ggml_backend_vk_context * ctx, const ggml_tensor * src1, const ggml_tensor * dst) {
|
||||
if (ctx->num_additional_fused_ops != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ggml_get_op_params_i32(dst, 1) != GGML_HINT_SRC0_IS_HADAMARD) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int idx = ggml_vk_fwht_pipeline_idx(src1->ne[0]);
|
||||
if (idx < 0 || ctx->device->pipeline_fwht_f32[idx] == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(src1)) {
|
||||
return false;
|
||||
}
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void ggml_vk_fwht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
const int idx = ggml_vk_fwht_pipeline_idx(src->ne[0]);
|
||||
vk_pipeline pipeline = ctx->device->pipeline_fwht_f32[idx];
|
||||
|
||||
const uint32_t rows_per_workgroup = 4;
|
||||
const uint32_t n_rows = (uint32_t)ggml_nrows(src);
|
||||
const uint32_t max_workgroups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
|
||||
|
||||
const uint32_t total_workgroups = CEIL_DIV(n_rows, rows_per_workgroup);
|
||||
const uint32_t workgroups_x = std::min(total_workgroups, max_workgroups_x);
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
|
||||
const vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src, true);
|
||||
const vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
|
||||
|
||||
vk_op_fwht_push_constants pc = {
|
||||
n_rows,
|
||||
0,
|
||||
0,
|
||||
1.0f / std::sqrt((float)src->ne[0]),
|
||||
};
|
||||
init_pushconst_tensor_offsets(ctx, pc, src, nullptr, nullptr, nullptr, dst);
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, { workgroups_x, 1, 1 });
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||
ggml_tensor * dst = cgraph->nodes[node_idx];
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
@@ -8774,6 +8863,8 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||
|
||||
m_offset += cur_M_size;
|
||||
}
|
||||
} else if (ggml_vk_can_use_fwht(ctx, src1, dst)) {
|
||||
ggml_vk_fwht(ctx, subctx, src1, dst);
|
||||
} else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
|
||||
// detect 0213 permutation, and batch size of 1
|
||||
src0->nb[0] <= src0->nb[2] &&
|
||||
|
||||
69
ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp
Normal file
69
ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp
Normal file
@@ -0,0 +1,69 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
|
||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||
layout(constant_id = 1) const uint N = 128;
|
||||
|
||||
layout(push_constant) uniform parameter
|
||||
{
|
||||
uint n_rows;
|
||||
uint src_offset;
|
||||
uint dst_offset;
|
||||
float scale;
|
||||
};
|
||||
|
||||
layout(binding = 0, std430) readonly buffer A { float data_a[]; };
|
||||
layout(binding = 1, std430) writeonly buffer D { float data_d[]; };
|
||||
|
||||
const uint EL_W = N / WARP_SIZE;
|
||||
|
||||
void main() {
|
||||
const uint lane = gl_SubgroupInvocationID;
|
||||
for (uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
|
||||
row < n_rows;
|
||||
row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) {
|
||||
const uint row_offset = row * N;
|
||||
|
||||
float reg[EL_W];
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < EL_W; ++i) {
|
||||
reg[i] = data_a[src_offset + row_offset + i * WARP_SIZE + lane] * scale;
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint h = 1; h < WARP_SIZE; h <<= 1) {
|
||||
[[unroll]]
|
||||
for (uint j = 0; j < EL_W; ++j) {
|
||||
const float val = reg[j];
|
||||
const float val2 = subgroupShuffleXor(val, h);
|
||||
reg[j] = (lane & h) == 0 ? val + val2 : val2 - val;
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint h = WARP_SIZE; h < N; h <<= 1) {
|
||||
const uint step = h / WARP_SIZE;
|
||||
[[unroll]]
|
||||
for (uint j = 0; j < EL_W; j += 2 * step) {
|
||||
[[unroll]]
|
||||
for (uint k = 0; k < step; ++k) {
|
||||
const float x = reg[j + k];
|
||||
const float y = reg[j + k + step];
|
||||
reg[j + k] = x + y;
|
||||
reg[j + k + step] = x - y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < EL_W; ++i) {
|
||||
data_d[dst_offset + row_offset + i * WARP_SIZE + lane] = reg[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -934,6 +934,7 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("fwht_f32", "fwht.comp", {});
|
||||
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
114
models/templates/ibm-granite-granite-4.1.jinja
Normal file
114
models/templates/ibm-granite-granite-4.1.jinja
Normal file
@@ -0,0 +1,114 @@
|
||||
{%- set tools_system_message_prefix = 'You are a helpful assistant with access to the following tools. You may call one or more tools to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>' %}
|
||||
{%- set tools_system_message_suffix = '\n</tools>\n\nFor each tool call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.' %}
|
||||
{%- set documents_system_message_prefix = 'You are a helpful assistant with access to the following documents. You may use one or more documents to assist with the user query.\n\nYou are given a list of documents within <documents></documents> XML tags:\n<documents>' %}
|
||||
{%- set documents_system_message_suffix = '\n</documents>\n\nWrite the response to the user\'s input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.' %}
|
||||
{%- if available_tools is defined and available_tools %}
|
||||
{%- set tools = available_tools %}
|
||||
{%- endif %}
|
||||
{%- set ns = namespace(tools_system_message=tools_system_message_prefix,
|
||||
documents_system_message=documents_system_message_prefix,
|
||||
system_message=''
|
||||
) %}
|
||||
{%- if tools %}
|
||||
{%- for tool in tools %}
|
||||
{%- set ns.tools_system_message = ns.tools_system_message + '\n' + (tool | tojson) %}
|
||||
{%- endfor %}
|
||||
{%- set ns.tools_system_message = ns.tools_system_message + tools_system_message_suffix %}
|
||||
{%- else %}
|
||||
{%- set ns.tools_system_message = '' %}
|
||||
{%- endif %}
|
||||
{%- if documents %}
|
||||
{%- for document in documents %}
|
||||
{%- set ns.documents_system_message = ns.documents_system_message + '\n' + (document | tojson) %}
|
||||
{%- endfor %}
|
||||
{%- set ns.documents_system_message = ns.documents_system_message + documents_system_message_suffix %}
|
||||
{%- else %}
|
||||
{%- set ns.documents_system_message = '' %}
|
||||
{%- endif %}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{%- if messages[0].content is string %}
|
||||
{%- set ns.system_message = messages[0].content %}
|
||||
{%- elif messages[0].content is iterable %}
|
||||
{%- for entry in messages[0].content %}
|
||||
{%- if entry.type== 'text' %}
|
||||
{%- if ns.system_message != '' %}
|
||||
{%- set ns.system_message = ns.system_message + '\n' %}
|
||||
{%- endif %}
|
||||
{%- set ns.system_message = ns.system_message + entry.text %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- if tools and documents %}
|
||||
{%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message + '\n\n' + ns.documents_system_message %}
|
||||
{%- elif tools %}
|
||||
{%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message %}
|
||||
{%- elif documents %}
|
||||
{%- set ns.system_message = ns.system_message + '\n\n' + ns.documents_system_message %}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{%- if tools and documents %}
|
||||
{%- set ns.system_message = ns.tools_system_message + '\n\n' + ns.documents_system_message %}
|
||||
{%- elif tools %}
|
||||
{%- set ns.system_message = ns.tools_system_message %}
|
||||
{%- elif documents %}
|
||||
{%- set ns.system_message = ns.documents_system_message %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if ns.system_message %}
|
||||
{{- '<|start_of_role|>system<|end_of_role|>' + ns.system_message + '<|end_of_text|>\n' }}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- set content = namespace(val='') %}
|
||||
{%- if message.content is string %}
|
||||
{%- set content.val = message.content %}
|
||||
{%- else %}
|
||||
{%- if message.content is iterable %}
|
||||
{%- for entry in message.content %}
|
||||
{%- if entry.type== 'text' %}
|
||||
{%- if content.val != '' %}
|
||||
{%- set content.val = content.val + '\n' %}
|
||||
{%- endif %}
|
||||
{%- set content.val = content.val + entry.text %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if (message.role == 'user') or (message.role == 'system' and not loop.first) %}
|
||||
{{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val + '<|end_of_text|>\n' }}
|
||||
{%- elif message.role == 'assistant' %}
|
||||
{{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }}
|
||||
{%- if message.tool_calls %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if (loop.first and content.val) or (not loop.first) %}
|
||||
{{- '\n' }}
|
||||
{%- endif %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '<tool_call>\n{"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '", "arguments": ' }}
|
||||
{%- if tool_call.arguments is string %}
|
||||
{{- tool_call.arguments }}
|
||||
{%- else %}
|
||||
{{- tool_call.arguments | tojson }}
|
||||
{%- endif %}
|
||||
{{- '}\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|end_of_text|>\n' }}
|
||||
{%- elif message.role == 'tool' %}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != 'tool') %}
|
||||
{{- '<|start_of_role|>user<|end_of_role|>' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{{- content.val }}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != 'tool') %}
|
||||
{{- '<|end_of_text|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|start_of_role|>assistant<|end_of_role|>' }}
|
||||
{%- endif %}
|
||||
@@ -62,6 +62,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
|
||||
{ "granite", LLM_CHAT_TEMPLATE_GRANITE_3_X },
|
||||
{ "granite-4.0", LLM_CHAT_TEMPLATE_GRANITE_4_0 },
|
||||
{ "granite-4.1", LLM_CHAT_TEMPLATE_GRANITE_4_1 },
|
||||
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
||||
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
||||
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
||||
@@ -194,7 +195,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
|
||||
} else if (tmpl_contains("<|start_of_role|>")) {
|
||||
if (tmpl_contains("<tool_call>") || tmpl_contains("<tools>")) {
|
||||
return LLM_CHAT_TEMPLATE_GRANITE_4_0;
|
||||
if (tmpl_contains("g4_default_system_message")) {
|
||||
return LLM_CHAT_TEMPLATE_GRANITE_4_0;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_GRANITE_4_1;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_GRANITE_3_X;
|
||||
} else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) {
|
||||
@@ -651,6 +655,20 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|start_of_role|>assistant<|end_of_role|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_4_1) {
|
||||
// IBM Granite 4.1 template
|
||||
for (const auto & message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "assistant_tool_call") {
|
||||
ss << "<|start_of_role|>assistant<|end_of_role|><|tool_call|>";
|
||||
} else {
|
||||
ss << "<|start_of_role|>" << role << "<|end_of_role|>";
|
||||
}
|
||||
ss << message->content << "<|end_of_text|>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|start_of_role|>assistant<|end_of_role|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
|
||||
// GigaChat template
|
||||
bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
|
||||
|
||||
@@ -41,6 +41,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_RWKV_WORLD,
|
||||
LLM_CHAT_TEMPLATE_GRANITE_3_X,
|
||||
LLM_CHAT_TEMPLATE_GRANITE_4_0,
|
||||
LLM_CHAT_TEMPLATE_GRANITE_4_1,
|
||||
LLM_CHAT_TEMPLATE_GIGACHAT,
|
||||
LLM_CHAT_TEMPLATE_MEGREZ,
|
||||
LLM_CHAT_TEMPLATE_YANDEX,
|
||||
|
||||
@@ -8318,6 +8318,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 1, 128));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 64, 1, 64));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 256, 1, 256));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 512, 1, 512));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 32, 128));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 4, 128, {2, 3}));
|
||||
|
||||
|
||||
@@ -618,6 +618,16 @@ int main_automated_tests(void) {
|
||||
},
|
||||
{
|
||||
/* .name= */ "ibm-granite/granite-4.0 (tool call)",
|
||||
/* .template_str= */ "{%- for message in messages %}\n {%- if message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\\n' }}\n {%- else %}\n {{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>\\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}\n{# <tool_call> <tools> g4_default_system_message #}",
|
||||
/* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>What is the weather?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|><|tool_call|><tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}\n</tool_call><|end_of_text|>\n<|start_of_role|>tool_response<|end_of_role|>{\"temperature\": 72}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>",
|
||||
/* .expected_output_jinja= */ "",
|
||||
/* .bos_token= */ "",
|
||||
/* .eos_token= */ "",
|
||||
/* .supported_with_jinja= */ true,
|
||||
/* .extra_conversation= */ {{"user", "What is the weather?"}, {"assistant_tool_call", "<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}\n</tool_call>"}, {"tool_response", "{\"temperature\": 72}"}},
|
||||
},
|
||||
{
|
||||
/* .name= */ "ibm-granite/granite-4.1 (tool call)",
|
||||
/* .template_str= */ "{%- for message in messages %}\n {%- if message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\\n' }}\n {%- else %}\n {{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>\\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}\n{# <tool_call> <tools> #}",
|
||||
/* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>What is the weather?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|><|tool_call|><tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}\n</tool_call><|end_of_text|>\n<|start_of_role|>tool_response<|end_of_role|>{\"temperature\": 72}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>",
|
||||
/* .expected_output_jinja= */ "",
|
||||
|
||||
@@ -2914,6 +2914,21 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
// IBM Granite 4.1 (same format as 4.0)
|
||||
auto tst = peg_tester("models/templates/ibm-granite-granite-4.1.jinja", detailed_debug);
|
||||
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
// ByteDance-Seed-OSS (reasoning and tool calling model)
|
||||
auto tst = peg_tester("models/templates/ByteDance-Seed-OSS.jinja", detailed_debug);
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "../src/llama-model-saver.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <cstdint>
|
||||
@@ -497,6 +498,7 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
|
||||
};
|
||||
|
||||
std::vector<device_config> dev_configs;
|
||||
size_t max_device_label_length = 4;
|
||||
{
|
||||
std::vector<ggml_backend_dev_t> devices_meta;
|
||||
{
|
||||
@@ -504,6 +506,7 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
|
||||
for (size_t i = 0; i < device_count; i++) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
dev_configs.emplace_back(std::vector<ggml_backend_dev_t>{dev}, ggml_backend_dev_description(dev), LLAMA_SPLIT_MODE_LAYER);
|
||||
max_device_label_length = std::max(max_device_label_length, dev_configs.back().label.length());
|
||||
|
||||
// cpu-based devices cannot be used in tensor split mode
|
||||
if (ggml_backend_dev_buffer_type(dev) != ggml_backend_cpu_buffer_type()) {
|
||||
@@ -515,10 +518,26 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
|
||||
dev_configs.emplace_back(devices_meta, "Meta", LLAMA_SPLIT_MODE_TENSOR);
|
||||
}
|
||||
|
||||
size_t max_arch_name_length = 0;
|
||||
for (const llm_arch & arch : llm_arch_all()) {
|
||||
max_arch_name_length = std::max(max_arch_name_length, strlen(llm_arch_name(arch)));
|
||||
}
|
||||
|
||||
const std::string template_header = std::string("|%" + std::to_string(max_arch_name_length) + "s|%") + std::to_string(max_device_label_length) + "s|%6s|%15s|%9s|\n";
|
||||
const std::string template_row = std::string("|%" + std::to_string(max_arch_name_length) + "s|%") + std::to_string(max_device_label_length) + "s|%6s|%15s %10s|%20s|\n";
|
||||
|
||||
bool all_ok = true;
|
||||
common_log_flush(common_log_main());
|
||||
printf("|%16s|%30s|%6s|%15s|%9s|\n", "Model arch.", "Device", "Config", "NMSE vs. CPU", "Roundtrip");
|
||||
printf("|----------------|------------------------------|------|---------------|---------|\n");
|
||||
printf(template_header.c_str(), "Model arch.", "Device", "Config", "NMSE vs. CPU", "Roundtrip");
|
||||
printf("|");
|
||||
for (size_t i = 0; i < max_arch_name_length; i++) {
|
||||
printf("-");
|
||||
}
|
||||
printf("|");
|
||||
for (size_t i = 0; i < max_device_label_length; i++) {
|
||||
printf("-");
|
||||
}
|
||||
printf("|------|---------------|---------|\n");
|
||||
for (const llm_arch & arch : llm_arch_all()) {
|
||||
if (arch == LLM_ARCH_UNKNOWN) {
|
||||
continue;
|
||||
@@ -595,7 +614,7 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
|
||||
}
|
||||
}
|
||||
|
||||
printf("|%16s|%30s|%6s|%15s %10s|%20s|\n", llm_arch_name(arch), dc.label.c_str(),
|
||||
printf(template_row.c_str(), llm_arch_name(arch), dc.label.c_str(),
|
||||
config_name.c_str(), status_nmse.c_str(), nmse_str, status_roundtrip.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -923,7 +923,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
||||
}
|
||||
|
||||
if (i0 == i1) {
|
||||
LOG_ERR("%s : task %zu does not fit in the context window (requires %lu tokens)\n", __func__, i0, hs_data[i0].required_tokens);
|
||||
LOG_ERR("%s : task %zu does not fit in the context window (requires %zu tokens)\n", __func__, i0, hs_data[i0].required_tokens);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1216,7 +1216,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
||||
}
|
||||
|
||||
if (i0 == i1) {
|
||||
LOG_ERR("%s : task %zu does not fit in the context window (requires %lu tokens)\n", __func__, i0, data[i0].required_tokens);
|
||||
LOG_ERR("%s : task %zu does not fit in the context window (requires %zu tokens)\n", __func__, i0, data[i0].required_tokens);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1595,7 +1595,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
||||
}
|
||||
|
||||
if (i0 == i1) {
|
||||
LOG_ERR("%s : task %zu does not fit in the context window (requires %lu tokens)\n", __func__, i0, tasks[i0].required_tokens);
|
||||
LOG_ERR("%s : task %zu does not fit in the context window (requires %zu tokens)\n", __func__, i0, tasks[i0].required_tokens);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -201,7 +201,7 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
|
||||
| `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
|
||||
| `--api-key KEY` | API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)<br/>(env: LLAMA_API_KEY) |
|
||||
| `--api-key-file FNAME` | path to file containing API keys (default: none) |
|
||||
| `--api-key-file FNAME` | path to file containing API keys, one per line (default: none)<br/>(env: LLAMA_ARG_API_KEY_FILE) |
|
||||
| `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key<br/>(env: LLAMA_ARG_SSL_KEY_FILE) |
|
||||
| `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate<br/>(env: LLAMA_ARG_SSL_CERT_FILE) |
|
||||
| `--chat-template-kwargs STRING` | sets additional params for the json template parser, must be a valid json object string, e.g. '{"key1":"value1","key2":"value2"}'<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_KWARGS) |
|
||||
|
||||
@@ -5,9 +5,9 @@
|
||||
|
||||
#include <cpp-httplib/httplib.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
@@ -21,7 +21,7 @@ public:
|
||||
};
|
||||
|
||||
server_http_context::server_http_context()
|
||||
: pimpl(std::make_unique<server_http_context::Impl>())
|
||||
: pimpl(std::make_unique<Impl>())
|
||||
{}
|
||||
|
||||
server_http_context::~server_http_context() = default;
|
||||
@@ -62,7 +62,7 @@ struct gcp_params {
|
||||
}
|
||||
|
||||
static std::string getenv(const char * name, const std::string & default_value, bool ensure_leading_slash = false) {
|
||||
const char * value = std::getenv(name);
|
||||
const auto * value = std::getenv(name);
|
||||
if (value == nullptr || value[0] == '\0') {
|
||||
return default_value;
|
||||
}
|
||||
@@ -94,15 +94,15 @@ bool server_http_context::init(const common_params & params) {
|
||||
auto & srv = pimpl->srv;
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
||||
if (!params.ssl_file_key.empty() && !params.ssl_file_cert.empty()) {
|
||||
SRV_INF("running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
|
||||
srv.reset(
|
||||
new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
|
||||
srv = std::make_unique<httplib::SSLServer>(
|
||||
params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()
|
||||
);
|
||||
is_ssl = true;
|
||||
} else {
|
||||
SRV_INF("%s", "running without SSL\n");
|
||||
srv.reset(new httplib::Server());
|
||||
srv = std::make_unique<httplib::Server>();
|
||||
}
|
||||
#else
|
||||
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
||||
@@ -150,7 +150,7 @@ bool server_http_context::init(const common_params & params) {
|
||||
// set timeouts and change hostname and port
|
||||
srv->set_read_timeout (params.timeout_read);
|
||||
srv->set_write_timeout(params.timeout_write);
|
||||
srv->set_socket_options([reuse_port = params.reuse_port](socket_t sock) {
|
||||
srv->set_socket_options([reuse_port = params.reuse_port](const socket_t sock) {
|
||||
httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
|
||||
if (reuse_port) {
|
||||
#ifdef SO_REUSEPORT
|
||||
@@ -162,8 +162,8 @@ bool server_http_context::init(const common_params & params) {
|
||||
});
|
||||
|
||||
if (params.api_keys.size() == 1) {
|
||||
auto key = params.api_keys[0];
|
||||
std::string substr = key.substr(std::max((int)(key.length() - 4), 0));
|
||||
const auto key = params.api_keys[0];
|
||||
const std::string substr = key.substr(std::max(static_cast<int>(key.length() - 4), 0));
|
||||
SRV_INF("api_keys: ****%s\n", substr.c_str());
|
||||
} else if (params.api_keys.size() > 1) {
|
||||
SRV_INF("api_keys: %zu keys loaded\n", params.api_keys.size());
|
||||
@@ -203,7 +203,7 @@ bool server_http_context::init(const common_params & params) {
|
||||
}
|
||||
|
||||
// remove the "Bearer " prefix if needed
|
||||
std::string prefix = "Bearer ";
|
||||
static std::string prefix = "Bearer ";
|
||||
if (req_api_key.substr(0, prefix.size()) == prefix) {
|
||||
req_api_key = req_api_key.substr(prefix.size());
|
||||
}
|
||||
@@ -232,11 +232,10 @@ bool server_http_context::init(const common_params & params) {
|
||||
};
|
||||
|
||||
auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
|
||||
bool ready = is_ready.load();
|
||||
if (!ready) {
|
||||
if (!is_ready.load()) {
|
||||
#if defined(LLAMA_UI_HAS_ASSETS)
|
||||
auto tmp = string_split<std::string>(req.path, '.');
|
||||
if (req.path == "/" || (tmp.size() > 0 && tmp.back() == "html")) {
|
||||
if (const auto tmp = string_split<std::string>(req.path, '.');
|
||||
req.path == "/" || (!tmp.empty() && tmp.back() == "html")) {
|
||||
if (const llama_ui_asset * a = llama_ui_find_asset("loading.html")) {
|
||||
res.status = 503;
|
||||
res.set_content(reinterpret_cast<const char*>(a->data), a->size, "text/html; charset=utf-8");
|
||||
@@ -284,17 +283,17 @@ bool server_http_context::init(const common_params & params) {
|
||||
return httplib::Server::HandlerResponse::Unhandled;
|
||||
});
|
||||
|
||||
int n_threads_http = params.n_threads_http;
|
||||
auto n_threads_http = params.n_threads_http;
|
||||
if (n_threads_http < 1) {
|
||||
// +4 threads for monitoring, health and some threads reserved for MCP and other tasks in the future
|
||||
n_threads_http = std::max(params.n_parallel + 4, (int32_t) std::thread::hardware_concurrency() - 1);
|
||||
n_threads_http = std::max(params.n_parallel + 4, static_cast<int32_t>(std::thread::hardware_concurrency() - 1));
|
||||
}
|
||||
SRV_INF("using %d threads for HTTP server\n", n_threads_http);
|
||||
srv->new_task_queue = [n_threads_http] {
|
||||
// spawn n_threads_http fixed thread (always alive), while allow up to 1024 max possible additional threads
|
||||
// when n_threads_http is used, server will create new "dynamic" threads that will be destroyed after processing each request
|
||||
// ref: https://github.com/yhirose/cpp-httplib/pull/2368
|
||||
size_t max_threads = (size_t)n_threads_http + 1024;
|
||||
const auto max_threads = static_cast<size_t>(n_threads_http + 1024);
|
||||
return new httplib::ThreadPool(n_threads_http, max_threads);
|
||||
};
|
||||
|
||||
@@ -310,20 +309,26 @@ bool server_http_context::init(const common_params & params) {
|
||||
// register static assets routes
|
||||
if (!params.public_path.empty()) {
|
||||
// Set the base directory for serving static files
|
||||
bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path);
|
||||
if (!is_found) {
|
||||
if (const auto is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path); !is_found) {
|
||||
SRV_ERR("static assets path not found: %s\n", params.public_path.c_str());
|
||||
return 1;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
#if defined(LLAMA_UI_HAS_ASSETS)
|
||||
auto serve_asset = [](const std::string & name, const char * mime, bool with_isolation_headers) {
|
||||
return [name, mime, with_isolation_headers](const httplib::Request & /*req*/, httplib::Response & res) {
|
||||
return [name, mime, with_isolation_headers](const httplib::Request & req, httplib::Response & res) {
|
||||
const llama_ui_asset * a = llama_ui_find_asset(name.c_str());
|
||||
if (!a) {
|
||||
res.status = 404;
|
||||
return false;
|
||||
}
|
||||
res.set_header("ETag", a->etag);
|
||||
// Check If-None-Match for conditional GET (304 Not Modified)
|
||||
if (const std::string & inm = req.get_header_value("If-None-Match");
|
||||
!inm.empty() && inm == a->etag) {
|
||||
res.status = 304;
|
||||
return false;
|
||||
}
|
||||
if (with_isolation_headers) {
|
||||
// COEP and COOP headers, required by pyodide (python interpreter)
|
||||
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
|
||||
@@ -346,9 +351,9 @@ bool server_http_context::init(const common_params & params) {
|
||||
bool server_http_context::start() {
|
||||
// Bind and listen
|
||||
|
||||
auto & srv = pimpl->srv;
|
||||
bool was_bound = false;
|
||||
bool is_sock = false;
|
||||
const auto & srv = pimpl->srv;
|
||||
auto was_bound = false;
|
||||
auto is_sock = false;
|
||||
if (string_ends_with(std::string(hostname), ".sock")) {
|
||||
is_sock = true;
|
||||
SRV_INF("%s", "setting address family to AF_UNIX\n");
|
||||
@@ -360,7 +365,7 @@ bool server_http_context::start() {
|
||||
SRV_INF("%s", "binding port with default address family\n");
|
||||
// bind HTTP listen port
|
||||
if (port == 0) {
|
||||
int bound_port = srv->bind_to_any_port(hostname);
|
||||
const auto bound_port = srv->bind_to_any_port(hostname);
|
||||
was_bound = (bound_port >= 0);
|
||||
if (was_bound) {
|
||||
port = bound_port;
|
||||
@@ -376,7 +381,7 @@ bool server_http_context::start() {
|
||||
}
|
||||
|
||||
// run the HTTP server in a thread
|
||||
thread = std::thread([this]() { pimpl->srv->listen_after_bind(); });
|
||||
thread = std::thread([this] { pimpl->srv->listen_after_bind(); });
|
||||
srv->wait_until_ready();
|
||||
|
||||
listening_address = is_sock ? string_format("unix://%s", hostname.c_str())
|
||||
@@ -433,13 +438,13 @@ static void process_handler_response(server_http_req_ptr && request, server_http
|
||||
if (response->is_stream()) {
|
||||
res.status = response->status;
|
||||
set_headers(res, response->headers);
|
||||
std::string content_type = response->content_type;
|
||||
const std::string content_type = response->content_type;
|
||||
// convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
|
||||
std::shared_ptr<server_http_req> q_ptr = std::move(request);
|
||||
std::shared_ptr<server_http_res> r_ptr = std::move(response);
|
||||
const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
|
||||
std::shared_ptr q_ptr = std::move(request);
|
||||
std::shared_ptr r_ptr = std::move(response);
|
||||
const auto chunked_content_provider = [response = r_ptr](size_t, const httplib::DataSink & sink) -> bool {
|
||||
std::string chunk;
|
||||
bool has_next = response->next(chunk);
|
||||
const bool has_next = response->next(chunk);
|
||||
if (!chunk.empty()) {
|
||||
if (!sink.write(chunk.data(), chunk.size())) {
|
||||
return false;
|
||||
@@ -550,7 +555,7 @@ static std::string path_to_gcp_format(const std::string & path) {
|
||||
if (c == '/' || c == '-' || c == '_') {
|
||||
cap = true;
|
||||
} else {
|
||||
result += cap ? (char)std::toupper(c) : (char)c;
|
||||
result += static_cast<char>(cap ? std::toupper(c) : c);
|
||||
cap = false;
|
||||
}
|
||||
}
|
||||
@@ -574,7 +579,7 @@ static json parse_gcp_predict_response(const server_http_res_ptr & res) {
|
||||
}
|
||||
}
|
||||
|
||||
void server_http_context::register_gcp_compat() {
|
||||
void server_http_context::register_gcp_compat() const {
|
||||
const gcp_params gcp;
|
||||
|
||||
if (!gcp.enabled) {
|
||||
@@ -595,7 +600,7 @@ void server_http_context::register_gcp_compat() {
|
||||
}
|
||||
|
||||
if (!gcp.path_health.empty()) {
|
||||
auto health_handler = handlers.find("/health");
|
||||
const auto health_handler = handlers.find("/health");
|
||||
GGML_ASSERT(health_handler != handlers.end());
|
||||
get(gcp.path_health, health_handler->second);
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ struct server_http_context {
|
||||
|
||||
std::string path_prefix;
|
||||
std::string hostname;
|
||||
int port;
|
||||
int port = 8080;
|
||||
bool is_ssl = false;
|
||||
|
||||
server_http_context();
|
||||
@@ -88,7 +88,7 @@ struct server_http_context {
|
||||
|
||||
// Register the Google Cloud Platform (Vertex AI) compat (AIP_PREDICT_ROUTE env var, or /predict)
|
||||
// Must be called AFTER all other API routes are registered
|
||||
void register_gcp_compat();
|
||||
void register_gcp_compat() const;
|
||||
|
||||
// for debugging
|
||||
std::string listening_address;
|
||||
|
||||
@@ -9,6 +9,19 @@
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
// Computes FNV-1a hash of the data
|
||||
static uint64_t fnv_hash(const uint8_t * data, size_t len) {
|
||||
const uint64_t fnv_prime = 0x100000001b3ULL;
|
||||
uint64_t hash = 0xcbf29ce484222325ULL;
|
||||
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
hash ^= data[i];
|
||||
hash *= fnv_prime;
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
static bool read_file(const std::string & path, std::vector<unsigned char> & out) {
|
||||
std::ifstream f(path, std::ios::binary | std::ios::ate);
|
||||
@@ -95,6 +108,7 @@ int main(int argc, char ** argv) {
|
||||
" const char * name;\n"
|
||||
" const unsigned char * data;\n"
|
||||
" size_t size;\n"
|
||||
" const char * etag;\n"
|
||||
"};\n\n"
|
||||
"const llama_ui_asset * llama_ui_find_asset(const char * name);\n";
|
||||
|
||||
@@ -110,14 +124,18 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
cpp += fmt("static const unsigned char asset_%d_data[] = {", i);
|
||||
append_bytes_hex(cpp, bytes);
|
||||
cpp += fmt("};\nstatic const size_t asset_%d_size = %lu;\n\n",
|
||||
const auto hash = fnv_hash(bytes.data(), bytes.size());
|
||||
|
||||
cpp += fmt("};\nstatic const size_t asset_%d_size = %lu;\n",
|
||||
i, static_cast<unsigned long>(bytes.size()));
|
||||
cpp += fmt("static const char asset_%d_etag[] = \"\\\"0x%016lx\\\"\";\n\n",
|
||||
i, static_cast<unsigned long>(hash));
|
||||
}
|
||||
|
||||
cpp += "static const llama_ui_asset g_assets[] = {\n";
|
||||
for (int i = 0; i < n_assets; i++) {
|
||||
const char * name = argv[3 + i * 2];
|
||||
cpp += fmt(" { \"%s\", asset_%d_data, asset_%d_size },\n", name, i, i);
|
||||
cpp += fmt(" { \"%s\", asset_%d_data, asset_%d_size, asset_%d_etag },\n",
|
||||
argv[3 + i * 2], i, i, i);
|
||||
}
|
||||
cpp += "};\n\n";
|
||||
|
||||
|
||||
Reference in New Issue
Block a user