mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-23 16:37:33 +03:00
Compare commits
5 Commits
b8061
...
gg/kv-mask
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2751545b9 | ||
|
|
e08d3ac323 | ||
|
|
490f6f70c0 | ||
|
|
bac56aef91 | ||
|
|
6628f5186a |
@@ -200,42 +200,6 @@ uint32_t llama_hparams::n_layer_kv() const {
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
switch (swa_type) {
|
||||
case LLAMA_SWA_TYPE_NONE:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_STANDARD:
|
||||
{
|
||||
if (p1 - p0 >= (int32_t) n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_CHUNKED:
|
||||
{
|
||||
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
|
||||
|
||||
if (p0 < pos_chunk_start) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_SYMMETRIC:
|
||||
{
|
||||
const int32_t half_n_swa = (int32_t) n_swa / 2;
|
||||
const int32_t pos_diff = p1 - p0;
|
||||
|
||||
// Mask if outside the symmetric window
|
||||
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool llama_hparams::use_mrope() const {
|
||||
return rope_sections[0] > 0 && rope_sections[1] > 0;
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "llama.h"
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
|
||||
// bump if necessary
|
||||
#define LLAMA_MAX_LAYERS 512
|
||||
@@ -274,9 +275,45 @@ struct llama_hparams {
|
||||
uint32_t n_layer_kv() const;
|
||||
|
||||
// note that this function uses different SWA parameters from those in the hparams
|
||||
// note: inlined on purpose for performance reasons
|
||||
// TODO: think of a better place for this function
|
||||
// TODO: pack the SWA params in a struct?
|
||||
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
|
||||
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
switch (swa_type) {
|
||||
case LLAMA_SWA_TYPE_NONE:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_STANDARD:
|
||||
{
|
||||
if (p1 - p0 >= (int32_t) n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_CHUNKED:
|
||||
{
|
||||
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
|
||||
|
||||
if (p0 < pos_chunk_start) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_SYMMETRIC:
|
||||
{
|
||||
const int32_t half_n_swa = (int32_t) n_swa / 2;
|
||||
const int32_t pos_diff = p1 - p0;
|
||||
|
||||
// Mask if outside the symmetric window
|
||||
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool use_mrope() const;
|
||||
};
|
||||
|
||||
@@ -852,7 +852,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
|
||||
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
||||
|
||||
// SWA mask
|
||||
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||
can_use = true;
|
||||
}
|
||||
}
|
||||
@@ -1237,6 +1237,197 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
|
||||
}
|
||||
}
|
||||
|
||||
struct args_set_input_kq_mask {
|
||||
const llama_hparams & hparams;
|
||||
const llama_ubatch * ubatch;
|
||||
|
||||
const std::vector<llama_kv_cells> & v_cells;
|
||||
const std::vector<uint32_t> & seq_to_stream;
|
||||
|
||||
uint32_t n_swa;
|
||||
llama_swa_type swa_type;
|
||||
|
||||
int64_t n_kv;
|
||||
int64_t n_stream;
|
||||
int64_t n_tps;
|
||||
};
|
||||
|
||||
template<bool causal, bool swa, bool is_2d, bool alibi>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
//const auto & hparams = args.hparams;
|
||||
const auto & ubatch = args.ubatch;
|
||||
|
||||
const auto & v_cells = args.v_cells;
|
||||
const auto & seq_to_stream = args.seq_to_stream;
|
||||
|
||||
const uint32_t n_swa = args.n_swa;
|
||||
const llama_swa_type swa_type = args.swa_type;
|
||||
|
||||
const int64_t n_kv = args.n_kv;
|
||||
const int64_t n_stream = args.n_stream;
|
||||
const int64_t n_tps = args.n_tps;
|
||||
|
||||
// the min position in the batch for each sequence
|
||||
llama_pos seq_pos_min[LLAMA_MAX_SEQ];
|
||||
std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
|
||||
|
||||
for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]);
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
// bookeeping of the KQ mask cells that could change for other tokens of the same sequence
|
||||
std::unordered_map<llama_seq_id, uint32_t> seq_srct;
|
||||
std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs;
|
||||
|
||||
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
||||
const uint32_t i = s*n_tps + ii;
|
||||
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
const auto & cells = v_cells.at(seq_to_stream[seq_id]);
|
||||
|
||||
llama_pos p0 = -1;
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
// for M-RoPE
|
||||
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
||||
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||
|
||||
const uint64_t idst = n_kv*i;
|
||||
|
||||
// for tokens of the same sequence, the mask is mostly the same, so we can reuse it
|
||||
// the only cells that could change are the ones that are with similar positions as the
|
||||
// ones in the batch (i.e. due to causal masking, SWA, etc.)
|
||||
// keep track of those cells and shortcut the loop to save time
|
||||
// note: this optimization is not compatible with Alibi position encoding
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18842
|
||||
bool prev = false;
|
||||
|
||||
auto & idxs = seq_idxs[seq_id];
|
||||
|
||||
if (!alibi) {
|
||||
if (seq_srct.find(seq_id) != seq_srct.end()) {
|
||||
const uint32_t srct = seq_srct[seq_id];
|
||||
|
||||
const uint64_t idst_prev = n_kv*srct;
|
||||
|
||||
std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst);
|
||||
|
||||
prev = true;
|
||||
} else {
|
||||
idxs.clear();
|
||||
idxs.reserve(ubatch->n_tokens + n_swa + 32);
|
||||
|
||||
seq_srct[seq_id] = i;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t jj = 0; jj < n_kv; ++jj) {
|
||||
uint32_t j = jj;
|
||||
|
||||
// we have an exiting mask for this sequence -> update just seq_idxs
|
||||
if (!alibi) {
|
||||
if (prev) {
|
||||
if (jj >= idxs.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
j = idxs[jj];
|
||||
}
|
||||
}
|
||||
|
||||
if (cells.is_empty(j)) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
// mask the token if not the same sequence
|
||||
if (!cells.seq_has(j, seq_id)) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
p0 = cells.pos_get(j);
|
||||
|
||||
if (!alibi) {
|
||||
if (!prev) {
|
||||
// record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32
|
||||
if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) {
|
||||
idxs.push_back(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (causal) {
|
||||
// mask future tokens
|
||||
if (p0 > p1) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (is_2d) {
|
||||
if (p0 == p1) {
|
||||
const auto & p0_ext = cells.ext_get(j);
|
||||
|
||||
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
||||
goto skip;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (swa) {
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
||||
goto skip;
|
||||
}
|
||||
}
|
||||
|
||||
if (alibi) {
|
||||
data[idst + j] = -std::abs(p0 - p1);
|
||||
} else {
|
||||
data[idst + j] = 0.0f;
|
||||
}
|
||||
|
||||
continue;
|
||||
skip:
|
||||
data[idst + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool causal, bool swa, bool is_2d>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool alibi = args.hparams.use_alibi;
|
||||
if (alibi) {
|
||||
set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool causal, bool swa>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool is_2d = args.ubatch->is_pos_2d();
|
||||
if (is_2d) {
|
||||
set_input_kq_mask_impl<causal, swa, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<causal, swa, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool causal>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
|
||||
if (swa) {
|
||||
set_input_kq_mask_impl<causal, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<causal, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
@@ -1251,74 +1442,110 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
||||
// n_tps == n_tokens_per_stream
|
||||
const int64_t n_tps = n_tokens/n_stream;
|
||||
|
||||
std::fill(data, data + ggml_nelements(dst), -INFINITY);
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
||||
// Causal mask:
|
||||
// xxx-------
|
||||
// xxxx------
|
||||
// xxxxx-----
|
||||
// Non-causal mask:
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||
// TODO: optimize this section
|
||||
for (uint32_t h = 0; h < 1; ++h) {
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
||||
const uint32_t i = s*n_tps + ii;
|
||||
const args_set_input_kq_mask args = {
|
||||
/*.hparams =*/ hparams,
|
||||
/*.ubatch =*/ ubatch,
|
||||
/*.v_cells =*/ v_cells,
|
||||
/*.seq_to_stream =*/ seq_to_stream,
|
||||
/*.n_swa =*/ n_swa,
|
||||
/*.swa_type =*/ swa_type,
|
||||
/*.n_kv =*/ n_kv,
|
||||
/*.n_stream =*/ n_stream,
|
||||
/*.n_tps =*/ n_tps,
|
||||
};
|
||||
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
if (causal_attn) {
|
||||
set_input_kq_mask_impl<true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<false>(args, data);
|
||||
}
|
||||
|
||||
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
// the old reference implementation
|
||||
{
|
||||
std::vector<float> data2(n_tokens*n_kv);
|
||||
std::fill(data2.begin(), data2.end(), -INFINITY);
|
||||
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
||||
// Causal mask:
|
||||
// xxx-------
|
||||
// xxxx------
|
||||
// xxxxx-----
|
||||
// Non-causal mask:
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||
// TODO: optimize this section
|
||||
for (uint32_t h = 0; h < 1; ++h) {
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
||||
const uint32_t i = s*n_tps + ii;
|
||||
|
||||
// for M-RoPE
|
||||
const bool is_2d = ubatch->is_pos_2d();
|
||||
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
||||
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
|
||||
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
if (cells.is_empty(j)) {
|
||||
continue;
|
||||
}
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
// mask the token if not the same sequence
|
||||
if (!cells.seq_has(j, seq_id)) {
|
||||
continue;
|
||||
}
|
||||
// for M-RoPE
|
||||
const bool is_2d = ubatch->is_pos_2d();
|
||||
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
||||
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||
|
||||
const llama_pos p0 = cells.pos_get(j);
|
||||
const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
|
||||
|
||||
// mask future tokens
|
||||
if (causal_attn && p0 > p1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (causal_attn && is_2d && p0 == p1) {
|
||||
const auto & p0_ext = cells.ext_get(j);
|
||||
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
if (cells.is_empty(j)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (is_masked_swa(p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
// mask the token if not the same sequence
|
||||
if (!cells.seq_has(j, seq_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||
const llama_pos p0 = cells.pos_get(j);
|
||||
|
||||
// mask future tokens
|
||||
if (causal_attn && p0 > p1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (causal_attn && is_2d && p0 == p1) {
|
||||
const auto & p0_ext = cells.ext_get(j);
|
||||
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
data2[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check data and data2 are equal
|
||||
for (int i = 0; i < n_tokens*n_kv; ++i) {
|
||||
if (data[i] != data2[i]) {
|
||||
printf("data[%d] = %f, data2[%d] = %f\n", i, data[i], i, data2[i]);
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//const int64_t t_end = ggml_time_us();
|
||||
|
||||
//LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
|
||||
}
|
||||
|
||||
void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
@@ -1483,10 +1710,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
|
||||
@@ -257,8 +257,6 @@ private:
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
|
||||
|
||||
ggml_tensor * build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
|
||||
@@ -259,3 +259,8 @@ target_link_libraries(${TEST_TARGET} PRIVATE llama)
|
||||
|
||||
llama_build_and_test(test-alloc.cpp)
|
||||
target_include_directories(test-alloc PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
|
||||
|
||||
# test-kq-mask
|
||||
set(LLAMA_TEST_NAME test-kq-mask)
|
||||
llama_build_and_test(test-kq-mask.cpp)
|
||||
target_include_directories(test-kq-mask PRIVATE ${PROJECT_SOURCE_DIR}/src)
|
||||
|
||||
507
tests/test-kq-mask.cpp
Normal file
507
tests/test-kq-mask.cpp
Normal file
@@ -0,0 +1,507 @@
|
||||
// run:
|
||||
// g++ -std=c++17 -I ../src/ -I ../include -I ../ggml/include ../tests/test-kq-mask.cpp && ./a.out
|
||||
|
||||
#include "llama-hparams.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-kv-cells.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// Populate random data into cells before tests
|
||||
static void populate_random_cells(std::vector<llama_kv_cells> & cells, std::mt19937 & gen) {
|
||||
std::uniform_int_distribution<int> pos_dist(0, cells.size() + 64);
|
||||
for (auto & c : cells) {
|
||||
// c is already resized to n_kv by caller
|
||||
for (uint32_t i = 0; i < c.size(); ++i) {
|
||||
// randomly decide if cell is used
|
||||
if (std::uniform_int_distribution<int>(0, 3)(gen) == 0) {
|
||||
continue; // leave empty ~25%
|
||||
}
|
||||
llama_pos p = pos_dist(gen);
|
||||
c.pos_set(i, p);
|
||||
// random ext
|
||||
llama_kv_cell_ext ext{ (llama_pos) std::uniform_int_distribution<int>(0, 10)(gen),
|
||||
(llama_pos) std::uniform_int_distribution<int>(0, 10)(gen) };
|
||||
c.ext_set(i, ext);
|
||||
// assign a random seq id
|
||||
llama_seq_id sid =
|
||||
(llama_seq_id) std::uniform_int_distribution<int>(0, std::min(10, (int) LLAMA_MAX_SEQ - 1))(gen);
|
||||
c.seq_add(i, sid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Simplified llama_hparams structure for testing
|
||||
// Simplified args_set_input_kq_mask structure
|
||||
struct args_set_input_kq_mask {
|
||||
llama_hparams & hparams;
|
||||
const llama_ubatch * ubatch;
|
||||
|
||||
const std::vector<llama_kv_cells> & v_cells;
|
||||
const std::vector<uint32_t> & seq_to_stream;
|
||||
|
||||
uint32_t n_swa;
|
||||
llama_swa_type swa_type;
|
||||
|
||||
int64_t n_kv;
|
||||
int64_t n_stream;
|
||||
int64_t n_tps;
|
||||
};
|
||||
|
||||
// Old implementation of set_input_kq_mask_impl
|
||||
template <bool causal, bool swa, bool is_2d, bool alibi>
|
||||
static void set_input_kq_mask_impl_old(const args_set_input_kq_mask & args, float * data) {
|
||||
const auto & ubatch = args.ubatch;
|
||||
|
||||
const auto & v_cells = args.v_cells;
|
||||
const auto & seq_to_stream = args.seq_to_stream;
|
||||
|
||||
const uint32_t n_swa = args.n_swa;
|
||||
const llama_swa_type swa_type = args.swa_type;
|
||||
|
||||
const int64_t n_kv = args.n_kv;
|
||||
const int64_t n_stream = args.n_stream;
|
||||
const int64_t n_tps = args.n_tps;
|
||||
|
||||
std::fill(data, data + n_kv*ubatch->n_tokens, -INFINITY);
|
||||
|
||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
||||
// Causal mask:
|
||||
// xxx-------
|
||||
// xxxx------
|
||||
// xxxxx-----
|
||||
// Non-causal mask:
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||
// TODO: optimize this section
|
||||
for (uint32_t h = 0; h < 1; ++h) {
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
||||
const uint32_t i = s*n_tps + ii;
|
||||
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
// for M-RoPE
|
||||
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
||||
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||
|
||||
const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
|
||||
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
if (cells.is_empty(j)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// mask the token if not the same sequence
|
||||
if (!cells.seq_has(j, seq_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_pos p0 = cells.pos_get(j);
|
||||
|
||||
// mask future tokens
|
||||
if (causal && p0 > p1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (causal && is_2d && p0 == p1) {
|
||||
const auto & p0_ext = cells.ext_get(j);
|
||||
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
data[idst + j] = alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// New implementation of set_input_kq_mask_impl (with the optimization from PR #18842)
|
||||
template <bool causal, bool swa, bool is_2d, bool alibi>
|
||||
static void set_input_kq_mask_impl_new(const args_set_input_kq_mask & args, float * data) {
|
||||
const auto & ubatch = args.ubatch;
|
||||
|
||||
const auto & v_cells = args.v_cells;
|
||||
const auto & seq_to_stream = args.seq_to_stream;
|
||||
|
||||
const uint32_t n_swa = args.n_swa;
|
||||
const llama_swa_type swa_type = args.swa_type;
|
||||
|
||||
const int64_t n_kv = args.n_kv;
|
||||
const int64_t n_stream = args.n_stream;
|
||||
const int64_t n_tps = args.n_tps;
|
||||
|
||||
// the min position in the batch for each sequence
|
||||
llama_pos seq_pos_min[LLAMA_MAX_SEQ];
|
||||
std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
|
||||
|
||||
for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]);
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
// bookeeping of the KQ mask cells that could change for other tokens of the same sequence
|
||||
std::unordered_map<llama_seq_id, uint32_t> seq_srct;
|
||||
std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs;
|
||||
|
||||
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
||||
const uint32_t i = s * n_tps + ii;
|
||||
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
const auto & cells = v_cells.at(seq_to_stream[seq_id]);
|
||||
|
||||
llama_pos p0 = -1;
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
// for M-RoPE
|
||||
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens * 2] : 0;
|
||||
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||
|
||||
const uint64_t idst = n_kv * i;
|
||||
|
||||
// for tokens of the same sequence, the mask is mostly the same, so we can reuse it
|
||||
// the only cells that could change are the ones that are with similar positions as the
|
||||
// ones in the batch (i.e. due to causal masking, SWA, etc.)
|
||||
// keep track of those cells and shortcut the loop to save time
|
||||
// note: this optimization is not compatible with Alibi position encoding
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18842
|
||||
bool prev = false;
|
||||
|
||||
auto & idxs = seq_idxs[seq_id];
|
||||
|
||||
if (!alibi) {
|
||||
if (seq_srct.find(seq_id) != seq_srct.end()) {
|
||||
const uint32_t srct = seq_srct[seq_id];
|
||||
|
||||
const uint64_t idst_prev = n_kv * srct;
|
||||
|
||||
std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst);
|
||||
|
||||
prev = true;
|
||||
} else {
|
||||
idxs.clear();
|
||||
idxs.reserve(ubatch->n_tokens + n_swa + 32);
|
||||
|
||||
seq_srct[seq_id] = i;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t jj = 0; jj < n_kv; ++jj) {
|
||||
uint32_t j = jj;
|
||||
|
||||
// we have an exiting mask for this sequence -> update just seq_idxs
|
||||
if (!alibi) {
|
||||
if (prev) {
|
||||
if (jj >= idxs.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
j = idxs[jj];
|
||||
}
|
||||
}
|
||||
|
||||
if (cells.is_empty(j)) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
// mask the token if not the same sequence
|
||||
if (!cells.seq_has(j, seq_id)) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
p0 = cells.pos_get(j);
|
||||
|
||||
if (!alibi) {
|
||||
if (!prev) {
|
||||
// record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32
|
||||
if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) {
|
||||
idxs.push_back(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (causal) {
|
||||
// mask future tokens
|
||||
if (p0 > p1) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (is_2d) {
|
||||
if (p0 == p1) {
|
||||
const auto & p0_ext = cells.ext_get(j);
|
||||
|
||||
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
||||
goto skip;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (swa) {
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
||||
goto skip;
|
||||
}
|
||||
}
|
||||
|
||||
if (alibi) {
|
||||
data[idst + j] = -std::abs(p0 - p1);
|
||||
} else {
|
||||
data[idst + j] = 0.0f;
|
||||
}
|
||||
|
||||
continue;
|
||||
skip:
|
||||
data[idst + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wrapper functions to call the implementations
|
||||
template <bool causal, bool swa, bool is_2d>
|
||||
static void set_input_kq_mask_impl_old_wrapper(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool alibi = args.hparams.use_alibi;
|
||||
if (alibi) {
|
||||
set_input_kq_mask_impl_old<causal, swa, is_2d, true>(args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl_old<causal, swa, is_2d, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool causal, bool swa, bool is_2d>
|
||||
static void set_input_kq_mask_impl_new_wrapper(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool alibi = args.hparams.use_alibi;
|
||||
if (alibi) {
|
||||
set_input_kq_mask_impl_new<causal, swa, is_2d, true>(args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl_new<causal, swa, is_2d, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool causal, bool swa>
|
||||
static void set_input_kq_mask_impl_old_wrapper(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool is_2d = args.ubatch->is_pos_2d();
|
||||
if (is_2d) {
|
||||
set_input_kq_mask_impl_old_wrapper<causal, swa, true>(args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl_old_wrapper<causal, swa, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool causal, bool swa>
|
||||
static void set_input_kq_mask_impl_new_wrapper(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool is_2d = args.ubatch->is_pos_2d();
|
||||
if (is_2d) {
|
||||
set_input_kq_mask_impl_new_wrapper<causal, swa, true>(args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl_new_wrapper<causal, swa, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool causal>
|
||||
static void set_input_kq_mask_impl_old_wrapper(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
|
||||
if (swa) {
|
||||
set_input_kq_mask_impl_old_wrapper<causal, true>(args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl_old_wrapper<causal, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool causal>
|
||||
static void set_input_kq_mask_impl_new_wrapper(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
|
||||
if (swa) {
|
||||
set_input_kq_mask_impl_new_wrapper<causal, true>(args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl_new_wrapper<causal, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
// Simple test function
|
||||
static void test_kq_mask_impl() {
|
||||
std::cout << "Testing set_input_kq_mask implementations...\n";
|
||||
|
||||
// Parameter space (kept small for test speed)
|
||||
const std::vector<int> test_n_kv = { 64, 512, 2048, 8192 };
|
||||
const std::vector<int> test_n_stream = { 1, 2, 4 };
|
||||
const std::vector<int> test_n_tokens = { 1, 8, 64, 128, 512 };
|
||||
|
||||
// Random generator
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
|
||||
int total_tests = 0;
|
||||
int passed_tests = 0;
|
||||
|
||||
// Helper to run a single configuration and compare old vs new
|
||||
auto run_case = [&](bool causal, bool alibi, llama_swa_type swa_type, int n_swa,
|
||||
args_set_input_kq_mask args, std::vector<float> & data_old,
|
||||
std::vector<float> & data_new) {
|
||||
args.hparams.use_alibi = alibi;
|
||||
args.swa_type = swa_type;
|
||||
args.n_swa = n_swa;
|
||||
|
||||
// call appropriate wrappers based on causal flag
|
||||
if (causal) {
|
||||
set_input_kq_mask_impl_old_wrapper<true>(args, data_old.data());
|
||||
set_input_kq_mask_impl_new_wrapper<true>(args, data_new.data());
|
||||
} else {
|
||||
set_input_kq_mask_impl_old_wrapper<false>(args, data_old.data());
|
||||
set_input_kq_mask_impl_new_wrapper<false>(args, data_new.data());
|
||||
}
|
||||
|
||||
// compare
|
||||
bool match = true;
|
||||
for (size_t i = 0; i < data_old.size(); ++i) {
|
||||
if (data_old[i] != data_new[i]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
++total_tests;
|
||||
if (match) {
|
||||
++passed_tests;
|
||||
std::cout << "✓ Test passed: " << (causal ? "causal" : "non-causal") << ", "
|
||||
<< (alibi ? "ALIBI" : "no ALIBI") << ", " << (swa_type == LLAMA_SWA_TYPE_NONE ? "no SWA" : "SWA")
|
||||
<< ", n_kv=" << args.n_kv << ", n_stream=" << args.n_stream
|
||||
<< ", n_tokens=" << args.n_tps * args.n_stream << "\n";
|
||||
} else {
|
||||
std::cout << "✗ Test failed: " << (causal ? "causal" : "non-causal") << ", "
|
||||
<< (alibi ? "ALIBI" : "no ALIBI") << ", " << (swa_type == LLAMA_SWA_TYPE_NONE ? "no SWA" : "SWA")
|
||||
<< ", n_kv=" << args.n_kv << ", n_stream=" << args.n_stream
|
||||
<< ", n_tokens=" << args.n_tps * args.n_stream << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
// Main loops over dimensions
|
||||
for (int n_kv : test_n_kv) {
|
||||
for (int n_stream : test_n_stream) {
|
||||
for (int n_tokens : test_n_tokens) {
|
||||
for (int n_pos : {1, 3}) {
|
||||
if (n_tokens > n_kv) {
|
||||
continue; // unrealistic
|
||||
}
|
||||
if (n_stream > 1 && n_tokens % n_stream != 0) {
|
||||
continue; // must divide evenly
|
||||
}
|
||||
|
||||
// Prepare random test data
|
||||
std::vector<llama_pos> test_pos(n_tokens*n_pos);
|
||||
std::vector<int32_t> test_n_seq_id(n_tokens);
|
||||
std::vector<llama_seq_id> test_seq_id_data(n_tokens);
|
||||
std::vector<llama_seq_id *> test_seq_id(n_tokens);
|
||||
|
||||
std::uniform_int_distribution<> pos_dist(std::max(0, n_kv - 2 * n_tokens), n_kv);
|
||||
std::uniform_int_distribution<> seq_dist(0, std::min(LLAMA_MAX_SEQ, n_stream - 1));
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
for (int p = 0; p < n_pos; ++p) {
|
||||
test_pos[i*n_pos + p] = pos_dist(gen);
|
||||
}
|
||||
test_n_seq_id[i] = 1;
|
||||
test_seq_id_data[i] = seq_dist(gen);
|
||||
test_seq_id[i] = &test_seq_id_data[i];
|
||||
}
|
||||
|
||||
// Build ubatch
|
||||
llama_ubatch ubatch{};
|
||||
ubatch.n_tokens = n_tokens;
|
||||
ubatch.n_seq_tokens = n_tokens;
|
||||
ubatch.n_seqs = n_stream;
|
||||
ubatch.n_seqs_unq = n_stream;
|
||||
ubatch.n_pos = n_pos;
|
||||
ubatch.pos = test_pos.data();
|
||||
ubatch.n_seq_id = test_n_seq_id.data();
|
||||
ubatch.seq_id = test_seq_id.data();
|
||||
ubatch.seq_id_unq = test_seq_id_data.data();
|
||||
|
||||
// Dummy hparams (will be mutated per case)
|
||||
llama_hparams hparams{};
|
||||
|
||||
// Cells per stream
|
||||
std::vector<llama_kv_cells> cells(n_stream);
|
||||
for (int s = 0; s < n_stream; ++s) {
|
||||
cells[s].resize(n_kv);
|
||||
}
|
||||
|
||||
// Populate random data into cells
|
||||
static std::random_device rd;
|
||||
static std::mt19937 gen(rd());
|
||||
populate_random_cells(cells, gen);
|
||||
|
||||
// seq_to_stream mapping
|
||||
std::vector<uint32_t> seq_to_stream(LLAMA_MAX_SEQ, 0);
|
||||
for (int s = 0; s < n_stream; ++s) {
|
||||
seq_to_stream[s] = s;
|
||||
}
|
||||
|
||||
// Base args (will be copied/modified for each case)
|
||||
args_set_input_kq_mask base_args = {
|
||||
.hparams = hparams,
|
||||
.ubatch = &ubatch,
|
||||
.v_cells = cells,
|
||||
.seq_to_stream = seq_to_stream,
|
||||
.n_swa = 0,
|
||||
.swa_type = LLAMA_SWA_TYPE_NONE,
|
||||
.n_kv = n_kv,
|
||||
.n_stream = n_stream,
|
||||
.n_tps = n_tokens / n_stream,
|
||||
};
|
||||
|
||||
// Output buffers
|
||||
std::vector<float> data_old(n_tokens * n_kv);
|
||||
std::vector<float> data_new(n_tokens * n_kv);
|
||||
|
||||
// 1) causal, no SWA, no ALIBI
|
||||
run_case(true, false, LLAMA_SWA_TYPE_NONE, 0, base_args, data_old, data_new);
|
||||
// 2) causal, SWA, no ALIBI
|
||||
run_case(true, false, LLAMA_SWA_TYPE_STANDARD, 128, base_args, data_old, data_new);
|
||||
// 3) non‑causal, no SWA, no ALIBI
|
||||
run_case(false, false, LLAMA_SWA_TYPE_NONE, 0, base_args, data_old, data_new);
|
||||
// 4) non‑causal, SWA, no ALIBI
|
||||
run_case(false, false, LLAMA_SWA_TYPE_STANDARD, 128, base_args, data_old, data_new);
|
||||
// 5) causal, ALIBI, no SWA
|
||||
run_case(true, true, LLAMA_SWA_TYPE_NONE, 0, base_args, data_old, data_new);
|
||||
// 6) non‑causal, ALIBI, no SWA
|
||||
run_case(false, true, LLAMA_SWA_TYPE_NONE, 0, base_args, data_old, data_new);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Test completed. Passed: " << passed_tests << "/" << total_tests << "\n";
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_kq_mask_impl();
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user