Compare commits

..

7 Commits

Author SHA1 Message Date
Georgi Gerganov
49af767fad build : add compile option to force use of MMQ kernels 2023-10-27 13:21:04 +03:00
Georgi Gerganov
a4e15a36e4 cuda : add CUDA_USE_TENSOR_CORES and GGML_CUDA_FORCE_MMQ macros 2023-10-26 16:00:48 +03:00
Georgi Gerganov
4c6744b526 cuda : remove duplicated cuBLAS GEMM code 2023-10-25 18:25:13 +03:00
Georgi Gerganov
a3c28439d3 cuda : fine-tune >= VOLTA params + use MMQ only for small batches 2023-10-25 15:07:34 +03:00
Georgi Gerganov
16b60dd75c cuda : add F32 sgemm branch 2023-10-25 14:00:21 +03:00
Georgi Gerganov
52af782608 cuda : new cublas gemm branch for multi-batch quantized src0 2023-10-25 13:14:24 +03:00
Georgi Gerganov
59d1232ea7 cuda : prints wip 2023-10-25 10:26:58 +03:00
4 changed files with 38 additions and 125 deletions

View File

@@ -1502,7 +1502,7 @@ struct llama_server_context
{
for (auto & slot : slots)
{
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty();
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty());
// empty prompt passed -> release the slot and send empty response
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt)

View File

@@ -95,8 +95,13 @@ int main(int argc, char ** argv) {
llama_batch batch = llama_batch_init(512, 0, 1);
// evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); i++) {
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
batch.n_tokens = tokens_list.size();
for (int32_t i = 0; i < batch.n_tokens; i++) {
batch.token[i] = tokens_list[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
}
// llama_decode will output logits only for the last token of the prompt
@@ -143,10 +148,15 @@ int main(int argc, char ** argv) {
fflush(stdout);
// prepare the next batch
llama_batch_clear(batch);
batch.n_tokens = 0;
// push this new token for next evaluation
llama_batch_add(batch, new_token_id, n_cur, { 0 }, true);
batch.token [batch.n_tokens] = new_token_id;
batch.pos [batch.n_tokens] = n_cur;
batch.seq_id[batch.n_tokens] = 0;
batch.logits[batch.n_tokens] = true;
batch.n_tokens += 1;
n_decode += 1;
}

View File

@@ -8,9 +8,6 @@
#include <string>
#include <vector>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft {
bool active = false;
bool drafting = false;
@@ -67,33 +64,6 @@ int main(int argc, char ** argv) {
params.n_gpu_layers = params.n_gpu_layers_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__);
fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return 1;
}
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i,
llama_token_to_piece(ctx_tgt, i).c_str(),
llama_token_to_piece(ctx_dft, i).c_str());
return 1;
}
}
}
// tokenize the prompt
std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
@@ -257,7 +227,6 @@ int main(int argc, char ** argv) {
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode (ctx_dft, batch_dft);
++n_past_dft;
@@ -401,7 +370,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
}
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
//LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt));
llama_decode(ctx_tgt, batch_tgt);
++n_past_tgt;
}

110
llama.cpp
View File

@@ -1578,14 +1578,12 @@ static void llama_kv_cache_seq_shift(
enum llama_fver {
GGUF_FILE_VERSION_V1 = 1,
GGUF_FILE_VERSION_V2 = 2,
GGUF_FILE_VERSION_V3 = 3,
};
static const char * llama_file_version_name(llama_fver version) {
switch (version) {
case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)";
case GGUF_FILE_VERSION_V2: return "GGUF V2";
case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)";
case GGUF_FILE_VERSION_V2: return "GGUF V2 (latest)";
}
return "unknown";
@@ -2695,8 +2693,8 @@ static void llm_load_tensors(
} break;
case LLM_ARCH_STARCODER:
{
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
// output
{
@@ -2747,19 +2745,19 @@ static void llm_load_tensors(
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend);
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend);
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split);
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend);
layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split);
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split);
if (backend == GGML_BACKEND_GPU) {
vram_weights +=
@@ -4616,8 +4614,6 @@ static struct ggml_cgraph * llm_build_starcoder(
const float norm_eps = hparams.f_norm_eps;
const int n_gpu_layers = model.n_gpu_layers;
const int32_t n_tokens = batch.n_tokens;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
@@ -4662,27 +4658,6 @@ static struct ggml_cgraph * llm_build_starcoder(
}
}
const int i_gpu_start = n_layer - n_gpu_layers;
(void) i_gpu_start;
// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop;
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer) {
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 1) {
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 2) {
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS
{
// Compute position embeddings.
struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
@@ -4708,7 +4683,6 @@ static struct ggml_cgraph * llm_build_starcoder(
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
ggml_set_name(KQ_mask, "KQ_mask");
offload_func_kq(KQ_mask);
ggml_allocr_alloc(lctx.alloc, KQ_mask);
if (!ggml_allocr_is_measure(lctx.alloc)) {
float * data = (float *) KQ_mask->data;
@@ -4732,67 +4706,44 @@ static struct ggml_cgraph * llm_build_starcoder(
ggml_set_name(inpL, "inpL");
for (int il = 0; il < n_layer; ++il) {
offload_func_t offload_func = llama_nop;
#ifdef GGML_USE_CUBLAS
if (il >= i_gpu_start) {
offload_func = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS
{
// Norm
cur = ggml_norm(ctx0, inpL, norm_eps);
offload_func(cur);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b);
offload_func(cur);
}
{
// Self Attention
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
offload_func_kq(cur);
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv);
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
offload_func_kq(cur);
struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd);
struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd);
struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa));
struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
struct ggml_tensor * tmpv = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
ggml_set_name(tmpq, "tmpq");
ggml_set_name(tmpk, "tmpk");
ggml_set_name(tmpv, "tmpv");
offload_func_kq(tmpq);
offload_func_kq(tmpk);
offload_func_v (tmpv);
struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens);
struct ggml_tensor * Qcur = tmpq;
struct ggml_tensor * Kcur = tmpk;
{
struct ggml_tensor * Vcur = ggml_transpose(ctx0, tmpv);
offload_func_v(Vcur);
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens));
ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
offload_func_kq(k);
ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
offload_func_v(v);
ggml_set_name(v, "v");
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
}
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
offload_func_kq(Q);
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)),
0, 2, 1, 3);
ggml_set_name(Q, "Q");
struct ggml_tensor * K =
@@ -4801,28 +4752,23 @@ static struct ggml_cgraph * llm_build_starcoder(
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
offload_func_kq(K);
ggml_set_name(K, "K");
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
offload_func_kq(KQ);
ggml_set_name(KQ, "KQ");
// KQ_scaled = KQ / sqrt(n_embd_head)
// KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled");
// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked");
// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max");
// split cached V into n_head heads
@@ -4835,25 +4781,22 @@ static struct ggml_cgraph * llm_build_starcoder(
ggml_set_name(V, "V");
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
offload_func_v(KQV);
ggml_set_name(KQV, "KQV");
// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
offload_func_v(KQV_merged);
ggml_set_name(KQV_merged, "KQV_merged");
// cur = KQV_merged.contiguous().view(n_embd, n_tokens)
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
offload_func_v(cur);
ggml_set_name(cur, "KQV_merged_contiguous");
}
// Projection
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo);
offload_func(cur);
// Add the input
cur = ggml_add(ctx0, cur, inpL);
offload_func(cur);
struct ggml_tensor * inpFF = cur;
@@ -4862,36 +4805,27 @@ static struct ggml_cgraph * llm_build_starcoder(
// Norm
{
cur = ggml_norm(ctx0, inpFF, norm_eps);
offload_func_nr(cur);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b);
offload_func_nr(cur);
}
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
offload_func(cur);
// GELU activation
cur = ggml_gelu(ctx0, cur);
offload_func(cur);
// Projection
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2);
offload_func(cur);
}
inpL = ggml_add(ctx0, cur, inpFF);
}
// Output Norm
{
cur = ggml_norm(ctx0, inpL, norm_eps);
offload_func_nr(cur);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b);
ggml_set_name(cur, "result_norm");
}
ggml_set_name(cur, "result_norm");
cur = ggml_mul_mat(ctx0, model.output, cur);
ggml_set_name(cur, "result_output");