mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-19 14:13:22 +02:00
Compare commits
15 Commits
gg/remove-
...
b3134
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef52d1d16a | ||
|
|
14f83526cd | ||
|
|
6fe42d073f | ||
|
|
148995e5e5 | ||
|
|
4bfe50f741 | ||
|
|
bdcb8f4222 | ||
|
|
c2ce6c47e4 | ||
|
|
b61eb9644d | ||
|
|
396b18dfec | ||
|
|
864a99e7a0 | ||
|
|
fd5ea0f897 | ||
|
|
c28a83902c | ||
|
|
d9da0e4986 | ||
|
|
1f0dabda8d | ||
|
|
af4ae502dd |
@@ -2,4 +2,4 @@
|
||||
- [ ] Review Complexity : Low
|
||||
- [ ] Review Complexity : Medium
|
||||
- [ ] Review Complexity : High
|
||||
- [ ] I have read the [contributing guidelines](CONTRIBUTING.md)
|
||||
- [ ] I have read the [contributing guidelines](https://github.com/ggerganov/llama.cpp/blob/master/CONTRIBUTING.md)
|
||||
9
.github/workflows/build.yml
vendored
9
.github/workflows/build.yml
vendored
@@ -13,7 +13,7 @@ on:
|
||||
paths: ['.github/workflows/**', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m']
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
paths: ['**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m']
|
||||
paths: ['.github/workflows/build.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m']
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
@@ -684,7 +684,7 @@ jobs:
|
||||
cmake --build build --config ${{ matrix.build }} -j $(nproc)
|
||||
|
||||
windows-latest-cmake:
|
||||
runs-on: windows-latest
|
||||
runs-on: windows-2019
|
||||
|
||||
env:
|
||||
OPENBLAS_VERSION: 0.3.23
|
||||
@@ -829,7 +829,7 @@ jobs:
|
||||
name: llama-bin-win-${{ matrix.build }}.zip
|
||||
|
||||
windows-latest-cmake-cuda:
|
||||
runs-on: windows-latest
|
||||
runs-on: windows-2019
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -843,8 +843,9 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- uses: Jimver/cuda-toolkit@v0.2.11
|
||||
- name: Install CUDA toolkit
|
||||
id: cuda-toolkit
|
||||
uses: Jimver/cuda-toolkit@v0.2.15
|
||||
with:
|
||||
cuda: ${{ matrix.cuda }}
|
||||
method: 'network'
|
||||
|
||||
6
.github/workflows/server.yml
vendored
6
.github/workflows/server.yml
vendored
@@ -16,11 +16,9 @@ on:
|
||||
branches:
|
||||
- master
|
||||
paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*']
|
||||
pull_request_target:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*']
|
||||
schedule:
|
||||
- cron: '2 4 * * *'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }}
|
||||
@@ -115,7 +113,7 @@ jobs:
|
||||
|
||||
|
||||
server-windows:
|
||||
runs-on: windows-latest
|
||||
runs-on: windows-2019
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
@@ -402,12 +402,26 @@ if (LLAMA_CUBLAS)
|
||||
endif()
|
||||
|
||||
if (LLAMA_CUDA)
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
|
||||
|
||||
find_package(CUDAToolkit)
|
||||
if (CUDAToolkit_FOUND)
|
||||
message(STATUS "CUDA found")
|
||||
|
||||
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||
# 52 == lowest CUDA 12 standard
|
||||
# 60 == f16 CUDA intrinsics
|
||||
# 61 == integer CUDA intrinsics
|
||||
# 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
|
||||
if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
|
||||
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
|
||||
else()
|
||||
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
|
||||
#set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
||||
enable_language(CUDA)
|
||||
|
||||
set(GGML_HEADERS_CUDA ggml-cuda.h)
|
||||
@@ -472,21 +486,6 @@ if (LLAMA_CUDA)
|
||||
else()
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||
# 52 == lowest CUDA 12 standard
|
||||
# 60 == f16 CUDA intrinsics
|
||||
# 61 == integer CUDA intrinsics
|
||||
# 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
|
||||
if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
|
||||
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
|
||||
else()
|
||||
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
|
||||
#set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
||||
else()
|
||||
message(WARNING "CUDA not found")
|
||||
endif()
|
||||
|
||||
@@ -40,7 +40,7 @@ static std::string build_repetition(const std::string & item_rule, int min_items
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::string SPACE_RULE = "\" \"?";
|
||||
const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
|
||||
|
||||
struct BuiltinRule {
|
||||
std::string content;
|
||||
@@ -57,7 +57,7 @@ std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
|
||||
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
|
||||
{"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
|
||||
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
|
||||
{"char", {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
|
||||
{"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
|
||||
{"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
|
||||
{"null", {"\"null\" space", {}}},
|
||||
};
|
||||
|
||||
@@ -29,9 +29,8 @@ class BuiltinRule:
|
||||
self.content = content
|
||||
self.deps = deps or []
|
||||
|
||||
# whitespace is constrained to a single space char to prevent model "running away" in
|
||||
# whitespace. Also maybe improves generation quality?
|
||||
SPACE_RULE = '" "?'
|
||||
# Constraining spaces to prevent model "running away".
|
||||
SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}'
|
||||
|
||||
PRIMITIVE_RULES = {
|
||||
'boolean' : BuiltinRule('("true" | "false") space', []),
|
||||
@@ -43,7 +42,7 @@ PRIMITIVE_RULES = {
|
||||
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
|
||||
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
|
||||
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
|
||||
'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})', []),
|
||||
'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
|
||||
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
|
||||
'null' : BuiltinRule('"null" space', []),
|
||||
}
|
||||
|
||||
@@ -1033,6 +1033,27 @@ struct markdown_printer : public printer {
|
||||
if (field == "n_gpu_layers") {
|
||||
return 3;
|
||||
}
|
||||
if (field == "n_threads") {
|
||||
return 7;
|
||||
}
|
||||
if (field == "n_batch") {
|
||||
return 7;
|
||||
}
|
||||
if (field == "n_ubatch") {
|
||||
return 8;
|
||||
}
|
||||
if (field == "type_k" || field == "type_v") {
|
||||
return 6;
|
||||
}
|
||||
if (field == "split_mode") {
|
||||
return 5;
|
||||
}
|
||||
if (field == "flash_attn") {
|
||||
return 2;
|
||||
}
|
||||
if (field == "use_mmap") {
|
||||
return 4;
|
||||
}
|
||||
if (field == "test") {
|
||||
return 13;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// WARNING: This file was ported from json_schema_to_grammar.py, please fix bugs / add features there first.
|
||||
const SPACE_RULE = '" "?';
|
||||
const SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}';
|
||||
|
||||
function _buildRepetition(itemRule, minItems, maxItems, opts={}) {
|
||||
if (minItems === 0 && maxItems === 1) {
|
||||
@@ -41,7 +41,7 @@ const PRIMITIVE_RULES = {
|
||||
object : new BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
|
||||
array : new BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
|
||||
uuid : new BuiltinRule('"\\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\\"" space', []),
|
||||
char : new BuiltinRule(`[^"\\\\] | "\\\\" (["\\\\/bfnrt] | "u" [0-9a-fA-F]{4})`, []),
|
||||
char : new BuiltinRule(`[^"\\\\\\x7F\\x00-\\x1F] | [\\\\] (["\\\\bfnrt] | "u" [0-9a-fA-F]{4})`, []),
|
||||
string : new BuiltinRule(`"\\"" char* "\\"" space`, ['char']),
|
||||
null : new BuiltinRule('"null" space', []),
|
||||
};
|
||||
|
||||
@@ -147,7 +147,7 @@ struct server_slot {
|
||||
int32_t n_prompt_tokens = 0;
|
||||
int32_t n_prompt_tokens_processed = 0;
|
||||
|
||||
json prompt;
|
||||
std::string prompt;
|
||||
|
||||
// when a task is submitted, we first tokenize the prompt and store it here
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
@@ -822,13 +822,8 @@ struct server_context {
|
||||
continue;
|
||||
}
|
||||
|
||||
// skip the slot if it does not contains prompt
|
||||
if (!slot.prompt.is_string()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// current slot's prompt
|
||||
std::string slot_prompt = slot.prompt.get<std::string>();
|
||||
std::string slot_prompt = slot.prompt;
|
||||
|
||||
// length of the current slot's prompt
|
||||
int slot_prompt_len = slot_prompt.size();
|
||||
@@ -958,13 +953,16 @@ struct server_context {
|
||||
if (!task.infill) {
|
||||
const auto & prompt = data.find("prompt");
|
||||
if (prompt == data.end()) {
|
||||
send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
||||
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
} else {
|
||||
slot.prompt = *prompt;
|
||||
}
|
||||
if (slot.prompt.is_array() && slot.prompt.size() == 0) {
|
||||
send_error(task, "\"prompt\" cannot be an empty array", ERROR_TYPE_INVALID_REQUEST);
|
||||
|
||||
if (prompt->is_string()) {
|
||||
slot.prompt = prompt->get<std::string>();
|
||||
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) {
|
||||
slot.prompt = prompt->at(0).get<std::string>();
|
||||
} else {
|
||||
send_error(task, "\"prompt\" must be a string or an array of strings", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1582,14 +1580,18 @@ struct server_context {
|
||||
switch (task.type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
{
|
||||
int id_slot = json_value(task.data, "id_slot", -1);
|
||||
std::string prompt = json_value(task.data, "prompt", std::string());
|
||||
const int id_slot = json_value(task.data, "id_slot", -1);
|
||||
|
||||
server_slot * slot;
|
||||
|
||||
if (id_slot != -1) {
|
||||
slot = get_slot_by_id(id_slot);
|
||||
} else {
|
||||
std::string prompt;
|
||||
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
|
||||
json_value(task.data, "prompt", std::string());
|
||||
}
|
||||
|
||||
slot = get_available_slot(prompt);
|
||||
}
|
||||
|
||||
|
||||
@@ -886,7 +886,7 @@ static bool alloc_tensor_range(struct ggml_context * ctx,
|
||||
fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size);
|
||||
#endif
|
||||
for (size_t i = 0; i < *n_buffers; i++) {
|
||||
ggml_backend_buffer_free(*buffers[i]);
|
||||
ggml_backend_buffer_free((*buffers)[i]);
|
||||
}
|
||||
free(*buffers);
|
||||
return false;
|
||||
|
||||
@@ -139,6 +139,7 @@
|
||||
#define CC_PASCAL 600
|
||||
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||
#define CC_VOLTA 700
|
||||
#define CC_TURING 750
|
||||
#define CC_AMPERE 800
|
||||
#define CC_OFFSET_AMD 1000000
|
||||
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
||||
@@ -326,9 +327,17 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
|
||||
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
||||
#endif // defined(GGML_USE_HIPBLAS)
|
||||
|
||||
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
||||
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
||||
#define FP16_AVAILABLE
|
||||
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
||||
|
||||
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
#define FP16_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
||||
#define INT8_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
return cc >= CC_PASCAL && cc != 610;
|
||||
@@ -338,6 +347,10 @@ static bool fp16_mma_available(const int cc) {
|
||||
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
|
||||
}
|
||||
|
||||
static bool int8_mma_available(const int cc) {
|
||||
return cc < CC_OFFSET_AMD && cc >= CC_TURING;
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
static __device__ void no_device_code(
|
||||
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
||||
@@ -379,7 +392,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
#pragma unroll
|
||||
@@ -412,7 +425,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
|
||||
return __float2half(fmaxf(__half2float(a), __half2float(b)));
|
||||
|
||||
@@ -74,7 +74,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
||||
|
||||
const int sumi = __dp4a(v, u, 0);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
@@ -122,7 +122,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
||||
|
||||
const int sumi = __dp4a(v, u, 0);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
@@ -181,7 +181,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
||||
|
||||
const int sumi = __dp4a(v, u, 0);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
@@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
||||
|
||||
const int sumi = __dp4a(v, u, 0);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
@@ -314,7 +314,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
||||
GGML_UNUSED(Q_q8);
|
||||
GGML_UNUSED(Q_ds_v);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_h2 = (const half2 *) Q_v;
|
||||
|
||||
@@ -407,7 +407,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
|
||||
const int q0 = x[ib].qs[iqs];
|
||||
const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return ((half) d)*((half) q);
|
||||
}
|
||||
@@ -428,7 +428,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__
|
||||
const int q0 = x[ib].qs[iqs];
|
||||
const int q = ((q0 >> (4*shift)) & 0x0F);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return __low2half(dm)*((half) q) + __high2half(dm);
|
||||
}
|
||||
@@ -453,7 +453,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
|
||||
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
||||
const int q = (ql | qh) - 16;
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return ((half) d)*((half) q);
|
||||
}
|
||||
@@ -478,7 +478,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
|
||||
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
||||
const int q = (ql | qh);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return __low2half(dm)*((half) q) + __high2half(dm);
|
||||
}
|
||||
@@ -497,7 +497,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__
|
||||
const T d = x[ib].d;
|
||||
const int q = x[ib].qs[iqs];
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return ((half) d)*((half) q);
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
@@ -40,7 +40,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
#if FP16_MMA_AVAILABLE
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
#include <mma.h>
|
||||
#endif
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
|
||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
||||
@@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if FP16_MMA_AVAILABLE
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||
|
||||
161
ggml-cuda/mma.cuh
Normal file
161
ggml-cuda/mma.cuh
Normal file
@@ -0,0 +1,161 @@
|
||||
#include "common.cuh"
|
||||
|
||||
struct mma_int_A_I16K4 {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int K = 4;
|
||||
static constexpr int ne = 2;
|
||||
|
||||
int x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l%2) * (I/2) + threadIdx.x / K;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_k(const int /* l */) {
|
||||
const int ret = threadIdx.x % K;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_A_I16K8 {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int K = 8;
|
||||
static constexpr int ne = 4;
|
||||
|
||||
int x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_k(const int l) {
|
||||
const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_B_J8K4 {
|
||||
static constexpr int J = 8;
|
||||
static constexpr int K = 4;
|
||||
static constexpr int ne = 1;
|
||||
|
||||
int x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
||||
const int ret = threadIdx.x / K;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_k(const int /* l */) {
|
||||
const int ret = threadIdx.x % K;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_B_J8K8 {
|
||||
static constexpr int J = 8;
|
||||
static constexpr int K = 8;
|
||||
static constexpr int ne = 2;
|
||||
|
||||
int x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
||||
const int ret = threadIdx.x / (K/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_k(const int l) {
|
||||
const int ret = l * (K/2) + threadIdx.x % (K/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_C_I16J8 {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int J = 8;
|
||||
static constexpr int ne = 4;
|
||||
|
||||
int x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
#if __CUDA_ARCH__ >= CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
||||
#endif // __CUDA_ARCH__ >= CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(mma_A);
|
||||
GGML_UNUSED(mma_B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
#if __CUDA_ARCH__ >= CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
|
||||
#else
|
||||
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "r"(mma_A.x[2]), "r"(mma_B.x[1]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[3]), "r"(mma_B.x[1]));
|
||||
#endif // __CUDA_ARCH__ >= CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(mma_A);
|
||||
GGML_UNUSED(mma_B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
};
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "common.cuh"
|
||||
#include "vecdotq.cuh"
|
||||
#include "mma.cuh"
|
||||
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
@@ -14,6 +15,7 @@ typedef void (*load_tiles_mmq_t)(
|
||||
typedef void (*vec_dot_mmq_t)(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0);
|
||||
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
|
||||
|
||||
struct block_q8_1_mmq {
|
||||
half2 ds[4];
|
||||
@@ -141,15 +143,15 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
const float * x_dmf = (const float *) x_dm;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
const float * x_df = (const float *) x_dm;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
@@ -170,12 +172,76 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
|
||||
}
|
||||
|
||||
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
||||
(&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
|
||||
(&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
|
||||
y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K8 mma_A;
|
||||
typedef mma_int_B_J8K8 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const float * x_df = (const float *) x_dm;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
mma_A A;
|
||||
float dA[mma_C::ne/2];
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = k0 + mma_A::get_k(l) % QI4_0;
|
||||
const int shift = 4*(mma_A::get_k(l) / QI4_0);
|
||||
|
||||
A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
|
||||
}
|
||||
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
mma_C C;
|
||||
mma_B B;
|
||||
half2 dsB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
|
||||
|
||||
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
||||
}
|
||||
|
||||
C.mma_K8(A, B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
||||
@@ -215,7 +281,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
@@ -249,6 +315,70 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K8 mma_A;
|
||||
typedef mma_int_B_J8K8 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
mma_A A;
|
||||
half2 dmA[mma_C::ne/2];
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = k0 + mma_A::get_k(l) % QI4_0;
|
||||
const int shift = 4*(mma_A::get_k(l) / QI4_0);
|
||||
|
||||
A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
|
||||
}
|
||||
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
mma_C C;
|
||||
mma_B B;
|
||||
half2 dsB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
|
||||
|
||||
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
||||
}
|
||||
|
||||
C.mma_K8(A, B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
|
||||
sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
||||
@@ -308,7 +438,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
@@ -343,6 +473,68 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K8 mma_A;
|
||||
typedef mma_int_B_J8K8 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const float * x_df = (const float *) x_dm;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const float * y_df = (const float *) y;
|
||||
|
||||
mma_A A;
|
||||
float dA[mma_C::ne/2];
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
|
||||
|
||||
A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0];
|
||||
}
|
||||
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
mma_C C;
|
||||
mma_B B;
|
||||
float dB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
|
||||
|
||||
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
||||
}
|
||||
|
||||
C.mma_K8(A, B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
@@ -400,7 +592,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
@@ -434,6 +626,69 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K8 mma_A;
|
||||
typedef mma_int_B_J8K8 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
mma_A A;
|
||||
half2 dmA[mma_C::ne/2];
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
|
||||
|
||||
A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1];
|
||||
}
|
||||
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
mma_C C;
|
||||
mma_B B;
|
||||
half2 dsB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
|
||||
|
||||
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
||||
}
|
||||
|
||||
C.mma_K8(A, B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
|
||||
sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
||||
@@ -475,7 +730,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
@@ -500,6 +755,69 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K8 mma_A;
|
||||
typedef mma_int_B_J8K8 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const float * x_df = (const float *) x_dm;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const float * y_df = (const float *) y;
|
||||
|
||||
mma_A A;
|
||||
float dA[mma_C::ne/2];
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = k0 + mma_A::get_k(l);
|
||||
|
||||
A.x[l] = x_ql[i*(WARP_SIZE + 1) + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
|
||||
}
|
||||
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
mma_C C;
|
||||
mma_B B;
|
||||
float dB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = k0 + mma_B::get_k(l);
|
||||
|
||||
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
|
||||
}
|
||||
|
||||
C.mma_K8(A, B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
||||
@@ -771,7 +1089,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
@@ -797,6 +1115,97 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K8 mma_A;
|
||||
typedef mma_int_B_J8K8 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
mma_A A[2];
|
||||
int scA[mma_C::ne/2][2];
|
||||
int mA[mma_C::ne/2][2];
|
||||
half2 dmA[mma_C::ne/2];
|
||||
#pragma unroll
|
||||
for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = k0 + mma_A::get_k(l);
|
||||
|
||||
A[kvdr/4].x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
|
||||
const uint8_t * m = sc + 8;
|
||||
|
||||
scA[l][kvdr/4] = sc[kvdr/4];
|
||||
mA[l][kvdr/4] = m[kvdr/4];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
float tmpd[mma_C::ne] = {0.0f};
|
||||
float tmpm[mma_C::ne] = {0.0f};
|
||||
|
||||
#pragma unroll
|
||||
for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
|
||||
mma_C C;
|
||||
mma_B B;
|
||||
half2 dsB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
|
||||
|
||||
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
|
||||
}
|
||||
|
||||
C.mma_K8(A[kvdr/4], B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
|
||||
tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
||||
@@ -870,7 +1279,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
@@ -896,6 +1305,97 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K8 mma_A;
|
||||
typedef mma_int_B_J8K8 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
mma_A A[2];
|
||||
int scA[mma_C::ne/2][2];
|
||||
int mA[mma_C::ne/2][2];
|
||||
half2 dmA[mma_C::ne/2];
|
||||
#pragma unroll
|
||||
for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
|
||||
|
||||
A[kvdr/4].x[l] = x_ql[i*(QR5_K*WARP_SIZE + 1) + k];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
|
||||
const uint8_t * m = sc + 8;
|
||||
|
||||
scA[l][kvdr/4] = sc[kvdr/4];
|
||||
mA[l][kvdr/4] = m[kvdr/4];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
float tmpd[mma_C::ne] = {0.0f};
|
||||
float tmpm[mma_C::ne] = {0.0f};
|
||||
|
||||
#pragma unroll
|
||||
for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
|
||||
mma_C C;
|
||||
mma_B B;
|
||||
half2 dsB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
|
||||
|
||||
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
|
||||
}
|
||||
|
||||
C.mma_K8(A[kvdr/4], B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
|
||||
tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
||||
@@ -962,7 +1462,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
@@ -989,6 +1489,148 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K4 mma_A;
|
||||
typedef mma_int_B_J8K4 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const float * x_df = (const float *) x_dm;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const float * y_df = (const float *) y;
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
mma_A A[4];
|
||||
int scA[mma_C::ne/2][4];
|
||||
float dA[mma_C::ne/2];
|
||||
#pragma unroll
|
||||
for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
|
||||
|
||||
A[kvdr/2 + 0].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + 0];
|
||||
A[kvdr/2 + 1].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
|
||||
|
||||
scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0];
|
||||
scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
float tmp[mma_C::ne] = {0.0f};
|
||||
|
||||
#pragma unroll
|
||||
for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
|
||||
mma_C C[2];
|
||||
mma_B B[2];
|
||||
float dB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
|
||||
|
||||
B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
|
||||
B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
|
||||
}
|
||||
|
||||
C[0].mma_K4(A[kvdr/2 + 0], B[0]);
|
||||
C[1].mma_K4(A[kvdr/2 + 1], B[1]);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
|
||||
|
||||
if (j >= ne1) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||
const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
|
||||
|
||||
if (need_check && i >= ne0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const int i0 = threadIdx.y*mma_C::I;
|
||||
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
|
||||
|
||||
if (j >= ne1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
|
||||
|
||||
if (need_check && i >= ne0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
|
||||
@@ -998,35 +1640,65 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
|
||||
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
|
||||
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
|
||||
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
|
||||
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
|
||||
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
@@ -1034,6 +1706,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
|
||||
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
@@ -1041,27 +1714,46 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
|
||||
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
|
||||
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
|
||||
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
|
||||
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
static int mmq_need_sum(const ggml_type type_x) {
|
||||
@@ -1118,6 +1810,7 @@ static __global__ void mul_mat_q(
|
||||
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
|
||||
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
||||
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
|
||||
constexpr mmq_write_back_t write_back = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::write_back;
|
||||
|
||||
constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
|
||||
|
||||
@@ -1137,7 +1830,7 @@ static __global__ void mul_mat_q(
|
||||
|
||||
const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
|
||||
|
||||
float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
|
||||
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
||||
|
||||
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
|
||||
|
||||
@@ -1164,25 +1857,7 @@ static __global__ void mul_mat_q(
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
|
||||
|
||||
if (j >= ne1) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||
const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
|
||||
|
||||
if (need_check && i >= ne0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
||||
}
|
||||
}
|
||||
write_back(sum, dst, ne0, ne1);
|
||||
}
|
||||
|
||||
struct mmq_args {
|
||||
@@ -1256,10 +1931,10 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
|
||||
launch_mul_mat_q<type, 8, 4>(args, stream);
|
||||
break;
|
||||
case 16:
|
||||
launch_mul_mat_q<type, 16, 8>(args, stream);
|
||||
launch_mul_mat_q<type, 16, 4>(args, stream);
|
||||
break;
|
||||
case 24:
|
||||
launch_mul_mat_q<type, 24, 8>(args, stream);
|
||||
launch_mul_mat_q<type, 24, 4>(args, stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_mul_mat_q<type, 32, 8>(args, stream);
|
||||
|
||||
@@ -13089,10 +13089,12 @@ void *ggml_sycl_host_malloc(size_t size) try {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_sycl_set_device(g_main_device);
|
||||
dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
|
||||
|
||||
void * ptr = nullptr;
|
||||
//allow to use dpct::get_in_order_queue() for host malloc
|
||||
dpct::err0 err = CHECK_TRY_ERROR(
|
||||
ptr = (void *)sycl::malloc_host(size, dpct::get_in_order_queue()));
|
||||
ptr = (void *)sycl::malloc_host(size, *main_stream));
|
||||
|
||||
if (err != 0) {
|
||||
// clear the error
|
||||
@@ -13113,8 +13115,9 @@ catch (sycl::exception const &exc) {
|
||||
}
|
||||
|
||||
void ggml_sycl_host_free(void *ptr) try {
|
||||
//allow to use dpct::get_in_order_queue() for host malloc
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue())));
|
||||
ggml_sycl_set_device(g_main_device);
|
||||
dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *main_stream)));
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -150,7 +150,7 @@ struct vk_device {
|
||||
vk_pipeline pipeline_relu_f32;
|
||||
vk_pipeline pipeline_diag_mask_inf_f32;
|
||||
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
||||
vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
|
||||
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
|
||||
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
||||
vk_pipeline pipeline_argsort_f32;
|
||||
vk_pipeline pipeline_sum_rows_f32;
|
||||
@@ -283,26 +283,15 @@ struct vk_op_diag_mask_push_constants {
|
||||
|
||||
struct vk_op_rope_push_constants {
|
||||
uint32_t ncols;
|
||||
uint32_t n_dims;
|
||||
float freq_scale;
|
||||
uint32_t p_delta_rows;
|
||||
float freq_base;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float corr_dims[4];
|
||||
};
|
||||
|
||||
struct vk_op_rope_neox_push_constants {
|
||||
uint32_t ncols;
|
||||
uint32_t ndims;
|
||||
float freq_scale;
|
||||
uint32_t p_delta_rows;
|
||||
float freq_base;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float corr_dims[4];
|
||||
float corr_dims[2];
|
||||
float theta_scale;
|
||||
float inv_ndims;
|
||||
uint32_t has_freq_facs;
|
||||
uint32_t has_ff;
|
||||
};
|
||||
|
||||
struct vk_op_soft_max_push_constants {
|
||||
@@ -1534,11 +1523,11 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
@@ -3905,10 +3894,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
} else {
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_rope_f32;
|
||||
return ctx->device->pipeline_rope_norm_f32;
|
||||
}
|
||||
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
return ctx->device->pipeline_rope_f16;
|
||||
return ctx->device->pipeline_rope_norm_f16;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
@@ -4152,24 +4141,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
||||
} else if (op == GGML_OP_ROPE) {
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const bool is_neox = mode & 2;
|
||||
|
||||
if (is_neox) {
|
||||
// Empty src2 is possible in rope, but the shader needs a buffer
|
||||
vk_subbuffer subbuf_z;
|
||||
if (use_src2) {
|
||||
subbuf_z = { d_Z, z_buf_offset, z_sz };
|
||||
} else {
|
||||
subbuf_z = { d_X, 0, d_X->size };
|
||||
}
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
||||
// Empty src2 is possible in rope, but the shader needs a buffer
|
||||
vk_subbuffer subbuf_z;
|
||||
if (use_src2) {
|
||||
subbuf_z = { d_Z, z_buf_offset, z_sz };
|
||||
} else {
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
||||
subbuf_z = { d_X, 0, d_X->size };
|
||||
}
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
||||
} else if (use_src2) {
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_Z, z_buf_offset, z_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
||||
@@ -4391,7 +4372,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
|
||||
|
||||
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
// const int mode = ((int32_t *) dst->op_params)[2];
|
||||
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
const float freq_base = ((float *) dst->op_params)[5];
|
||||
@@ -4401,28 +4382,16 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
|
||||
const float beta_fast = ((float *) dst->op_params)[9];
|
||||
const float beta_slow = ((float *) dst->op_params)[10];
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
|
||||
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
||||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
||||
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
if (is_neox) {
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
const float inv_ndims = -1.0f / n_dims;
|
||||
ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
|
||||
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
||||
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f}, theta_scale, inv_ndims,
|
||||
src2 != nullptr,
|
||||
});
|
||||
} else {
|
||||
ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
|
||||
(uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1],
|
||||
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f}
|
||||
});
|
||||
}
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
|
||||
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
||||
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
||||
src2 != nullptr,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
@@ -6070,7 +6039,13 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(
|
||||
std::cerr << "ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")" << std::endl;
|
||||
#endif
|
||||
ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
|
||||
vk_buffer dev_buffer = ggml_vk_create_buffer_device(ctx->ctx, size);
|
||||
|
||||
vk_buffer dev_buffer = nullptr;
|
||||
try {
|
||||
dev_buffer = ggml_vk_create_buffer_device(ctx->ctx, size);
|
||||
} catch (const vk::SystemError& e) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->ctx, std::move(dev_buffer), ctx->name);
|
||||
|
||||
@@ -6466,7 +6441,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
||||
// return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
// } break;
|
||||
case GGML_OP_ROPE:
|
||||
return true;
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
|
||||
@@ -2400,7 +2400,7 @@ void main() {
|
||||
"""
|
||||
|
||||
# ROPE
|
||||
rope_src = """
|
||||
rope_norm_src = """
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
@@ -2408,17 +2408,21 @@ rope_src = """
|
||||
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {int data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
layout (binding = 1) readonly buffer Y {int data_pos[];};
|
||||
layout (binding = 2) readonly buffer Z {float data_ff[];};
|
||||
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
uint n_dims;
|
||||
float freq_scale;
|
||||
uint p_delta_rows;
|
||||
float freq_base;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float corr_dims[4];
|
||||
float corr_dims[2];
|
||||
float theta_scale;
|
||||
uint has_ff;
|
||||
} p;
|
||||
|
||||
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||
@@ -2450,14 +2454,24 @@ void main() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (col >= p.n_dims) {
|
||||
const uint i = row*p.ncols + col;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i = row*p.ncols + col;
|
||||
const uint i2 = row/p.p_delta_rows;
|
||||
|
||||
const int pos = data_b[i2];
|
||||
const float theta_base = pos * pow(p.freq_base, -float(col)/p.ncols);
|
||||
const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base, col, cos_theta, sin_theta);
|
||||
rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = float(data_a[i + 0]);
|
||||
const float x1 = float(data_a[i + 1]);
|
||||
@@ -2475,22 +2489,21 @@ rope_neox_src = """
|
||||
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {int data_b[];};
|
||||
layout (binding = 2) readonly buffer Z {float data_freq_factors[];};
|
||||
layout (binding = 1) readonly buffer Y {int data_pos[];};
|
||||
layout (binding = 2) readonly buffer Z {float data_ff[];};
|
||||
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
uint ndims;
|
||||
uint n_dims;
|
||||
float freq_scale;
|
||||
uint p_delta_rows;
|
||||
float freq_base;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float corr_dims[4];
|
||||
float corr_dims[2];
|
||||
float theta_scale;
|
||||
float inv_ndims;
|
||||
uint has_freq_facs;
|
||||
uint has_ff;
|
||||
} p;
|
||||
|
||||
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||
@@ -2522,11 +2535,8 @@ void main() {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint ib = col / p.ndims;
|
||||
const uint ic = col % p.ndims;
|
||||
|
||||
if (ib > 0) {
|
||||
const uint i = row*p.ncols + ib*p.ndims + ic;
|
||||
if (col >= p.n_dims) {
|
||||
const uint i = row*p.ncols + col;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
@@ -2534,29 +2544,27 @@ void main() {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i = row*p.ncols + ib*p.ndims + ic/2;
|
||||
const uint i = row*p.ncols + col/2;
|
||||
const uint i2 = row/p.p_delta_rows;
|
||||
|
||||
const int pos = data_b[i2];
|
||||
const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
|
||||
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
|
||||
const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base, ic, cos_theta, sin_theta);
|
||||
rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = float(data_a[i + 0]);
|
||||
const float x1 = float(data_a[i + p.ndims/2]);
|
||||
const float x1 = float(data_a[i + p.n_dims/2]);
|
||||
|
||||
data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
data_d[i + p.ndims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
"""
|
||||
|
||||
argsort_src = """
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#define BLOCK_SIZE 1024
|
||||
#define ASC 0
|
||||
|
||||
@@ -3039,8 +3047,8 @@ async def main():
|
||||
tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("soft_max_f32_f16", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float16_t", "C_TYPE": "float16_t", "D_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
|
||||
tasks.append(string_to_spv("rope_norm_f32", rope_norm_src, {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("rope_norm_f16", rope_norm_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
|
||||
|
||||
tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
|
||||
|
||||
@@ -94,6 +94,8 @@ This guide provides a brief overview. Check out the GBNF files in this directory
|
||||
./main -m <model> --grammar-file grammars/some-grammar.gbnf -p 'Some prompt'
|
||||
```
|
||||
|
||||
`llama.cpp` can also convert JSON schemas to grammars either ahead of time or at each request, see below.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
Grammars currently have performance gotchas (see https://github.com/ggerganov/llama.cpp/issues/4218).
|
||||
@@ -103,3 +105,40 @@ Grammars currently have performance gotchas (see https://github.com/ggerganov/ll
|
||||
A common pattern is to allow repetitions of a pattern `x` up to N times.
|
||||
|
||||
While semantically correct, the syntax `x? x? x?.... x?` (with N repetitions) may result in extremely slow sampling. Instead, you can write `x{0,N}` (or `(x (x (x ... (x)?...)?)?)?` w/ N-deep nesting in earlier llama.cpp versions).
|
||||
|
||||
## Using GBNF grammars
|
||||
|
||||
You can use GBNF grammars:
|
||||
|
||||
- In the [server](../examples/server)'s completion endpoints, passed as the `grammar` body field
|
||||
- In the [main](../examples/main) CLI, passed as the `--grammar` & `--grammar-file` flags
|
||||
- With the [gbnf-validator](../examples/gbnf-validator) tool, to test them against strings.
|
||||
|
||||
## JSON Schemas → GBNF
|
||||
|
||||
`llama.cpp` supports converting a subset of https://json-schema.org/ to GBNF grammars:
|
||||
|
||||
- In the [server](../examples/server):
|
||||
- For any completion endpoints, passed as the `json_schema` body field
|
||||
- For the `/chat/completions` endpoint, passed inside the `result_format` body field (e.g. `{"type", "json_object", "schema": {"items": {}}}`)
|
||||
- In the [main](../examples/main) CLI, passed as the `--json` / `-j` flag
|
||||
- To convert to a grammar ahead of time:
|
||||
- in CLI, with [json_schema_to_grammar.py](../examples/json_schema_to_grammar.py)
|
||||
- in JavaScript with [json-schema-to-grammar.mjs](../examples/server/public/json-schema-to-grammar.mjs) (this is used by the [server](../examples/server)'s Web UI)
|
||||
|
||||
Take a look at [tests](../../tests/test-json-schema-to-grammar.cpp) to see which features are likely supported (you'll also find usage examples in https://github.com/ggerganov/llama.cpp/pull/5978, https://github.com/ggerganov/llama.cpp/pull/6659 & https://github.com/ggerganov/llama.cpp/pull/6555).
|
||||
|
||||
Here is also a non-exhaustive list of **unsupported** features:
|
||||
|
||||
- `additionalProperties`: to be fixed in https://github.com/ggerganov/llama.cpp/pull/7840
|
||||
- `minimum`, `exclusiveMinimum`, `maximum`, `exclusiveMaximum`
|
||||
- `integer` constraints to be implemented in https://github.com/ggerganov/llama.cpp/pull/7797
|
||||
- Remote `$ref`s in the C++ version (Python & JavaScript versions fetch https refs)
|
||||
- Mixing `properties` w/ `anyOf` / `oneOf` in the same type (https://github.com/ggerganov/llama.cpp/issues/7703)
|
||||
- `string` formats `uri`, `email`
|
||||
- [`contains`](https://json-schema.org/draft/2020-12/json-schema-core#name-contains) / `minContains`
|
||||
- `uniqueItems`
|
||||
- `$anchor` (cf. [dereferencing](https://json-schema.org/draft/2020-12/json-schema-core#name-dereferencing))
|
||||
- [`not`](https://json-schema.org/draft/2020-12/json-schema-core#name-not)
|
||||
- [Conditionals](https://json-schema.org/draft/2020-12/json-schema-core#name-keywords-for-applying-subsche) `if` / `then` / `else` / `dependentSchemas`
|
||||
- [`patternProperties`](https://json-schema.org/draft/2020-12/json-schema-core#name-patternproperties)
|
||||
|
||||
@@ -16,10 +16,10 @@ array ::=
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
ws ::= | " " | "\n" [ \t]{0,20}
|
||||
|
||||
@@ -25,10 +25,10 @@ array ::=
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
ws ::= | " " | "\n" [ \t]{0,20}
|
||||
|
||||
@@ -105,14 +105,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
@@ -135,7 +135,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
date-time ::= date "T" time
|
||||
date-time-string ::= "\"" date-time "\"" space
|
||||
root ::= "[" space tuple-0 "," space uuid "," space tuple-2 "," space tuple-3 "]" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
time ::= ([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )
|
||||
time-string ::= "\"" time "\"" space
|
||||
tuple-0 ::= date-string
|
||||
@@ -152,9 +152,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"type": "string"
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char* "\"" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -166,9 +166,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minLength": 1
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char+ "\"" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -180,9 +180,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minLength": 3
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char{3,} "\"" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -194,9 +194,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maxLength": 3
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char{0,3} "\"" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -209,9 +209,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maxLength": 4
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char{1,4} "\"" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -223,7 +223,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("true" | "false") space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -236,7 +236,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= ("-"? integral-part) space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -248,7 +248,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"foo\""
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -260,7 +260,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "123"
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -272,7 +272,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]"
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -283,9 +283,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"prefixItems": [{ "type": "string" }]
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "[" space string "]" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -297,12 +297,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"prefixItems": [{ "type": "string" }, { "type": "number" }]
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "[" space string "," space number "]" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -317,7 +317,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -333,7 +333,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space boolean ("," space boolean)+ "]" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -349,7 +349,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space boolean? "]" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -365,7 +365,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space (boolean ("," space boolean)?)? "]" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -386,7 +386,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
item ::= number | integer
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "[" space item ("," space item){2,4} "]" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -399,7 +399,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" "ab" "c"? "d"* "ef" "g"+ ("hij")? "kl" "\"" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -412,7 +412,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" "[]{}()|+*?" "\"" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -425,7 +425,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" "\"" "\"" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -440,7 +440,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
dot ::= [^\x0A\x0D]
|
||||
root ::= "\"" ("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot "\"" space
|
||||
root-1 ::= [0-9]
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -466,9 +466,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
a-kv ::= "\"a\"" space ":" space string
|
||||
b-kv ::= "\"b\"" space ":" space string
|
||||
c-kv ::= "\"c\"" space ":" space string
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -486,9 +486,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space string
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "{" space (a-kv )? "}" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -510,9 +510,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
b-kv ::= "\"b\"" space ":" space string
|
||||
b-rest ::= ( "," space c-kv )?
|
||||
c-kv ::= "\"c\"" space ":" space string
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? "}" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -534,11 +534,11 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
a-kv ::= "\"a\"" space ":" space string
|
||||
b-kv ::= "\"b\"" space ":" space string
|
||||
c-kv ::= "\"c\"" space ":" space string
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
d-kv ::= "\"d\"" space ":" space string
|
||||
d-rest ::= ( "," space c-kv )?
|
||||
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -554,12 +554,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
additional-kv ::= string ":" space additional-value
|
||||
additional-kvs ::= additional-kv ( "," space additional-kv )*
|
||||
additional-value ::= "[" space (number ("," space number)*)? "]" space
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space (additional-kvs )? "}" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -574,14 +574,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
@@ -596,14 +596,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
@@ -618,7 +618,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "{" space "}" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -637,12 +637,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
a-kv ::= "\"a\"" space ":" space number
|
||||
additional-kv ::= string ":" space string
|
||||
additional-kvs ::= additional-kv ( "," space additional-kv )*
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space a-kv ( "," space ( additional-kvs ) )? "}" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -662,12 +662,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
a-rest ::= additional-kvs
|
||||
additional-kv ::= string ":" space number
|
||||
additional-kvs ::= additional-kv ( "," space additional-kv )*
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space (a-kv a-rest | additional-kvs )? "}" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -690,12 +690,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
additional-kvs ::= additional-kv ( "," space additional-kv )*
|
||||
b-kv ::= "\"b\"" space ":" space number
|
||||
b-rest ::= additional-kvs
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space a-kv ( "," space ( b-kv b-rest | additional-kvs ) )? "}" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -721,11 +721,11 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
}
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
foo ::= "{" space foo-a-kv "}" space
|
||||
foo-a-kv ::= "\"a\"" space ":" space string
|
||||
root ::= foo
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -759,7 +759,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= alternative-0 | alternative-1
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -803,7 +803,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -851,7 +851,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
number-number-kv ::= "\"number\"" space ":" space number-number
|
||||
number-number-root-kv ::= "\"root\"" space ":" space number
|
||||
root ::= "{" space number-kv "}" space
|
||||
space ::= " "?
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
}
|
||||
@@ -870,7 +870,7 @@ int main() {
|
||||
}
|
||||
});
|
||||
|
||||
if (getenv("LLAMA_PYTHON_AVAILABLE") || (std::system("python --version") == 0)) {
|
||||
if (getenv("LLAMA_PYTHON_AVAILABLE") || (std::system("python -c \"import sys; exit(1) if sys.version_info < (3, 8) else print('Python version is sufficient')\"") == 0)) {
|
||||
test_all("Python", [](const TestCase & tc) {
|
||||
write("test-json-schema-input.tmp", tc.schema);
|
||||
tc.verify_status(std::system(
|
||||
@@ -878,7 +878,7 @@ int main() {
|
||||
tc.verify(read("test-grammar-output.tmp"));
|
||||
});
|
||||
} else {
|
||||
fprintf(stderr, "\033[33mWARNING: Python not found, skipping Python JSON schema -> grammar tests.\n\033[0m");
|
||||
fprintf(stderr, "\033[33mWARNING: Python not found (min version required is 3.8), skipping Python JSON schema -> grammar tests.\n\033[0m");
|
||||
}
|
||||
|
||||
if (getenv("LLAMA_NODE_AVAILABLE") || (std::system("node --version") == 0)) {
|
||||
|
||||
Reference in New Issue
Block a user