mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
cont : inline verification
This commit is contained in:
@@ -1462,6 +1462,87 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
||||
set_input_kq_mask_impl<false>(args, data);
|
||||
}
|
||||
|
||||
// the old reference implementation
|
||||
{
|
||||
std::vector<float> data2(n_tokens*n_kv);
|
||||
std::fill(data2.begin(), data2.end(), -INFINITY);
|
||||
|
||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
||||
// Causal mask:
|
||||
// xxx-------
|
||||
// xxxx------
|
||||
// xxxxx-----
|
||||
// Non-causal mask:
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||
// TODO: optimize this section
|
||||
for (uint32_t h = 0; h < 1; ++h) {
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
||||
const uint32_t i = s*n_tps + ii;
|
||||
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
// for M-RoPE
|
||||
const bool is_2d = ubatch->is_pos_2d();
|
||||
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
||||
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||
|
||||
const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
|
||||
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
if (cells.is_empty(j)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// mask the token if not the same sequence
|
||||
if (!cells.seq_has(j, seq_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_pos p0 = cells.pos_get(j);
|
||||
|
||||
// mask future tokens
|
||||
if (causal_attn && p0 > p1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (causal_attn && is_2d && p0 == p1) {
|
||||
const auto & p0_ext = cells.ext_get(j);
|
||||
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
data2[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check data and data2 are equal
|
||||
for (int i = 0; i < n_tokens*n_kv; ++i) {
|
||||
if (data[i] != data2[i]) {
|
||||
printf("data[%d] = %f, data2[%d] = %f\n", i, data[i], i, data2[i]);
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//const int64_t t_end = ggml_time_us();
|
||||
|
||||
//LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
|
||||
|
||||
Reference in New Issue
Block a user