Compare commits

..

15 Commits

Author SHA1 Message Date
Douglas Hanley
b54afce9f4 mostly style fixes; fix KQ_mask comment 2024-03-09 13:03:46 -06:00
DAN™
03acc82a85 Clean-up GritLM sample code. 2024-03-09 07:44:25 -05:00
Douglas Hanley
bd3d9fbfed allow to toggle embedding mode 2024-03-07 11:55:27 -06:00
Douglas Hanley
f618e5060a add to gitignore 2024-03-07 01:38:30 -06:00
Douglas Hanley
1ab6aeeeee gritlm embeddings are back babeee 2024-03-07 01:37:08 -06:00
Douglas Hanley
97936078b7 rebase to new embed 2024-03-05 23:23:17 -06:00
Douglas Hanley
805ae529c4 comment out debug printing 2024-03-05 13:44:42 -06:00
Douglas Hanley
a71842d7ef tabs to spaces 2024-03-05 13:44:42 -06:00
Douglas Hanley
e79195fc53 gritlm results match 2024-03-05 13:44:41 -06:00
Douglas Hanley
4be8fb18ed add gritlm example 2024-03-05 13:42:51 -06:00
Jared Van Bortel
bd836944f8 quants : use MM256_SET_M128I consistently to fix gcc 7 build (#5889) 2024-03-05 11:56:37 -05:00
ExtReMLapin
3de31677d3 grammars : blacklists character control set (#5888)
* Prevent control characters from being served in json string

* Prevent control characters from being served in json string (array)
2024-03-05 18:33:08 +02:00
Georgi Gerganov
82cb31eb93 Revert "grammars : don't allow to output unescaped new line in string (#5885)"
This reverts commit b1a4e994fd.
2024-03-05 15:56:24 +02:00
ExtReMLapin
b1a4e994fd grammars : don't allow to output unescaped new line in string (#5885)
* Don't allow grammar json array to output unescaped new line in string

* Don't allow new line in json object string
2024-03-05 15:44:29 +02:00
0cc4m
61d1c88e15 Vulkan Improvements (#5835)
* Improve dequant shaders, add fast q4_0 dequant

* Optimize dmmv non-kquants for GCN

Remove unnecessary SPIR-V shader duplication

* Fix q4_0 dequant dispatch sizes

Fix backend free bug

* Optimize dequant shaders for q4_1, q5_0, q5_1 and q8_0

* Add unary and binary op shader templates

* Fix Vulkan check results

* Enable non-contiguous support for simple ops

* Add argsort

Basic q4_0 mmq shader and unit test

* Speed up q4_0 dequant code, enable mmq for q4_0

* Rework matmul pipeline selection

* Add soft_max alibi support

* Add q4_1, q5_0, q5_1 and q8_0 dequant mat mat mul shaders

* Add environment variable GGML_VK_FORCE_MAX_ALLOCATION_SIZE to limit max buffer size

Rename GGML_VULKAN_DISABLE_F16 to GGML_VK_DISABLE_F16 for consistency
2024-03-05 13:33:42 +01:00
14 changed files with 43770 additions and 46815 deletions

1
.gitignore vendored
View File

@@ -45,6 +45,7 @@ models-mnt
/embedding
/gguf
/gguf-llama-simple
/gritlm
/imatrix
/infill
/libllama.so

View File

@@ -2,7 +2,7 @@
BUILD_TARGETS = \
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o
# Binaries only useful for tests
TEST_TARGETS = \
@@ -720,6 +720,10 @@ embedding: examples/embedding/embedding.cpp ggml.o llama.o $(C
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
gritlm: examples/gritlm/gritlm.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View File

@@ -20,6 +20,7 @@ else()
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(embedding)
add_subdirectory(finetune)
add_subdirectory(gritlm)
add_subdirectory(infill)
add_subdirectory(llama-bench)
add_subdirectory(llava)

View File

@@ -0,0 +1,5 @@
set(TARGET gritlm)
add_executable(${TARGET} gritlm.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

243
examples/gritlm/gritlm.cpp Normal file
View File

@@ -0,0 +1,243 @@
#include "common.h"
#include "llama.h"
#include <string>
#include <vector>
// #define GRIT_DEBUG
static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) {
float dot = 0.0f;
for (uint64_t i = 0; i < v1.size(); ++i) {
dot += v1[i] * v2[i];
}
return dot;
}
static float norm(const std::vector<float> & v) {
return std::sqrt(dot_product(v, v));
}
static float cosine_similarity(const std::vector<float> & v1, const std::vector<float> & v2) {
return dot_product(v1, v2) / (norm(v1) * norm(v2));
}
static void normalize(const std::vector<float> & in, float * out) {
float inorm = norm(in);
for (uint64_t i = 0; i < in.size(); i++) {
out[i] = in[i] / inorm;
}
}
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
auto result = std::vector<std::vector<float>>{};
auto mdl = llama_get_model(ctx);
auto batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
for (uint64_t i = 0; i < sentences.size(); i++) {
llama_batch_clear(batch);
std::string input_string = instruction + sentences[i];
std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
auto n_toks = (int32_t)inputs.size();
// GritLM seems to have embed EOS = ""
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
// inputs.push_back(llama_token_eos(mdl));
// we want to ignore instruction tokens for mean pooling
std::vector<llama_token> inputs_instruct = llama_tokenize(mdl, instruction, true, false);
auto n_inst = (int32_t)inputs_instruct.size();
#ifdef GRIT_DEBUG
// debug tokens - should be matching as referenced in the GritLM sample
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
});
std::printf("\n");
#endif
// add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) {
llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
}
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
// run model
llama_decode(ctx, batch);
// get embedding dimensions
uint64_t n_embd = llama_n_embd(mdl);
// allocate embedding output
std::vector<float> emb_unorm(n_embd, 0.0f);
// sum up all token embeddings
for (int32_t k = n_inst; k < n_toks; k++) {
float * emb = llama_get_embeddings_ith(ctx, k);
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] += emb[j];
}
}
// divide by number of tokens (mean pooling)
uint64_t n_sent = n_toks - n_inst;
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] /= n_sent;
}
auto emb_norm = std::vector<float>(emb_unorm.size());
normalize(emb_unorm, emb_norm.data());
result.push_back(emb_norm);
#ifdef GRIT_DEBUG
// print out emb_norm
std::printf("embedding %ld: ", i);
for (uint64_t j = 0; j < n_embd; j++) {
std::printf("%.5f ", emb_norm[j]);
}
std::printf("\n\n");
#endif
}
llama_batch_free(batch);
return result;
}
static std::string aggregate_pieces(const std::vector<std::string> & pieces) {
// calculate total length required
size_t length = 0;
for (const auto & str : pieces) {
length += str.size();
}
// reserve memory
std::string result;
result.reserve(length);
// append pieces
for (const auto & str : pieces) {
result += str;
}
return result;
}
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
std::vector<std::string> pieces;
const llama_model * mdl = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
int32_t i_current_token = 0;
while (true) {
llama_batch_clear(bat);
for (auto i = 0; i < inputs.size(); i++) {
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == inputs.size() - 1);
}
inputs.clear();
llama_decode(ctx, bat);
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
for (auto token = 0; token < candidates.size(); token++) {
candidates[token] = llama_token_data{ token, logits[token], 0.0f };
}
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
if (token == eos_token) {
break;
}
std::string piece = llama_token_to_piece(ctx, token);
if (stream) {
std::printf("%s", piece.c_str());
}
pieces.push_back(piece);
inputs.push_back(token);
}
llama_batch_free(bat);
return aggregate_pieces(pieces);
}
static std::string gritlm_instruction(const std::string & instruction) {
return !instruction.empty() ? "<|user|>\n" + instruction + "\n<|embed|>\n" : "<|embed|>\n";
}
int main(int argc, char * argv[])
{
gpt_params params;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
llama_model_params mparams = llama_model_params_from_gpt_params(params);
llama_context_params cparams = llama_context_params_from_gpt_params(params);
llama_backend_init();
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
// create new context - set to embedding mode
llama_context * embd_ctx = llama_new_context_with_model(mdl, cparams);
llama_set_embeddings(embd_ctx, true);
// create new context - default mode is causal
llama_context * causal_ctx = llama_new_context_with_model(mdl, cparams);
// samples taken from here: https://github.com/ContextualAI/gritlm#basic
// Embedding/Representation
{
std::string instruction = "Given a scientific paper title, retrieve the paper's abstract";
std::vector<std::string> queries = {
"Bitcoin: A Peer-to-Peer Electronic Cash System",
"Generative Representational Instruction Tuning",
};
std::vector<std::string> documents = {
"A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.",
"All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.",
};
// No need to add instruction for retrieval documents
std::vector<std::vector<float>> d_rep = encode(embd_ctx, documents, gritlm_instruction(""));
std::vector<std::vector<float>> q_rep = encode(embd_ctx, queries, gritlm_instruction(instruction));
float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]);
float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]);
float cosine_sim_q1_d0 = cosine_similarity(q_rep[1], d_rep[0]);
float cosine_sim_q1_d1 = cosine_similarity(q_rep[1], d_rep[1]);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[0].c_str(), cosine_sim_q1_d0);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
}
// Generation
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
std::string response = generate(causal_ctx, prompt, true);
}
llama_free(embd_ctx);
llama_free(causal_ctx);
llama_free_model(mdl);
llama_backend_free();
return 0;
}

View File

@@ -51,6 +51,7 @@
#define UNUSED GGML_UNUSED
// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
@@ -9563,7 +9564,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m128i odd_bits = _mm_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
const __m128i full_sign_bits = _mm_or_si128(partial_sign_bits, odd_bits);
const __m256i full_signs = _mm256_set_m128i(full_sign_bits, full_sign_bits);
const __m256i full_signs = MM256_SET_M128I(full_sign_bits, full_sign_bits);
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)(y[i].qs+32));
@@ -9585,8 +9586,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
const __m256i sc1 = _mm256_set_m128i(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1));
const __m256i sc2 = _mm256_set_m128i(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1));
const __m256i sc1 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1));
const __m256i sc2 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1));
const __m256i sum = _mm256_add_epi32(_mm256_madd_epi16(sc1, dot1), _mm256_madd_epi16(sc2, dot2));
@@ -9653,8 +9654,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l);
const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h);
const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l);
const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h);
__m256i signs;
signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
@@ -10551,10 +10552,10 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs);
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs);
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs);
const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
@@ -10661,10 +10662,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -10,6 +10,7 @@ extern "C" {
#define GGML_VK_NAME "Vulkan"
#define GGML_VK_MAX_DEVICES 16
GGML_API void ggml_vk_instance_init(void);
GGML_API void ggml_vk_init_cpu_assist(void);
GGML_API void ggml_vk_preallocate_buffers_graph_cpu_assist(struct ggml_tensor * node);

File diff suppressed because it is too large Load Diff

View File

@@ -15,7 +15,7 @@ array ::=
string ::=
"\"" (
[^"\\] |
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

View File

@@ -24,7 +24,7 @@ array ::=
string ::=
"\"" (
[^"\\] |
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

View File

@@ -1684,7 +1684,6 @@ struct llama_cparams {
bool embeddings;
bool offload_kqv;
enum llama_pooling_type pooling_type;
ggml_backend_sched_eval_callback cb_eval;
@@ -5014,8 +5013,8 @@ static struct ggml_tensor * llm_build_kqv(
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
}
#if defined(GGML_USE_VULKAN) || defined(GGML_USE_KOMPUTE)
#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Vulkan, and Kompute")
#if defined(GGML_USE_KOMPUTE)
#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute")
#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024")
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488")
if (hparams.f_max_alibi_bias > 0.0f) {
@@ -8030,7 +8029,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
if (hparams.causal_attn) {
GGML_ASSERT(
(hparams.causal_attn || cparams.embeddings) &&
"non-causal attention with generative models is not supported"
);
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
// But if cparams.embeddings is set, the attention will be non-causal nonetheless.
if (!cparams.embeddings) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;
@@ -8055,8 +8061,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
} else {
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
// for models using the kv cache, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens;
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
@@ -8075,7 +8082,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
}
for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
}
}
}
@@ -13158,6 +13169,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
ctx->abort_callback_data = abort_callback_data;
}
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
ctx->cparams.embeddings = embeddings;
}
struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens,

View File

@@ -641,6 +641,10 @@ extern "C" {
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
// Set whether to use causal attention or not
// If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
// Set abort callback
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);