Compare commits

...

10 Commits

Author SHA1 Message Date
Georgi Gerganov
a73ccf1aa3 llama : replace (permute + reshape + view_1d) with (view_3d) (#2538)
ggml-ci
2023-08-17 10:47:09 +03:00
drbh
7cf54e1f74 tests : adds simple llama grammar tests (#2618)
* adds simple llama grammar tests

* fix lint and add Makefile

* 0 terminate code_points

* avoid dangling pointers in candidate cleanup

* cleanup grammar at end of test
2023-08-17 10:41:01 +03:00
Shouzheng Liu
a872a2b28e ggml-alloc : fix discrepency between measure&eval (#2639)
The GGML memory allocator consistently places a tensor within the
optimal-fit memory block, which is the smallest block capable of
accommodating the tensor's size. During the measurement phase, the final
block is generously sized, ensuring it never qualifies as the
optimal-fit block as long as there exists another block capable of
accommodating the tensor. Nevertheless, in the evaluation phase, the
last block is constrained in size and could potentially qualify as the
optimal-fit block. Consequently, there exists the possibility of a
tensor being allocated to a different region during evaluation, leading
to more memory fragmentation in our scratch buffer.

This recent commit guarantees uniform behavior of the allocator across
both the measurement and evaluation phases, eliminating discrepancies
between the two.
2023-08-17 10:35:53 +03:00
Kolen Cheung
0919a0f73d cmake : install ggml-meta.metal if LLAMA_METAL (#2449) 2023-08-16 23:09:49 +03:00
Jhen-Jie Hong
ed53db86c3 metal : print error of load pipeline state (#2564)
* metal : print error of load pipeline state

* metal : return null if load pipeline failed
2023-08-16 23:09:03 +03:00
Shouzheng Liu
fc8ef549e5 metal : enable ggml-alloc (#2627)
* metal: enable ggml-alloc

Make ggml-alloc work with concurrently dispatch.

* style-fix

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-08-16 23:08:28 +03:00
Shouzheng Liu
bf83bff674 metal : matrix-matrix multiplication kernel (#2615)
* metal: matrix-matrix multiplication kernel

This commit removes MPS and uses custom matrix-matrix multiplication
kernels for all quantization types. This commit also adds grouped-query
attention to support llama2 70B.

* metal: fix performance degradation from gqa

Integers are slow on the GPU, and 64-bit divides are extremely slow.
In the context of GQA, we introduce a 64-bit divide that cannot be
optimized out by the compiler, which results in a decrease of ~8% in
inference performance. This commit fixes that issue by calculating a
part of the offset with a 32-bit divide. Naturally, this limits the
size of a single matrix to ~4GB. However, this limitation should
suffice for the near future.

* metal: fix bugs for GQA and perplexity test.

I mixed up ne02 and nb02 in previous commit.
2023-08-16 23:07:04 +03:00
Georgi Gerganov
b5ffb2849d scripts : add helper script to get wikitext 2023-08-15 10:05:25 +03:00
Jhen-Jie Hong
3ebb00935f server : add missing /json-schema-to-grammar.mjs (#2616)
fixes #2611
2023-08-15 06:14:14 +08:00
Jhen-Jie Hong
d783f7982e metal : return null instead of exit(1) (#2573) 2023-08-14 16:37:39 +03:00
13 changed files with 1047 additions and 675 deletions

View File

@@ -296,7 +296,6 @@ if (LLAMA_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
@@ -313,7 +312,6 @@ if (LLAMA_METAL)
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK}
)
endif()
@@ -571,6 +569,16 @@ install(
WORLD_READ
WORLD_EXECUTE
DESTINATION ${CMAKE_INSTALL_BINDIR})
if (LLAMA_METAL)
install(
FILES ggml-metal.metal
PERMISSIONS
OWNER_READ
OWNER_WRITE
GROUP_READ
WORLD_READ
DESTINATION ${CMAKE_INSTALL_BINDIR})
endif()
#
# programs, examples and tests

View File

@@ -2,7 +2,7 @@
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test
# Binaries only useful for tests
TEST_TARGETS = tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0
default: $(BUILD_TARGETS)
@@ -283,7 +283,7 @@ endif # LLAMA_CLBLAST
ifdef LLAMA_METAL
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
CXXFLAGS += -DGGML_USE_METAL
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
OBJS += ggml-metal.o
endif # LLAMA_METAL
@@ -412,6 +412,9 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o
vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
tests/test-llama-grammar: tests/test-llama-grammar.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)
tests/test-grammar-parser: tests/test-grammar-parser.cpp examples/grammar-parser.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)

View File

@@ -15,6 +15,7 @@
#include "index.html.hpp"
#include "index.js.hpp"
#include "completion.js.hpp"
#include "json-schema-to-grammar.mjs.hpp"
#ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1
@@ -1218,6 +1219,12 @@ int main(int argc, char **argv)
res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript");
return false; });
// this is only called if no index.html is found in the public --path
svr.Get("/json-schema-to-grammar.mjs", [](const Request &, Response &res)
{
res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript");
return false; });
svr.Post("/completion", [&llama](const Request &req, Response &res)
{
auto lock = llama.lock();

View File

@@ -14,8 +14,6 @@
with pkgs.darwin.apple_sdk_11_0.frameworks; [
Accelerate
MetalKit
MetalPerformanceShaders
MetalPerformanceShadersGraph
]
else if isAarch32 && isDarwin then
with pkgs.darwin.apple_sdk.frameworks; [

View File

@@ -67,6 +67,8 @@ struct ggml_allocr {
struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
size_t max_size;
bool measure;
int parse_seq[GGML_MAX_NODES];
bool has_parse_seq;
#ifdef GGML_ALLOCATOR_DEBUG
struct ggml_tensor * allocated_tensors[1024];
@@ -111,10 +113,10 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
size_t max_avail = 0;
// find the best fitting free block
// find the best fitting free block besides the last block
int best_fit_block = -1;
size_t best_fit_size = SIZE_MAX;
for (int i = 0; i < alloc->n_free_blocks; i++) {
for (int i = 0; i < alloc->n_free_blocks - 1; i++) {
struct free_block * block = &alloc->free_blocks[i];
max_avail = MAX(max_avail, block->size);
if (block->size >= size && block->size <= best_fit_size) {
@@ -126,10 +128,17 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
AT_PRINTF("block %d\n", best_fit_block);
if (best_fit_block == -1) {
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
__func__, size, max_avail);
GGML_ASSERT(!"not enough space in the buffer");
// the last block is our last resort
struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
if (block->size >= size) {
best_fit_block = alloc->n_free_blocks - 1;
max_avail = MAX(max_avail, block->size);
} else {
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
__func__, size, max_avail);
GGML_ASSERT(!"not enough space in the buffer");
return;
}
}
struct free_block * block = &alloc->free_blocks[best_fit_block];
void * addr = block->addr;
@@ -229,6 +238,17 @@ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_t
alloc->n_free_blocks++;
}
void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n) {
int pos = 0;
for (int i = 0; i < n; i++) {
if (list[i] != -1) {
alloc->parse_seq[pos] = list[i];
pos++;
}
}
alloc->has_parse_seq = true;
}
void ggml_allocr_reset(struct ggml_allocr * alloc) {
alloc->n_free_blocks = 1;
size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
@@ -248,6 +268,8 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment)
/*.hash_table = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ false,
/*.parse_seq = */ {0},
/*.has_parse_seq = */ false,
#ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ = {0},
#endif
@@ -275,6 +297,8 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
/*.hash_table = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ true,
/*.parse_seq = */ {0},
/*.has_parse_seq = */ false,
#ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ = {0},
#endif
@@ -473,7 +497,13 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
allocate_node(alloc, input);
}
}
for (int i = 0; i < gf->n_nodes; i++) {
for (int ind = 0; ind < gf->n_nodes; ind++) {
int i;
if (alloc->has_parse_seq) {
i = alloc->parse_seq[ind];
} else {
i = ind;
}
struct ggml_tensor * node = gf->nodes[i];
// allocate parents (leafs)

View File

@@ -10,6 +10,10 @@ extern "C" {
GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
// tell the allocator to parse nodes following the order described in the list
// you should call this if your graph are optimized to execute out-of-order
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n);
GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);
GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc);

View File

@@ -63,10 +63,13 @@ void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
// try to find operations that can be run concurrently in the graph
// you should run it again if the topology of your graph changes
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);
// if the graph has been optimized for concurrently dispatch
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
// if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
int ggml_metal_if_optimized(struct ggml_metal_context * ctx);
// output the concur_list for ggml_alloc
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
// same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel

View File

@@ -5,7 +5,6 @@
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#undef MIN
#undef MAX
@@ -79,6 +78,14 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -110,13 +117,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->n_buffers = 0;
ctx->concur_list_len = 0;
// determine if we can use MPS
if (MPSSupportsMTLDevice(ctx->device)) {
fprintf(stderr, "%s: using MPS\n", __func__);
} else {
fprintf(stderr, "%s: not using MPS\n", __func__);
GGML_ASSERT(false && "MPS not supported");
}
#if 0
// compile from source string and show compile log
@@ -126,7 +126,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
return NULL;
}
}
#else
@@ -144,7 +144,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
return NULL;
}
#ifdef GGML_QKK_64
@@ -156,17 +156,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
#endif
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
return NULL;
}
}
#endif
// load kernels
{
NSError * error = nil;
#define GGML_METAL_ADD_KERNEL(name) \
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
if (error) { \
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
return NULL; \
}
GGML_METAL_ADD_KERNEL(add);
GGML_METAL_ADD_KERNEL(add_row);
@@ -196,6 +201,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -228,11 +241,12 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
ctx->n_cb = n_cb;
}
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
if (ctx->concur_list_len) {
return true;
}
return false;
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
return ctx->concur_list_len;
}
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
return ctx->concur_list;
}
// finds the Metal buffer that contains the tensor data on the GPU device
@@ -375,7 +389,7 @@ void ggml_metal_get_tensor(
void ggml_metal_graph_find_concurrency(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
struct ggml_cgraph * gf, bool check_mem) {
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
int nodes_unused[GGML_MAX_CONCUR];
@@ -422,7 +436,7 @@ void ggml_metal_graph_find_concurrency(
}
}
}
if (exe_flag) {
if (exe_flag && check_mem) {
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
int64_t data_start = (int64_t) gf->nodes[i]->data;
@@ -506,7 +520,7 @@ void ggml_metal_graph_compute(
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = nil;
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
@@ -515,10 +529,6 @@ void ggml_metal_graph_compute(
const int i = has_concur ? ctx->concur_list[ind] : ind;
if (i == -1) {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
continue;
}
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue;
}
@@ -592,10 +602,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_ADD:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -613,10 +619,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_MUL:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -634,10 +636,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_SCALE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const float scale = *(const float *) src1->data;
[encoder setComputePipelineState:ctx->pipeline_scale];
@@ -653,10 +651,6 @@ void ggml_metal_graph_compute(
switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_SILU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -667,10 +661,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_UNARY_OP_RELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -681,10 +671,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_UNARY_OP_GELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_gelu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -701,10 +687,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_SOFT_MAX:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int nth = 32;
[encoder setComputePipelineState:ctx->pipeline_soft_max];
@@ -719,10 +701,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_DIAG_MASK_INF:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int n_past = ((int32_t *)(dst->op_params))[0];
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
@@ -740,53 +718,43 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne00 == ne10);
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
uint gqa = ne12/ne02;
GGML_ASSERT(ne03 == ne13);
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
src1t == GGML_TYPE_F32 &&
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00%32 == 0 &&
ne11 > 1) {
switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
}
MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
// for F32 x F32 we use MPS
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
initWithDevice:ctx->device transposeLeft:false transposeRight:true
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
// we need to do ne12 multiplications
// TODO: is there a way to do this in parallel - currently very slow ..
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
for (int64_t i02 = 0; i02 < ne12; ++i02) {
size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
size_t offs_src1_cur = offs_src1 + i02*nb12;
size_t offs_dst_cur = offs_dst + i02*nb2;
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
}
} else {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
else {
int nth0 = 32;
int nth1 = 1;
@@ -885,23 +853,24 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif
}
else if (src0t == GGML_TYPE_Q5_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -910,10 +879,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_GET_ROWS:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
@@ -939,10 +904,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_RMS_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
@@ -962,10 +923,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const float eps = 1e-5f;
const int nth = 256;
@@ -984,10 +941,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_ALIBI:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
GGML_ASSERT((src0t == GGML_TYPE_F32));
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
@@ -1027,10 +980,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_ROPE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
@@ -1071,10 +1020,6 @@ void ggml_metal_graph_compute(
case GGML_OP_CPY:
case GGML_OP_CONT:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int nth = 32;
switch (src0t) {

File diff suppressed because it is too large Load Diff

View File

@@ -63,7 +63,7 @@ static void llama_log_callback_default(llama_log_level level, const char * text,
#define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__)
#if !defined(GGML_USE_CUBLAS) && !defined(GGML_USE_METAL)
#if !defined(GGML_USE_CUBLAS)
#include "ggml-alloc.h"
#define LLAMA_USE_ALLOCATOR
#else
@@ -1609,11 +1609,11 @@ static struct ggml_cgraph * llama_build_graph(
ggml_set_name(Q, "Q");
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd_gqa, il*n_ctx*ggml_element_size(kv_self.k)*n_embd_gqa),
n_embd_head, n_head_kv, n_past + N),
0, 2, 1, 3);
ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_past + N, n_head_kv,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
offload_func_kq(K);
ggml_set_name(K, "K");
@@ -1642,9 +1642,9 @@ static struct ggml_cgraph * llama_build_graph(
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_embd_head, n_head_kv,
n_ctx*ggml_element_size(kv_self.v),
n_ctx*ggml_element_size(kv_self.v)*n_embd_head,
n_ctx*ggml_element_size(kv_self.v)*n_embd_gqa*il);
ggml_element_size(kv_self.v)*n_ctx,
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
offload_func_v(V);
ggml_set_name(V, "V");
@@ -1845,11 +1845,7 @@ static bool llama_eval_internal(
#endif
#ifdef GGML_USE_METAL
if (lctx.ctx_metal && N == 1) {
// TODO: disabled until #2413 is resolved
//if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
// ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf);
//}
if (lctx.ctx_metal) {
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
ggml_metal_graph_compute(lctx.ctx_metal, gf);
ggml_metal_get_tensor (lctx.ctx_metal, res);
@@ -1857,22 +1853,6 @@ static bool llama_eval_internal(
ggml_metal_get_tensor(lctx.ctx_metal, embeddings);
}
} else {
// IMPORTANT:
// Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla
// ggml_graph_compute(). It uses Apple's Accelerate CBLAS API which takes advantage of the ANE or the AMX
// coprocessor.
//
// When we implement Matrix x Matrix Metal multiplication, we can avoid this branch.
// But for now, we have focused only on Matrix x Vector Metal multiplication.
//
// TODO: avoid these syncs via shared memory (ref #1696)
//
if (lctx.ctx_metal) {
// We need to sync the GPU KV cache with the CPU KV cache
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.k);
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
}
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
}
#else
@@ -3303,7 +3283,18 @@ struct llama_context * llama_new_context_with_model(
int n_past = hparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past);
#ifdef GGML_USE_METAL
if (params.n_gpu_layers > 0) {
ctx->ctx_metal = ggml_metal_init(1);
if (!ctx->ctx_metal) {
LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__);
llama_free(ctx);
return NULL;
}
ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false);
ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
}
#endif
// measure memory requirements for the graph
size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment;
@@ -3321,6 +3312,11 @@ struct llama_context * llama_new_context_with_model(
ctx->buf_alloc.resize(alloc_size);
ctx->alloc = ggml_allocr_new(ctx->buf_alloc.addr, ctx->buf_alloc.size, tensor_alignment);
#ifdef GGML_USE_METAL
if (ctx->ctx_metal) {
ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
}
#endif
}
#else
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
@@ -3335,7 +3331,6 @@ struct llama_context * llama_new_context_with_model(
#ifdef GGML_USE_METAL
if (params.n_gpu_layers > 0) {
// this allocates all Metal resources and memory buffers
ctx->ctx_metal = ggml_metal_init(1);
void * data_ptr = NULL;
size_t data_size = 0;
@@ -3364,8 +3359,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.addr, ctx->kv_self.buf.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.addr, ctx->buf_alloc.size, 0));
#undef LLAMA_METAL_CHECK_BUF
}
#endif

View File

@@ -0,0 +1,3 @@
#!/bin/bash
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip

View File

@@ -12,5 +12,6 @@ llama_add_test(test-quantize-perf.cpp)
llama_add_test(test-sampling.cpp)
llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
llama_add_test(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp)
llama_add_test(test-llama-grammar.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/common.cpp)
llama_add_test(test-grad0.cpp) # SLOW
# llama_add_test(test-opt.cpp) # SLOW

View File

@@ -0,0 +1,403 @@
#ifdef NDEBUG
#undef NDEBUG
#endif
#include "llama.cpp"
#include "examples/common.cpp"
#include "examples/grammar-parser.cpp"
#include <cassert>
int main()
{
grammar_parser::parse_state parsed_grammar;
std::vector<std::pair<std::string, uint32_t>> expected = {
{"expr", 2},
{"expr_6", 6},
{"expr_7", 7},
{"ident", 8},
{"ident_10", 10},
{"num", 9},
{"num_11", 11},
{"root", 0},
{"root_1", 1},
{"root_5", 5},
{"term", 4},
{"ws", 3},
{"ws_12", 12},
};
std::vector<std::vector<llama_grammar_element>> expected_rules = {
{{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
{
{LLAMA_GRETYPE_RULE_REF, 2},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_RULE_REF, 4},
{LLAMA_GRETYPE_CHAR, 10},
{LLAMA_GRETYPE_END, 0},
},
{{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
{{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
{
{LLAMA_GRETYPE_RULE_REF, 8},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_RULE_REF, 9},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_CHAR, 40},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_RULE_REF, 2},
{LLAMA_GRETYPE_CHAR, 41},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_END, 0},
},
{{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
{
{LLAMA_GRETYPE_CHAR, 45},
{LLAMA_GRETYPE_CHAR_ALT, 43},
{LLAMA_GRETYPE_CHAR_ALT, 42},
{LLAMA_GRETYPE_CHAR_ALT, 47},
{LLAMA_GRETYPE_RULE_REF, 4},
{LLAMA_GRETYPE_END, 0},
},
{{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
{
{LLAMA_GRETYPE_CHAR, 97},
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
{LLAMA_GRETYPE_RULE_REF, 10},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_END, 0},
},
{{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
{
{LLAMA_GRETYPE_CHAR, 97},
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
{LLAMA_GRETYPE_CHAR_ALT, 48},
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
{LLAMA_GRETYPE_CHAR_ALT, 95},
{LLAMA_GRETYPE_RULE_REF, 10},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_END, 0},
},
{
{LLAMA_GRETYPE_CHAR, 48},
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
{LLAMA_GRETYPE_RULE_REF, 11},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_CHAR, 48},
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
{LLAMA_GRETYPE_END, 0},
},
{
{LLAMA_GRETYPE_CHAR, 32},
{LLAMA_GRETYPE_CHAR_ALT, 9},
{LLAMA_GRETYPE_CHAR_ALT, 10},
{LLAMA_GRETYPE_RULE_REF, 12},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_END, 0},
},
};
for (auto pair : expected)
{
parsed_grammar.symbol_ids[pair.first] = pair.second;
}
for (auto rule : expected_rules)
{
parsed_grammar.rules.push_back({});
for (auto element : rule)
{
parsed_grammar.rules.back().push_back(element);
}
}
llama_grammar *grammar = NULL;
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
{
{LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_CHAR, 97},
},
{
{LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_CHAR, 48},
},
{
{LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_CHAR, 48},
},
{
{LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_CHAR, 40},
},
{
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_CHAR, 97},
},
{
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_CHAR, 48},
},
{
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_CHAR, 48},
},
{
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_CHAR, 40},
}};
auto index = 0;
for (auto stack : grammar->stacks)
{
// compare stack to expected_stack
for (uint32_t i = 0; i < stack.size(); i++)
{
auto element = stack[i];
auto expected_element = expected_stacks[index][i];
// pretty print error message before asserting
if (expected_element.type != element->type || expected_element.value != element->value)
{
fprintf(stderr, "index: %d\n", index);
fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value);
fprintf(stderr, "actual_element: %d, %d\n", element->type, element->value);
fprintf(stderr, "expected_element != actual_element\n");
}
assert(expected_element.type == element->type && expected_element.value == element->value);
}
index++;
}
std::vector<std::vector<const llama_grammar_element *>> next_stacks;
std::vector<llama_grammar_candidate> next_candidates;
next_candidates.resize(24);
for (size_t i = 0; i < 24; ++i)
{
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
cp[0] = 37 + i;
cp[1] = 0;
next_candidates[i] = {i, cp};
}
std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
{
{0, 37},
{1, 38},
{2, 39},
{3, 40},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{11, 48},
{12, 49},
{13, 50},
{14, 51},
{15, 52},
{16, 53},
{17, 54},
{18, 55},
{19, 56},
{20, 57},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{3, 40},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{3, 40},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{11, 48},
{12, 49},
{13, 50},
{14, 51},
{15, 52},
{16, 53},
{17, 54},
{18, 55},
{19, 56},
{20, 57},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{3, 40},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{11, 48},
{12, 49},
{13, 50},
{14, 51},
{15, 52},
{16, 53},
{17, 54},
{18, 55},
{19, 56},
{20, 57},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{3, 40},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{3, 40},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{11, 48},
{12, 49},
{13, 50},
{14, 51},
{15, 52},
{16, 53},
{17, 54},
{18, 55},
{19, 56},
{20, 57},
{21, 58},
{22, 59},
{23, 60},
},
};
std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates);
std::vector<std::vector<llama_grammar_candidate>> all_rejects;
for (std::size_t count = 0; count < grammar->stacks.size(); ++count)
{
rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates);
all_rejects.push_back(rejects);
}
index = 0;
for (auto rej : all_rejects)
{
for (uint32_t i = 0; i < rej.size(); i++)
{
auto element = rej[i];
auto expected_element = expected_reject[index][i];
assert(element.index == expected_element.first && *element.code_points == expected_element.second);
}
index++;
}
for (auto &candidate : next_candidates)
{
delete[] candidate.code_points;
candidate.code_points = nullptr;
}
delete grammar;
return 0;
}