mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-19 14:13:22 +02:00
Compare commits
7 Commits
b3362
...
compilade/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ba06b2deb7 | ||
|
|
dd07a123b7 | ||
|
|
f4444d992c | ||
|
|
6b2a849d1f | ||
|
|
0f1a39f343 | ||
|
|
83321c6958 | ||
|
|
cc61948b1f |
42
Makefile
42
Makefile
@@ -554,7 +554,7 @@ endif # GGML_BLIS
|
||||
|
||||
ifndef GGML_NO_LLAMAFILE
|
||||
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
|
||||
OBJ_GGML += ggml/src/sgemm.o
|
||||
OBJ_GGML += ggml/src/llamafile/sgemm.o
|
||||
endif
|
||||
|
||||
ifdef GGML_RPC
|
||||
@@ -835,7 +835,8 @@ OBJ_GGML += \
|
||||
ggml/src/ggml.o \
|
||||
ggml/src/ggml-alloc.o \
|
||||
ggml/src/ggml-backend.o \
|
||||
ggml/src/ggml-quants.o
|
||||
ggml/src/ggml-quants.o \
|
||||
ggml/src/ggml-aarch64.o
|
||||
|
||||
OBJ_LLAMA = \
|
||||
src/llama.o \
|
||||
@@ -969,15 +970,22 @@ ggml/src/ggml-quants.o: \
|
||||
ggml/src/ggml-common.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
ggml/src/ggml-aarch64.o: \
|
||||
ggml/src/ggml-aarch64.c \
|
||||
ggml/include/ggml.h \
|
||||
ggml/src/ggml-aarch64.h \
|
||||
ggml/src/ggml-common.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
ggml/src/ggml-blas.o: \
|
||||
ggml/src/ggml-blas.cpp \
|
||||
ggml/include/ggml-blas.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
ifndef GGML_NO_LLAMAFILE
|
||||
ggml/src/sgemm.o: \
|
||||
ggml/src/sgemm.cpp \
|
||||
ggml/src/sgemm.h \
|
||||
ggml/src/llamafile/sgemm.o: \
|
||||
ggml/src/llamafile/sgemm.cpp \
|
||||
ggml/src/llamafile/sgemm.h \
|
||||
ggml/include/ggml.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
endif # GGML_NO_LLAMAFILE
|
||||
@@ -1505,15 +1513,17 @@ llama-q8dot: pocs/vdot/q8dot.cpp ggml/src/ggml.o \
|
||||
# Mark legacy binary targets as .PHONY so that they are always checked.
|
||||
.PHONY: main quantize perplexity embedding server finetune
|
||||
|
||||
# NOTE: We currently will always build the deprecation-warning `main` and `server` binaries to help users migrate.
|
||||
# Eventually we will want to remove these target from building all the time.
|
||||
main: examples/deprecation-warning/deprecation-warning.cpp
|
||||
ifneq (,$(wildcard main))
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
@echo "#########"
|
||||
@echo "WARNING: The 'main' binary is deprecated. Please use 'llama-cli' instead."
|
||||
@echo " Remove the 'main' binary to remove this warning."
|
||||
@echo "#########"
|
||||
endif
|
||||
@echo "NOTICE: The 'main' binary is deprecated. Please use 'llama-cli' instead."
|
||||
|
||||
server: examples/deprecation-warning/deprecation-warning.cpp
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
@echo "NOTICE: The 'server' binary is deprecated. Please use 'llama-server' instead."
|
||||
|
||||
quantize: examples/deprecation-warning/deprecation-warning.cpp
|
||||
ifneq (,$(wildcard quantize))
|
||||
@@ -1545,16 +1555,6 @@ ifneq (,$(wildcard embedding))
|
||||
@echo "#########"
|
||||
endif
|
||||
|
||||
server: examples/deprecation-warning/deprecation-warning.cpp
|
||||
ifneq (,$(wildcard server))
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
@echo "#########"
|
||||
@echo "WARNING: The 'server' binary is deprecated. Please use 'llama-server' instead."
|
||||
@echo " Remove the 'server' binary to remove this warning."
|
||||
@echo "#########"
|
||||
endif
|
||||
|
||||
finetune: examples/deprecation-warning/deprecation-warning.cpp
|
||||
ifneq (,$(wildcard finetune))
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
|
||||
@@ -10,6 +10,7 @@ var sources = [
|
||||
"ggml/src/ggml-alloc.c",
|
||||
"ggml/src/ggml-backend.c",
|
||||
"ggml/src/ggml-quants.c",
|
||||
"ggml/src/ggml-aarch64.c",
|
||||
]
|
||||
|
||||
var resources: [Resource] = []
|
||||
|
||||
@@ -28,6 +28,7 @@ In order to build llama.cpp you have four different options.
|
||||
```
|
||||
|
||||
- Notes:
|
||||
- For `Q4_0_4_4` quantization type build, add the `GGML_NO_LLAMAFILE=1` flag. For example, use `make GGML_NO_LLAMAFILE=1`.
|
||||
- For faster compilation, add the `-j` argument to run multiple jobs in parallel. For example, `make -j 8` will run 8 jobs in parallel.
|
||||
- For faster repeated compilation, install [ccache](https://ccache.dev/).
|
||||
- For debug builds, run `make LLAMA_DEBUG=1`
|
||||
@@ -41,6 +42,7 @@ In order to build llama.cpp you have four different options.
|
||||
|
||||
**Notes**:
|
||||
|
||||
- For `Q4_0_4_4` quantization type build, add the `-DGGML_LLAMAFILE=OFF` cmake option. For example, use `cmake -B build -DGGML_LLAMAFILE=OFF`.
|
||||
- For faster compilation, add the `-j` argument to run multiple jobs in parallel. For example, `cmake --build build --config Release -j 8` will run 8 jobs in parallel.
|
||||
- For faster repeated compilation, install [ccache](https://ccache.dev/).
|
||||
- For debug builds, there are two cases:
|
||||
|
||||
@@ -46,6 +46,9 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
||||
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", },
|
||||
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", },
|
||||
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", },
|
||||
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
|
||||
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
|
||||
{ "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
|
||||
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, +0.0020 ppl @ Mistral-7B", },
|
||||
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
|
||||
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },
|
||||
|
||||
@@ -29,6 +29,7 @@ static void print_usage_information(const char * argv0, FILE * stream) {
|
||||
fprintf(stream, " -p PROMPT, --prompt PROMPT read prompt from the argument.\n");
|
||||
fprintf(stream, " --stdin read prompt from standard input.\n");
|
||||
fprintf(stream, " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n");
|
||||
fprintf(stream, " --no-parse-special do not parse control tokens.\n");
|
||||
fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n");
|
||||
fprintf(stream, " --show-count print the total number of tokens.\n");
|
||||
}
|
||||
@@ -195,6 +196,7 @@ int main(int raw_argc, char ** raw_argv) {
|
||||
// variables where to put any arguments we see.
|
||||
bool printing_ids = false;
|
||||
bool no_bos = false;
|
||||
bool no_parse_special = false;
|
||||
bool disable_logging = false;
|
||||
bool show_token_count = false;
|
||||
const char * model_path = NULL;
|
||||
@@ -229,6 +231,9 @@ int main(int raw_argc, char ** raw_argv) {
|
||||
else if (arg == "--no-bos") {
|
||||
no_bos = true;
|
||||
}
|
||||
else if (arg == "--no-parse-special") {
|
||||
no_parse_special = true;
|
||||
}
|
||||
else if (arg == "-p" || arg == "--prompt") {
|
||||
if (prompt_set) {
|
||||
fprintf(stderr, "Error: -p or --prompt specified multiple times.\n");
|
||||
@@ -359,9 +364,10 @@ int main(int raw_argc, char ** raw_argv) {
|
||||
|
||||
const bool model_wants_add_bos = llama_should_add_bos_token(model);
|
||||
const bool add_bos = model_wants_add_bos && !no_bos;
|
||||
const bool parse_special = !no_parse_special;
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
tokens = ::llama_tokenize(model, prompt, add_bos, true);
|
||||
tokens = ::llama_tokenize(model, prompt, add_bos, parse_special);
|
||||
|
||||
if (printing_ids) {
|
||||
printf("[");
|
||||
|
||||
@@ -104,7 +104,7 @@ option(GGML_ACCELERATE "ggml: enable Accelerate framework"
|
||||
option(GGML_BLAS "ggml: use BLAS" ${GGML_BLAS_DEFAULT})
|
||||
set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
|
||||
"ggml: BLAS library vendor")
|
||||
option(GGML_LLAMAFILE "ggml: use ggml SGEMM" OFF)
|
||||
option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF)
|
||||
|
||||
option(GGML_CUDA "ggml: use CUDA" OFF)
|
||||
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
|
||||
|
||||
@@ -383,6 +383,9 @@ extern "C" {
|
||||
GGML_TYPE_F64 = 28,
|
||||
GGML_TYPE_IQ1_M = 29,
|
||||
GGML_TYPE_BF16 = 30,
|
||||
GGML_TYPE_Q4_0_4_4 = 31,
|
||||
GGML_TYPE_Q4_0_4_8 = 32,
|
||||
GGML_TYPE_Q4_0_8_8 = 33,
|
||||
GGML_TYPE_COUNT,
|
||||
};
|
||||
|
||||
@@ -424,6 +427,9 @@ extern "C" {
|
||||
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
|
||||
};
|
||||
|
||||
// available tensor operations:
|
||||
@@ -2406,6 +2412,12 @@ extern "C" {
|
||||
typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
|
||||
const void * GGML_RESTRICT y, size_t by, int nrc);
|
||||
typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr,
|
||||
int64_t k, int64_t bx);
|
||||
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
|
||||
const void * GGML_RESTRICT y, int nr, int nc);
|
||||
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
|
||||
const void * GGML_RESTRICT y, int nr, int nc);
|
||||
|
||||
typedef struct {
|
||||
const char * type_name;
|
||||
@@ -2418,6 +2430,11 @@ extern "C" {
|
||||
ggml_vec_dot_t vec_dot;
|
||||
enum ggml_type vec_dot_type;
|
||||
int64_t nrows; // number of rows to process simultaneously;
|
||||
int64_t ncols; // number of columns to process simultaneously;
|
||||
int64_t interleave_blcksize; // interleave elements in blocks of interleave_blcksize;
|
||||
ggml_from_float_to_mat_t from_float_to_mat;
|
||||
ggml_gemv_t gemv;
|
||||
ggml_gemm_t gemm;
|
||||
} ggml_type_traits_t;
|
||||
|
||||
GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
|
||||
|
||||
@@ -238,12 +238,12 @@ if (GGML_BLAS)
|
||||
endif()
|
||||
|
||||
if (GGML_LLAMAFILE)
|
||||
message(STATUS "Using ggml SGEMM")
|
||||
message(STATUS "Using llamafile")
|
||||
|
||||
add_compile_definitions(GGML_USE_LLAMAFILE)
|
||||
|
||||
set(GGML_HEADERS_LLAMAFILE sgemm.h)
|
||||
set(GGML_SOURCES_LLAMAFILE sgemm.cpp)
|
||||
set(GGML_HEADERS_LLAMAFILE llamafile/sgemm.h)
|
||||
set(GGML_SOURCES_LLAMAFILE llamafile/sgemm.cpp)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA)
|
||||
@@ -1153,6 +1153,7 @@ add_library(ggml
|
||||
${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
|
||||
${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS}
|
||||
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
|
||||
ggml-aarch64.c ggml-aarch64.h
|
||||
)
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
|
||||
2187
ggml/src/ggml-aarch64.c
Normal file
2187
ggml/src/ggml-aarch64.c
Normal file
File diff suppressed because it is too large
Load Diff
39
ggml/src/ggml-aarch64.h
Normal file
39
ggml/src/ggml-aarch64.h
Normal file
@@ -0,0 +1,39 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
|
||||
#pragma once
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
// GGML internal header
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Quantization
|
||||
void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t interleave_blcksize);
|
||||
|
||||
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
||||
size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
|
||||
// GEMV
|
||||
void ggml_gemv_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
// GEMM
|
||||
void ggml_gemm_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -199,6 +199,30 @@ typedef struct {
|
||||
} block_q8_1;
|
||||
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_half d[4]; // deltas for 4 q4_0 blocks
|
||||
uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks
|
||||
} block_q4_0x4;
|
||||
static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_half) + QK4_0 * 2, "wrong q4_0x4 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_half d[8]; // deltas for 8 q4_0 blocks
|
||||
uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks
|
||||
} block_q4_0x8;
|
||||
static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_half) + QK4_0 * 4, "wrong q4_0x8 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_half d[4]; // deltas for 4 q8_0 blocks
|
||||
int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks
|
||||
} block_q8_0x4;
|
||||
static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong q8_0x4 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_half d[8]; // deltas for 8 q8_0 blocks
|
||||
int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks
|
||||
} block_q8_0x8;
|
||||
static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding");
|
||||
|
||||
//
|
||||
// Super-block quantization structures
|
||||
//
|
||||
|
||||
@@ -609,6 +609,10 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
||||
|
||||
#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
|
||||
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
#include <arm_sve.h>
|
||||
#endif // __ARM_FEATURE_SVE
|
||||
|
||||
// precomputed f32 table for f16 (256 KB)
|
||||
// defined in ggml.c, initialized in ggml_init()
|
||||
extern float ggml_table_f32_f16[1 << 16];
|
||||
|
||||
@@ -3814,43 +3814,47 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||
}
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
|
||||
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
|
||||
if (svcntb() == QK8_0) {
|
||||
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
|
||||
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
|
||||
|
||||
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||
|
||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
||||
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q4_0 * restrict x0 = &x[i + 0];
|
||||
const block_q4_0 * restrict x1 = &x[i + 1];
|
||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q4_0 * restrict x0 = &x[i + 0];
|
||||
const block_q4_0 * restrict x1 = &x[i + 1];
|
||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||
|
||||
// load x
|
||||
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
||||
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
||||
// load x
|
||||
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
||||
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
|
||||
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
|
||||
// 4-bit -> 8-bit
|
||||
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
|
||||
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
|
||||
|
||||
// sub 8
|
||||
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
|
||||
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
|
||||
// sub 8
|
||||
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
|
||||
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
|
||||
|
||||
// load y
|
||||
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
||||
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
||||
// load y
|
||||
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
||||
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
||||
|
||||
// dot product
|
||||
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||
// dot product
|
||||
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||
}
|
||||
|
||||
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||
return;
|
||||
}
|
||||
|
||||
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||
#elif defined(__ARM_NEON)
|
||||
#endif
|
||||
#if defined(__ARM_NEON)
|
||||
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||
|
||||
@@ -5422,31 +5426,35 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||
}
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||
if (svcntb() == QK8_0) {
|
||||
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||
|
||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
||||
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q8_0 * restrict x0 = &x[i + 0];
|
||||
const block_q8_0 * restrict x1 = &x[i + 1];
|
||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q8_0 * restrict x0 = &x[i + 0];
|
||||
const block_q8_0 * restrict x1 = &x[i + 1];
|
||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||
|
||||
// load x
|
||||
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
|
||||
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
|
||||
// load x
|
||||
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
|
||||
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
|
||||
|
||||
// load y
|
||||
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
||||
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
||||
// load y
|
||||
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
||||
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
||||
|
||||
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||
}
|
||||
|
||||
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||
return;
|
||||
}
|
||||
|
||||
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||
#elif defined(__ARM_NEON)
|
||||
#endif
|
||||
#if defined(__ARM_NEON)
|
||||
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||
|
||||
@@ -14760,6 +14768,16 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
|
||||
} \
|
||||
}
|
||||
|
||||
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
|
||||
const type * q = (const type *) (data); \
|
||||
for (size_t i = 0; i < (nb); ++i) { \
|
||||
for (size_t j = 0; j < (nr); ++j) { \
|
||||
if (!validate_fp16(q[i].d[j], i)) { \
|
||||
return false; \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
|
||||
if (type < 0 || type >= GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
|
||||
@@ -14977,6 +14995,16 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||||
{
|
||||
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
{
|
||||
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
{
|
||||
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8);
|
||||
} break;
|
||||
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
||||
@@ -346,4 +346,10 @@ inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
|
||||
return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);
|
||||
}
|
||||
|
||||
// Helper for accessing pointers with no warnings
|
||||
template <typename Tp, int dim>
|
||||
static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
|
||||
return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
}
|
||||
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
||||
@@ -158,7 +158,7 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1);
|
||||
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1835,10 +1835,10 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q4_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_qs_q4_0_acc_ct1.get_pointer(),
|
||||
tile_x_d_q4_0_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_qs_q4_0_acc_ct1),
|
||||
get_pointer(tile_x_d_q4_0_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -1870,10 +1870,10 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q4_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_qs_q4_0_acc_ct1.get_pointer(),
|
||||
tile_x_d_q4_0_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_qs_q4_0_acc_ct1),
|
||||
get_pointer(tile_x_d_q4_0_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -1950,10 +1950,10 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q4_1<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_qs_q4_1_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q4_1_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_qs_q4_1_acc_ct1),
|
||||
get_pointer(tile_x_dm_q4_1_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -1985,10 +1985,10 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q4_1<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_qs_q4_1_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q4_1_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_qs_q4_1_acc_ct1),
|
||||
get_pointer(tile_x_dm_q4_1_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2065,10 +2065,10 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q5_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q5_0_acc_ct1.get_pointer(),
|
||||
tile_x_d_q5_0_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q5_0_acc_ct1),
|
||||
get_pointer(tile_x_d_q5_0_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2100,10 +2100,10 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q5_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q5_0_acc_ct1.get_pointer(),
|
||||
tile_x_d_q5_0_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q5_0_acc_ct1),
|
||||
get_pointer(tile_x_d_q5_0_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2180,10 +2180,10 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q5_1<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q5_1_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q5_1_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q5_1_acc_ct1),
|
||||
get_pointer(tile_x_dm_q5_1_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2215,10 +2215,10 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q5_1<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q5_1_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q5_1_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q5_1_acc_ct1),
|
||||
get_pointer(tile_x_dm_q5_1_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2295,10 +2295,10 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q8_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_qs_q8_0_acc_ct1.get_pointer(),
|
||||
tile_x_d_q8_0_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_qs_q8_0_acc_ct1),
|
||||
get_pointer(tile_x_d_q8_0_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2330,10 +2330,10 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q8_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_qs_q8_0_acc_ct1.get_pointer(),
|
||||
tile_x_d_q8_0_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_qs_q8_0_acc_ct1),
|
||||
get_pointer(tile_x_d_q8_0_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2412,11 +2412,11 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q2_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q2_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q2_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q2_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q2_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q2_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q2_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2450,11 +2450,11 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q2_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q2_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q2_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q2_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q2_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q2_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q2_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2537,12 +2537,12 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q3_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q3_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q3_K_acc_ct1.get_pointer(),
|
||||
tile_x_qh_q3_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q3_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q3_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q3_K_acc_ct1),
|
||||
get_pointer(tile_x_qh_q3_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q3_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2578,12 +2578,12 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q3_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q3_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q3_K_acc_ct1.get_pointer(),
|
||||
tile_x_qh_q3_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q3_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q3_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q3_K_acc_ct1),
|
||||
get_pointer(tile_x_qh_q3_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q3_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2663,11 +2663,11 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q4_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q4_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q4_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q4_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q4_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q4_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q4_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2701,11 +2701,11 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q4_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q4_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q4_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q4_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q4_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q4_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q4_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2784,11 +2784,11 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q5_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q5_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q5_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q5_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q5_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q5_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q5_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2822,11 +2822,11 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q5_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q5_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q5_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q5_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q5_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q5_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q5_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2905,11 +2905,11 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q6_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_acc_ct1.get_pointer(),
|
||||
tile_x_dm_acc_ct1.get_pointer(),
|
||||
tile_x_sc_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_acc_ct1),
|
||||
get_pointer(tile_x_dm_acc_ct1),
|
||||
get_pointer(tile_x_sc_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2943,11 +2943,11 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
mul_mat_q6_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_acc_ct1.get_pointer(),
|
||||
tile_x_dm_acc_ct1.get_pointer(),
|
||||
tile_x_sc_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_acc_ct1),
|
||||
get_pointer(tile_x_dm_acc_ct1),
|
||||
get_pointer(tile_x_sc_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -218,7 +218,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
s_sum_acc_ct1.get_pointer(), work_group_size);
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -265,7 +265,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
group_norm_f32(x, dst, group_size, ne_elements,
|
||||
eps_ct4, item_ct1,
|
||||
s_sum_acc_ct1.get_pointer(), work_group_size);
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -306,7 +306,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
s_sum_acc_ct1.get_pointer(), work_group_size);
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
|
||||
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
||||
nrows_y, scale, max_bias, m0,
|
||||
m1, n_head_log2, item_ct1,
|
||||
local_buf_acc.get_pointer());
|
||||
get_pointer(local_buf_acc));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
153
ggml/src/ggml.c
153
ggml/src/ggml.c
@@ -4,7 +4,7 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-quants.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include "ggml-aarch64.h"
|
||||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||
@@ -37,12 +37,12 @@
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#ifdef __ARM_FEATURE_MATMUL_INT8
|
||||
#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
#undef GGML_USE_LLAMAFILE
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_LLAMAFILE
|
||||
#include "sgemm.h"
|
||||
#include <llamafile/sgemm.h>
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
@@ -692,6 +692,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
#else
|
||||
.nrows = 1,
|
||||
#endif
|
||||
.from_float_to_mat = quantize_mat_q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q8_1] = {
|
||||
.type_name = "q8_1",
|
||||
@@ -889,6 +890,54 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
|
||||
.vec_dot_type = GGML_TYPE_BF16,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q4_0_4_4] = {
|
||||
.type_name = "q4_0_4x4",
|
||||
.blck_size = QK4_0,
|
||||
.type_size = sizeof(block_q4_0),
|
||||
.is_quantized = true,
|
||||
.to_float = NULL,
|
||||
.from_float = NULL,
|
||||
.from_float_reference = NULL,
|
||||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
.ncols = 4,
|
||||
.interleave_blcksize = 4,
|
||||
.gemv = ggml_gemv_q4_0_4x4_q8_0,
|
||||
.gemm = ggml_gemm_q4_0_4x4_q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q4_0_4_8] = {
|
||||
.type_name = "q4_0_4x8",
|
||||
.blck_size = QK4_0,
|
||||
.type_size = sizeof(block_q4_0),
|
||||
.is_quantized = true,
|
||||
.to_float = NULL,
|
||||
.from_float = NULL,
|
||||
.from_float_reference = NULL,
|
||||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
.ncols = 4,
|
||||
.interleave_blcksize = 8,
|
||||
.gemv = ggml_gemv_q4_0_4x8_q8_0,
|
||||
.gemm = ggml_gemm_q4_0_4x8_q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q4_0_8_8] = {
|
||||
.type_name = "q4_0_8x8",
|
||||
.blck_size = QK4_0,
|
||||
.type_size = sizeof(block_q4_0),
|
||||
.is_quantized = true,
|
||||
.to_float = NULL,
|
||||
.from_float = NULL,
|
||||
.from_float_reference = NULL,
|
||||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
.ncols = 8,
|
||||
.interleave_blcksize = 8,
|
||||
.gemv = ggml_gemv_q4_0_8x8_q8_0,
|
||||
.gemm = ggml_gemm_q4_0_8x8_q8_0,
|
||||
}
|
||||
};
|
||||
|
||||
@@ -3188,6 +3237,9 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
||||
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = GGML_TYPE_Q4_0_4_8; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break;
|
||||
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
||||
}
|
||||
@@ -9432,6 +9484,9 @@ static void ggml_compute_forward_add(
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
{
|
||||
ggml_compute_forward_add_q_f32(params, dst);
|
||||
} break;
|
||||
@@ -9807,6 +9862,9 @@ static void ggml_compute_forward_add1(
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
{
|
||||
ggml_compute_forward_add1_q_f32(params, dst);
|
||||
} break;
|
||||
@@ -9932,6 +9990,9 @@ static void ggml_compute_forward_acc(
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
@@ -12134,6 +12195,12 @@ static void ggml_compute_forward_mul_mat(
|
||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||
int64_t const vec_dot_num_rows = type_traits[type].nrows;
|
||||
int64_t const matmul_num_cols = type_traits[type].ncols;
|
||||
int64_t const interleave_blcksize = type_traits[type].interleave_blcksize;
|
||||
ggml_from_float_to_mat_t const from_float_to_mat
|
||||
= type_traits[vec_dot_type].from_float_to_mat;
|
||||
ggml_gemv_t const gemv = type_traits[type].gemv;
|
||||
ggml_gemm_t const gemm = type_traits[type].gemm;
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
@@ -12192,7 +12259,16 @@ UseGgmlGemm1:;
|
||||
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
||||
int64_t i11_processed = 0;
|
||||
if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
|
||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
||||
4, ne10, interleave_blcksize);
|
||||
}
|
||||
i11_processed = ne11 - ne11 % 4;
|
||||
}
|
||||
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
||||
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
||||
ne10);
|
||||
@@ -12273,6 +12349,28 @@ UseGgmlGemm2:;
|
||||
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
|
||||
|
||||
if ((ggml_n_dims(src0) == 2) && gemv) {
|
||||
const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
|
||||
int64_t src0_start = (ith * ne01) / nth;
|
||||
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
||||
src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
|
||||
src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
|
||||
if (src0_start >= src0_end) return;
|
||||
|
||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||
if (gemm && (ne11 > 3)) {
|
||||
gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
||||
}
|
||||
for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
|
||||
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
@@ -12318,6 +12416,8 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||
int64_t const matmul_num_cols = type_traits[type].ncols;
|
||||
ggml_gemv_t const gemv = type_traits[type].gemv;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
@@ -12403,6 +12503,34 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
const int64_t nr0 = ne01; // src0 rows
|
||||
const int64_t nr1 = cne1; // src1 rows
|
||||
|
||||
if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
|
||||
int64_t src0_cur_start = (ith * ne01) / nth;
|
||||
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
||||
src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
|
||||
src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
|
||||
if (src0_cur_start >= src0_cur_end) return;
|
||||
|
||||
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
||||
const int id = row_mapping.i1; // selected expert index
|
||||
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = row_mapping.i2; // row index in src1
|
||||
|
||||
const int64_t i1 = id; // selected expert index
|
||||
const int64_t i2 = i12; // row
|
||||
|
||||
const char * src1_col = (const char *) wdata +
|
||||
(src1_cont || src1->type != vec_dot_type
|
||||
? (i11 + i12 * ne11) * row_size
|
||||
: (i11 * nb11 + i12 * nb12));
|
||||
|
||||
gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
|
||||
(const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// distribute the thread work across the inner or outer loop based on which one is larger
|
||||
|
||||
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
||||
@@ -12704,6 +12832,9 @@ static void ggml_compute_forward_out_prod(
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
{
|
||||
ggml_compute_forward_out_prod_q_f32(params, dst);
|
||||
} break;
|
||||
@@ -12889,6 +13020,9 @@ static void ggml_compute_forward_set(
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
@@ -13148,6 +13282,9 @@ static void ggml_compute_forward_get_rows(
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
{
|
||||
ggml_compute_forward_get_rows_q(params, dst);
|
||||
} break;
|
||||
@@ -13734,6 +13871,9 @@ static void ggml_compute_forward_clamp(
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q8_K:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
@@ -20457,6 +20597,9 @@ size_t ggml_quantize_chunk(
|
||||
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
size_t elemsize = sizeof(ggml_fp16_t);
|
||||
@@ -21759,8 +21902,6 @@ int ggml_cpu_has_neon(void) {
|
||||
|
||||
int ggml_cpu_has_sve(void) {
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
// TODO: Currently, SVE 256 bit is only supported.
|
||||
GGML_ASSERT(svcntb() == QK8_0);
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
|
||||
@@ -79,5 +79,4 @@ python -m twine upload dist/*
|
||||
```
|
||||
|
||||
## TODO
|
||||
- [ ] Add tests
|
||||
- [ ] Include conversion scripts as command line entry points in this package.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "gguf"
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
description = "Read and write ML models in GGUF for GGML"
|
||||
authors = ["GGML <ggml@ggml.ai>"]
|
||||
packages = [
|
||||
|
||||
@@ -162,6 +162,9 @@ extern "C" {
|
||||
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
|
||||
|
||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
|
||||
@@ -57,6 +57,12 @@
|
||||
#include <io.h>
|
||||
#endif
|
||||
|
||||
#if __cplusplus >= 202000L
|
||||
#define LU8(x) (const char*)(u8##x)
|
||||
#else
|
||||
#define LU8(x) u8##x
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
@@ -3782,6 +3788,9 @@ struct llama_model_loader {
|
||||
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
|
||||
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
|
||||
case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
|
||||
case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break;
|
||||
case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break;
|
||||
case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break;
|
||||
default:
|
||||
{
|
||||
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
|
||||
@@ -4475,6 +4484,9 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: return "Q4_0_4_8";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: return "Q4_0_8_8";
|
||||
|
||||
default: return "unknown, may not work";
|
||||
}
|
||||
@@ -17762,6 +17774,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||
new_type = GGML_TYPE_IQ3_S;
|
||||
}
|
||||
else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 ||
|
||||
new_type == GGML_TYPE_Q4_0_8_8) {
|
||||
new_type = GGML_TYPE_Q4_0;
|
||||
}
|
||||
}
|
||||
} else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
|
||||
@@ -18074,6 +18090,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = GGML_TYPE_Q4_0_4_4; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: default_type = GGML_TYPE_Q4_0_4_8; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: default_type = GGML_TYPE_Q4_0_8_8; break;
|
||||
|
||||
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
|
||||
}
|
||||
@@ -18384,6 +18403,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
f32_data = (float *) f32_conv_buf.data();
|
||||
}
|
||||
|
||||
int chunk_size_multiplier = 1;
|
||||
if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || new_type == GGML_TYPE_Q4_0_8_8) {
|
||||
if ((new_type == GGML_TYPE_Q4_0_8_8) && (tensor->ne[1] % 8 != 0)) new_type = GGML_TYPE_Q4_0;
|
||||
else if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0;
|
||||
if (new_type == GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8;
|
||||
else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
|
||||
fflush(stdout);
|
||||
|
||||
@@ -18396,7 +18423,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
const int64_t nrows = tensor->ne[1];
|
||||
|
||||
static const int64_t min_chunk_size = 32 * 512;
|
||||
const int64_t chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row);
|
||||
const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)) *
|
||||
chunk_size_multiplier;
|
||||
|
||||
const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
|
||||
const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
|
||||
@@ -21511,12 +21539,12 @@ static int32_t llama_chat_apply_template_internal(
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else if (tmpl == "minicpm" || tmpl_contains(u8"<用户>")) {
|
||||
} else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) {
|
||||
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "user") {
|
||||
ss << u8"<用户>";
|
||||
ss << LU8("<用户>");
|
||||
ss << trim(message->content);
|
||||
ss << "<AI>";
|
||||
} else {
|
||||
@@ -21532,7 +21560,7 @@ static int32_t llama_chat_apply_template_internal(
|
||||
} else if (role == "user") {
|
||||
ss << "User: " << message->content << "\n\n";
|
||||
} else if (role == "assistant") {
|
||||
ss << "Assistant: " << message->content << u8"<|end▁of▁sentence|>";
|
||||
ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>");
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
|
||||
Reference in New Issue
Block a user