Compare commits

...

4 Commits

Author SHA1 Message Date
Georgi Gerganov
f0fea264b0 cont : rand hadamard matrices 2026-03-27 20:11:47 +02:00
Georgi Gerganov
7711b3a36a cont : rotate caches separately + support non-power-of-2 head sizes 2026-03-27 14:07:38 +02:00
Georgi Gerganov
832e32639f cont : rotate V more + refactor 2026-03-27 11:29:16 +02:00
Georgi Gerganov
e5aa067d68 llama : rotate activations for better quantization 2026-03-26 19:04:04 +02:00
4 changed files with 257 additions and 1 deletions

View File

@@ -52,6 +52,59 @@ static bool can_reuse_kq_mask(
// impl
static bool ggml_is_power_of_2(int n) {
return (n & (n - 1)) == 0;
}
// orthonormal Walsh-Hadamard rotation matrix
static void set_input_hadamard(float * data, int n, int H) {
assert(ggml_is_power_of_2(n));
data[0*n + 0] = 1.0 / sqrtf(n);
for (int s = 1; s < n; s *= 2) {
for (int i = 0; i < s; i++) {
for (int j = 0; j < s; j++) {
const float val = data[i*n + j];
data[(i + s)*n + (j )] = val;
data[(i )*n + (j + s)] = val;
data[(i + s)*n + (j + s)] = -val;
}
}
}
srand(1242);
// copy to other heads
for (int h = 1; h < H; h++) {
//memcpy(data + h*n*n, data + (h-1)*n*n, n*n*sizeof(float));
for (int i = 0; i < n; i++) {
float sgn = rand() % 2 ? 1.0f : -1.0f;
for (int j = 0; j < n; j++) {
data[h*n*n + j*n + i] = sgn*data[j*n + i];
//data[h*n*n + (h-1)*n + j] *= sgn;
}
}
}
}
static ggml_tensor * ggml_rotate_hadamard(
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * rot) {
const auto n = rot->ne[0];
ggml_tensor * res;
res = ggml_reshape_4d(ctx, cur, n, cur->ne[0]/(n), cur->ne[1], cur->ne[2]);
//res = ggml_reshape_3d(ctx, cur, n, ggml_nelements(cur)/(n*cur->ne[1]), cur->ne[1]);
res = ggml_mul_mat(ctx, rot, res);
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
return res;
}
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
const int64_t n_tokens = ubatch->n_tokens;
@@ -429,6 +482,22 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
if (self_rotk) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_rotk->buffer));
float * data = (float *) self_rotk->data;
set_input_hadamard(data, self_rotk->ne[0], self_rotk->ne[2]);
}
if (self_rotv) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_rotv->buffer));
float * data = (float *) self_rotv->data;
set_input_hadamard(data, self_rotv->ne[0], self_rotv->ne[2]);
}
}
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
@@ -476,6 +545,22 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
if (self_rotk) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_rotk->buffer));
float * data = (float *) self_rotk->data;
set_input_hadamard(data, self_rotk->ne[0], self_rotk->ne[2]);
}
if (self_rotv) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_rotv->buffer));
float * data = (float *) self_rotv->data;
set_input_hadamard(data, self_rotv->ne[0], self_rotv->ne[2]);
}
}
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
@@ -532,6 +617,22 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
if (inp_attn->self_rotk) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_attn->self_rotk->buffer));
float * data = (float *) inp_attn->self_rotk->data;
set_input_hadamard(data, inp_attn->self_rotk->ne[0], inp_attn->self_rotk->ne[2]);
}
if (inp_attn->self_rotv) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_attn->self_rotv->buffer));
float * data = (float *) inp_attn->self_rotv->data;
set_input_hadamard(data, inp_attn->self_rotv->ne[0], inp_attn->self_rotv->ne[2]);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
@@ -630,6 +731,22 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
if (inp_attn->self_rotk) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_attn->self_rotk->buffer));
float * data = (float *) inp_attn->self_rotk->data;
set_input_hadamard(data, inp_attn->self_rotk->ne[0], inp_attn->self_rotk->ne[2]);
}
if (inp_attn->self_rotv) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_attn->self_rotv->buffer));
float * data = (float *) inp_attn->self_rotv->data;
set_input_hadamard(data, inp_attn->self_rotv->ne[0], inp_attn->self_rotv->ne[2]);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
@@ -2003,12 +2120,52 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
{
const bool can_rotk =
!hparams.is_n_embd_k_gqa_variable() &&
hparams.n_embd_head_k() % 64 == 0 &&
ggml_is_quantized(mctx_cur->type_k());
if (can_rotk) {
int nrot = 64;
//do {
// nrot *= 2;
//} while (hparams.n_embd_head_k() % nrot == 0);
//nrot /= 2;
inp->self_rotk = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, nrot, nrot, hparams.n_head_kv());
ggml_set_input(inp->self_rotk);
} else {
inp->self_rotk = nullptr;
}
const bool can_rotv =
!hparams.is_n_embd_v_gqa_variable() &&
hparams.n_embd_head_v() % 64 == 0 &&
ggml_is_quantized(mctx_cur->type_v());
if (can_rotv) {
int nrot = 64;
// TODO: I think we can afford to rotate the V more compared to Q and K - to be confirmed
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
//do {
// nrot *= 2;
//} while (hparams.n_embd_head_v() % nrot == 0);
//nrot /= 2;
inp->self_rotv = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nrot, nrot);
ggml_set_input(inp->self_rotv);
} else {
inp->self_rotv = nullptr;
}
}
return inp;
}
@@ -2034,6 +2191,15 @@ ggml_tensor * llm_graph_context::build_attn(
int il) const {
GGML_ASSERT(v_mla == nullptr);
if (inp->self_rotk) {
q_cur = ggml_rotate_hadamard(ctx0, q_cur, inp->self_rotk);
k_cur = ggml_rotate_hadamard(ctx0, k_cur, inp->self_rotk);
}
if (inp->self_rotv) {
v_cur = ggml_rotate_hadamard(ctx0, v_cur, inp->self_rotv);
}
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
@@ -2061,6 +2227,10 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (inp->self_rotv) {
cur = ggml_rotate_hadamard(ctx0, cur, inp->self_rotv);
}
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
@@ -2171,6 +2341,18 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v_mla,
float kq_scale,
int il) const {
if (inp->self_rotk) {
q_cur = ggml_rotate_hadamard(ctx0, q_cur, inp->self_rotk);
if (k_cur) {
k_cur = ggml_rotate_hadamard(ctx0, k_cur, inp->self_rotk);
}
}
if (inp->self_rotv) {
if (v_cur) {
v_cur = ggml_rotate_hadamard(ctx0, v_cur, inp->self_rotv);
}
}
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
@@ -2211,6 +2393,10 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (inp->self_rotv) {
cur = ggml_rotate_hadamard(ctx0, cur, inp->self_rotv);
}
if (wo) {
cur = build_lora_mm(wo, cur);
}
@@ -2315,6 +2501,48 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
}
{
const bool can_rotk =
!hparams.is_n_embd_k_gqa_variable() &&
hparams.n_embd_head_k() % 64 == 0 &&
ggml_is_quantized(mctx_cur->get_base()->type_k());
if (can_rotk) {
int nrot = 64;
//do {
// nrot *= 2;
//} while (hparams.n_embd_head_k() % nrot == 0);
//nrot /= 2;
inp->self_rotk = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, nrot, nrot, hparams.n_head_kv());
ggml_set_input(inp->self_rotk);
} else {
inp->self_rotk = nullptr;
}
const bool can_rotv =
!hparams.is_n_embd_v_gqa_variable() &&
hparams.n_embd_head_v() % 64 == 0 &&
ggml_is_quantized(mctx_cur->get_base()->type_v());
if (can_rotv) {
int nrot = 64;
// TODO: I think we can afford to rotate the V more compared to Q and K - to be confirmed
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
//do {
// nrot *= 2;
//} while (hparams.n_embd_head_v() % nrot == 0);
//nrot /= 2;
inp->self_rotv = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nrot, nrot);
ggml_set_input(inp->self_rotv);
} else {
inp->self_rotv = nullptr;
}
}
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}

View File

@@ -308,6 +308,9 @@ public:
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_rotk = nullptr;
ggml_tensor * self_rotv = nullptr;
// note: these have to be copies because in order to be able to reuse a graph, its inputs
// need to carry these parameters with them. otherwise, they can point to freed
// llm_graph_params from a previous batch, causing stack-use-after-return
@@ -384,6 +387,9 @@ public:
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_rotk = nullptr;
ggml_tensor * self_rotv = nullptr;
const llama_hparams hparams;
const llama_cparams cparams;

View File

@@ -1004,6 +1004,14 @@ bool llama_kv_cache::get_has_shift() const {
return result;
}
ggml_type llama_kv_cache::type_k() const {
return layers[0].k->type;
}
ggml_type llama_kv_cache::type_v() const {
return layers[0].v->type;
}
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;
@@ -2239,6 +2247,14 @@ uint32_t llama_kv_cache_context::get_n_kv() const {
return n_kv;
}
ggml_type llama_kv_cache_context::type_k() const {
return kv->type_k();
}
ggml_type llama_kv_cache_context::type_v() const {
return kv->type_v();
}
ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
}

View File

@@ -152,6 +152,9 @@ public:
bool get_has_shift() const;
ggml_type type_k() const;
ggml_type type_v() const;
//
// graph_build API
//
@@ -328,6 +331,9 @@ public:
uint32_t get_n_kv() const;
ggml_type type_k() const;
ggml_type type_v() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;