Compare commits

...

31 Commits

Author SHA1 Message Date
Georgi Gerganov
59fee24c72 recurrent : rework graph inputs + add TODOs
ggml-ci
2025-06-18 09:29:51 +03:00
Gabe Goodhart
faf41199c0 refactor: Use a common build_recurrent_state method that is cache-agnostic
This reduces the code duplication between the different build_rs impls and
also retains a similar signature to the previous build_recurrent_state
method while standardizing on the input-dispatched build_rs implementation.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:19 -06:00
Gabe Goodhart
5046d412ef fix: Fix initialization of child states
Since initially writing this PR, the logic in the child state types changed
such that using the "init full" signature and keeping the ubatches on the
parent struct no longer worked.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:19 -06:00
Gabe Goodhart
9db44a2a63 fix: Fix resize vs reserve and skip null tensors in size computation
https://github.com/ggml-org/llama.cpp/pull/13979/files#r2149469788

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Co-Authored-By: @younesbelkada
2025-06-17 14:54:19 -06:00
Gabe Goodhart
11cd80d5de feat: Overhaul build_recurrent_state / build_inp_s_copy to match attention pattern
https://github.com/ggml-org/llama.cpp/pull/13979/files#r2141701738

This is a big overhaul to bring consistency between how inputs and per-
layer components are created for attention layers and recurrent layers. The
main changes are:

- Rename class llm_graph_input_s_copy -> llm_graph_input_rs
- Add a corresponding llm_graph_input_rs_hybrid_recurrent
- Rename build_inp_s_copy -> build_rs_inp_recurrent
- Add a corresponding build_rs_inp_hybrid_recurrent
- Rename build_recurrent_state -> build_rs to match build_attn w/
llm_graph_input_rs android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input
- Add a corresponding overload of build_rs w/
llm_graph_input_rs_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input
- Add a llm_graph_input_attn_kv_hybrid_recurrent analogous to
llm_graph_input_attn_kv_unified
- Add a build_attn override that takes
llm_graph_input_attn_kv_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input

This makes the two paradigms fully consistent. The main drawback is the
code duplication in the build_attn and build_rs implementations where the
only difference between implementations is how they cast the memory state.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:19 -06:00
Gabe Goodhart
4ec4e6a801 refactor: Use llama_memory_state_ptr for child states in hybrid memory state
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:19 -06:00
Gabe Goodhart
7ba463b38c fix: Remove llama_model_is_hybrid_Recurrent public API
https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2141728423

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:19 -06:00
Gabe Goodhart
1510016ea4 fix: Remove logits_all after rebase
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:19 -06:00
Gabe Goodhart
d8c929ff5d feat: Allow custom layer filters for hybrid recurrent
This should help support architectures like Falcon H1 where there is
overlap between layers that need attention and recurrent caches.

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140748922

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:19 -06:00
Gabe Goodhart
d5d7628b5f refactor: Remove n_embd_k/v_gqa from recurrent cache
This is no longer needed now that there are separate implementations

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140825128

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:19 -06:00
Gabe Goodhart
b42c8b43cf refactor: Remove layer index from n_embd_k/v_s
Now that it's not used at all in the unified cache, we don't need to use
the layer index to zero it out for attention layers.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:19 -06:00
Gabe Goodhart
1dd12133cd refactor: Remove n_embd_k/v_s from unified cache
No longer needed now that unified isn't also supporting recurrent

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140761069

Branch: HybridRecurrentCache
2025-06-17 14:54:18 -06:00
Gabe Goodhart
833dfb54ae fix: Use per-layer n_embd_k/v_s calls for mamba (1) layers
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
f6d5f055c6 fix: Remove errant virtual destructor leftover from previous impl attempt
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
9c1a604af8 fix: Update clear signature for data argument after rebase
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
de9297fd5e fix: Add missing padding to n_ctx for hybrid cache construction
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
911e694476 fix: Fix status for init_update sig for recurrent cache state
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
d3699366e6 fix: Update recurrent cache for changes to remove intermediate kv_cache interface
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
a9b5fe98ad fix: Fix logic for initializing inputs and attn layers for hybrid caches
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
e3c1631556 feat: Support hybrid recurrent in llama-graph
NOTE: I intentionally did not add support for s_mask since it will be going
away soon

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
cf03d4ae5c fix: Fix shift logic to defer to unified cache
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
6c6ec0003a fix: Fix wrong bool condition for split equal in hybrid cache
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
423c89401d feat: Construct hybrid recurrent cache for hybrid recurrent models
This includes a refactor of the create_memory logic to avoid needing to use
the arch enum explicitly unless a model needs explicit cache instantiation
logic beyond the standard logic for recurrent, hybrid, unified, and iswa.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
c71eaa37a0 feat: First pass at llama_kv_cache_hybrid_recurrent
This follows the pattern in iswa where the two child caches are held
explicitly to support the case where a model requires a single attention
cache and a single recurrent cache where each layer uses exactly one of the
caches.

This is a rewrite of the more generic approach in the original hybrid cache
PR: https://github.com/ggml-org/llama.cpp/pull/13276

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
13332a7554 fix: Use per-layer sizing everywhere in kv caches
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
40e9187892 feat: Add layer filter to recurrent cache
Branch: HybridCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
fb26e95ae7 refactor: rename *_is_hybrid -> *_is_hybrid_recurrent
The implementation of the hybrid cache intentionally does not specify the
types of the child caches, so there was a naming mismatch with these
predicate functions that used "hybrid" to imply "hybrid recurrent."

Branch: HybridCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
fc9e0b576e feat: Auto-fill hparams.recurrent_layer_arr based on whether the model is recurrent
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:18 -06:00
Gabe Goodhart
05f1958080 feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:17 -06:00
Gabe Goodhart
5e2f2c3876 feat: Add c++ side constants for attention layer indices hparam
Branch: GraniteFour
2025-06-17 14:54:17 -06:00
Gabe Goodhart
ec8fe17b1a feat: Add llama_model_is_hybrid API call
Also, split llama_model_is_recurrent into llm_arch_is_recurrent in
llama-arch with llama_model_is_recurrent delegating to
llm_arch_is_recurrent. The same split is done for hybird. This is needed
because there are places where the llama_model has not yet been initialized
but we need to check if the model is recurrent (specifically for the
per-layer recurrent check array in hparams).

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-06-17 14:54:17 -06:00
13 changed files with 859 additions and 227 deletions

View File

@@ -23,6 +23,7 @@ add_library(llama
llama-kv-cache-unified.cpp
llama-kv-cache-unified-iswa.cpp
llama-kv-cache-recurrent.cpp
llama-kv-cache-hybrid-recurrent.cpp
llama-memory.cpp
llama-mmap.cpp
llama-model-loader.cpp

View File

@@ -147,6 +147,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -1816,3 +1817,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
return LLM_TENSOR_INFOS.at(tensor);
}
bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
return true;
default:
return false;
}
}
bool llm_arch_is_hybrid_recurrent(const llm_arch & arch) {
// TODO: There are currently no hybrid models! Once there are, this will be
// the place to identify them
switch (arch) {
default:
return false;
}
}

View File

@@ -151,6 +151,7 @@ enum llm_kv {
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ATTENTION_LAYER_INDICES,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -439,3 +440,6 @@ const char * llm_arch_name(llm_arch arch);
llm_arch llm_arch_from_string(const std::string & name);
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
bool llm_arch_is_recurrent(const llm_arch& arch);
bool llm_arch_is_hybrid_recurrent(const llm_arch& arch);

View File

@@ -7,6 +7,7 @@
#include "llama-kv-cache-unified.h"
#include "llama-kv-cache-unified-iswa.h"
#include "llama-kv-cache-recurrent.h"
#include "llama-kv-cache-hybrid-recurrent.h"
#include <cassert>
#include <cmath>
@@ -238,7 +239,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
const int64_t n_kv = kv_state->get_n_kv();
@@ -403,6 +404,24 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
kv_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
const int64_t n_kv = kv_state->get_state_recurrent()->get_n_kv();
if (s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
int32_t * data = (int32_t *) s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) {
data[i] = kv_state->get_state_recurrent()->s_copy(i);
}
}
}
//
// llm_graph_context
//
@@ -961,23 +980,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
return cur;
}
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
const auto n_kv = kv_state->get_n_kv();
auto & cur = inp->s_copy;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
ggml_set_input(cur);
res->add_input(std::move(inp));
return cur;
}
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
@@ -1047,6 +1049,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
return pos_bias;
}
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, kv_state);
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
const auto n_kv = inp->kv_state->get_state_attn()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
{
const auto n_kv = kv_state->get_state_recurrent()->get_n_kv();
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
ggml_set_input(inp->s_copy);
}
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q,
@@ -1291,36 +1320,6 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
{
const auto n_kv = kv_state->get_base()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
{
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
const auto n_kv = kv_state->get_swa()->get_n_kv();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
}
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_cgraph * gf,
@@ -1430,20 +1429,99 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
ggml_tensor * llm_graph_context::build_recurrent_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head();
const auto rs_zero = kv_state->get_rs_z();
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn();
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
// store to KV cache
{
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
}
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
{
const auto n_kv = kv_state->get_base()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
{
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
const auto n_kv = kv_state->get_swa()->get_n_kv();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
}
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_rs(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
uint32_t n_kv,
uint32_t kv_head,
uint32_t kv_size,
int32_t rs_zero,
bool avoid_copies) const {
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
// Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized.
@@ -1474,10 +1552,47 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
return output_states;
}
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
const auto n_kv = kv_state->get_n_kv();
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
ggml_set_input(inp->s_copy);
return (llm_graph_input_rs *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_rs(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_kv(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
}
ggml_tensor * llm_graph_context::build_rs(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent();
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_kv(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
}
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
const llama_ubatch & ubatch,
llm_graph_input_rs * inp,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -1487,8 +1602,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
ggml_tensor * token_shift = build_recurrent_state(
gf, token_shift_all, state_copy,
ggml_tensor * token_shift = build_rs(
inp, gf, token_shift_all,
hparams.n_embd_k_s(), n_seqs);
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);

View File

@@ -22,6 +22,7 @@ struct llama_memory_state_i;
class llama_kv_cache_unified_state;
class llama_kv_cache_unified_iswa_state;
class llama_kv_cache_recurrent_state;
class llama_kv_cache_hybrid_recurrent_state;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
@@ -188,10 +189,10 @@ public:
const llama_cparams & cparams;
};
class llm_graph_input_s_copy : public llm_graph_input_i {
class llm_graph_input_rs : public llm_graph_input_i {
public:
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
virtual ~llm_graph_input_s_copy() = default;
llm_graph_input_rs(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
virtual ~llm_graph_input_rs() = default;
void set_input(const llama_ubatch * ubatch) override;
@@ -300,6 +301,33 @@ public:
const llama_cross * cross = nullptr;
};
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
hparams(hparams),
cparams(cparams),
kv_state(kv_state) {
}
virtual ~llm_graph_input_mem_hybrid() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_hybrid_recurrent_state * kv_state;
};
//
// llm_graph_result
//
@@ -508,13 +536,14 @@ struct llm_graph_context {
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
ggml_tensor * build_inp_pos_bucket_dec() const;
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
//
// attention
//
@@ -589,22 +618,62 @@ struct llm_graph_context {
float kq_scale,
int il) const;
ggml_tensor * build_attn(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
//
// recurrent
//
ggml_tensor * build_recurrent_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;
// TODO: avoid notion of "kv"
// TODO: move this implementation to llama_kv_cache_recurrent.
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
// `llama_kv_cache_recurrent`
ggml_tensor * build_rs(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
uint32_t n_kv,
uint32_t kv_head,
uint32_t kv_size,
int32_t rs_zero,
bool avoid_copies = false) const;
llm_graph_input_rs * build_rs_inp() const;
ggml_tensor * build_rs(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;
ggml_tensor * build_rs(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;
ggml_tensor * build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
const llama_ubatch & ubatch,
llm_graph_input_rs * inp,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const;
ggml_tensor * build_rwkv_token_shift_store(

View File

@@ -86,6 +86,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
return ssm_d_state * ssm_d_inner;
}
bool llama_hparams::recurrent_layer(uint32_t il) const {
return recurrent_layer_arr[il];
}
bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) {
return swa_layers[il];

View File

@@ -115,6 +115,9 @@ struct llama_hparams {
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
// for hybrid state space models
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
bool ssm_dt_b_c_rms = false;
float f_clamp_kqv = 0.0f;
@@ -186,6 +189,9 @@ struct llama_hparams {
// dimension of the recurrent state embeddings
uint32_t n_embd_v_s() const;
// whether or not the given layer is recurrent (for hybrid models)
bool recurrent_layer(uint32_t il) const;
bool is_swa(uint32_t il) const;
};

View File

@@ -0,0 +1,250 @@
#include "llama-kv-cache-hybrid-recurrent.h"
#include "llama-impl.h"
#include "llama-model.h"
#include "llama-context.h"
//
// llama_kv_cache_hybrid_recurrent
//
llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
const llama_model & model,
/* attn */
ggml_type attn_type_k,
ggml_type attn_type_v,
bool attn_v_trans,
uint32_t attn_kv_size,
uint32_t attn_n_pad,
uint32_t attn_n_swa,
llama_swa_type attn_swa_type,
/* recurrent */
ggml_type recurrent_type_k,
ggml_type recurrent_type_v,
uint32_t recurrent_kv_size,
/* common */
uint32_t n_seq_max,
bool offload,
/* layer filters */
layer_filter_cb && attn_filter,
layer_filter_cb && recurrent_filter) :
hparams(model.hparams),
kv_attn(new llama_kv_cache_unified(
model,
attn_filter == nullptr ?
[&](int32_t il) { return !model.hparams.recurrent_layer(il); }
: attn_filter,
attn_type_k,
attn_type_v,
attn_v_trans,
offload,
attn_kv_size,
n_seq_max,
attn_n_pad,
attn_n_swa,
attn_swa_type
)),
kv_recurrent(new llama_kv_cache_recurrent(
model,
recurrent_filter == nullptr ?
[&](int32_t il) { return model.hparams.recurrent_layer(il); }
: recurrent_filter,
recurrent_type_k,
recurrent_type_v,
offload,
recurrent_kv_size,
n_seq_max
)) {}
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
// since this includes a recurrent cache, we cannot use split_simple
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
// follow the recurrent pattern for creating the ubatch splits
std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) {
llama_ubatch ubatch;
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = sbatch.split_seq(n_ubatch);
} else {
ubatch = sbatch.split_equal(n_ubatch);
}
ubatches.push_back(ubatch);
}
// prepare the recurrent batches first
if (!kv_recurrent->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined state at this point?
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
// prepare the attention cache
auto heads_attn = kv_attn->prepare(ubatches);
if (heads_attn.empty()) {
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
}
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this);
}
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this, lctx, optimize);
}
bool llama_kv_cache_hybrid_recurrent::get_can_shift() const {
// Shifting is trivially supported for recurrent
return kv_attn->get_can_shift();
}
void llama_kv_cache_hybrid_recurrent::clear(bool data) {
kv_attn ->clear(data);
kv_recurrent->clear(data);
}
bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
// Try removing from the recurrent cache first since it may fail. If it does
// fail, the cache will not have been mutated.
if (!kv_recurrent->seq_rm(seq_id, p0, p1)) {
return false;
}
return kv_attn->seq_rm(seq_id, p0, p1);
}
void llama_kv_cache_hybrid_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
kv_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
kv_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_cache_hybrid_recurrent::seq_keep(llama_seq_id seq_id) {
kv_attn ->seq_keep(seq_id);
kv_recurrent->seq_keep(seq_id);
}
void llama_kv_cache_hybrid_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
kv_attn->seq_add(seq_id, p0, p1, shift);
kv_recurrent->seq_add(seq_id, p0, p1, shift);
}
void llama_kv_cache_hybrid_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
kv_attn ->seq_div(seq_id, p0, p1, d);
kv_recurrent->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min(llama_seq_id seq_id) const {
// the min of the total cache is the max of the two caches' min values
return std::max(kv_attn->seq_pos_min(seq_id), kv_recurrent->seq_pos_min(seq_id));
}
llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) const {
// the max of the total cache is the min of the two caches' max values
return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id));
}
void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
kv_attn ->state_write(io, seq_id);
kv_recurrent->state_write(io, seq_id);
}
void llama_kv_cache_hybrid_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
kv_attn ->state_read(io, seq_id);
kv_recurrent->state_read(io, seq_id);
}
llama_kv_cache_unified * llama_kv_cache_hybrid_recurrent::get_kv_attn() const {
return kv_attn.get();
}
llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() const {
return kv_recurrent.get();
}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) : status(status) {}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
: status(LLAMA_MEMORY_STATUS_SUCCESS) {
state_attn = kv->get_kv_attn ()->init_full();
state_recurrent = kv->get_kv_recurrent()->init_full();
status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status());
}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
llama_kv_cache_hybrid_recurrent * kv,
llama_context * lctx,
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
state_attn = kv->get_kv_attn ()->init_update(lctx, optimize);
state_recurrent = kv->get_kv_recurrent()->init_update(lctx, optimize);
status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status());
}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
llama_kv_cache_hybrid_recurrent * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches)
: status(LLAMA_MEMORY_STATUS_SUCCESS),
sbatch(std::move(sbatch)),
ubatches(std::move(ubatches)) {
// note: here we copy the ubatches. not sure if this is ideal
state_attn .reset(new llama_kv_cache_unified_state (kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches));
state_recurrent.reset(new llama_kv_cache_recurrent_state(kv->get_kv_recurrent(), {}, this->ubatches));
}
bool llama_kv_cache_hybrid_recurrent_state::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
state_attn ->next();
state_recurrent->next();
if (++i_next >= ubatches.size()) {
return false;
}
return true;
}
bool llama_kv_cache_hybrid_recurrent_state::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
bool res = true;
res = res & state_attn ->apply();
res = res & state_recurrent->apply();
return res;
}
std::vector<int64_t> & llama_kv_cache_hybrid_recurrent_state::out_ids() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return sbatch.out_ids;
}
llama_memory_status llama_kv_cache_hybrid_recurrent_state::get_status() const {
return status;
}
const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const {
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
}
const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const {
return static_cast<const llama_kv_cache_recurrent_state *>(state_recurrent.get());
}

View File

@@ -0,0 +1,144 @@
#pragma once
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cache-recurrent.h"
#include "llama-kv-cache-unified.h"
#include "llama-memory.h"
#include <memory>
#include <vector>
//
// llama_kv_cache_hybrid_recurrent
// TODO: rename to llama_memory_hybrid
//
// utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to
// support models where each layer may be either attention-based or recurrent
class llama_kv_cache_hybrid_recurrent : public llama_memory_i {
public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_kv_cache_hybrid_recurrent(
const llama_model & model,
/* attn */
ggml_type attn_type_k,
ggml_type attn_type_v,
bool attn_v_trans,
uint32_t attn_kv_size,
uint32_t attn_n_pad,
uint32_t attn_n_swa,
llama_swa_type attn_swa_type,
/* recurrent */
ggml_type recurrent_type_k,
ggml_type recurrent_type_v,
uint32_t recurrent_kv_size,
/* common */
uint32_t n_seq_max,
bool offload,
/* layer filters */
layer_filter_cb && attn_filter = nullptr,
layer_filter_cb && recurrent_filter = nullptr);
~llama_kv_cache_hybrid_recurrent() = default;
//
// llama_memory_i
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
llama_memory_state_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
//
// llama_kv_cache_hybrid_recurrent specific API
//
llama_kv_cache_unified * get_kv_attn () const;
llama_kv_cache_recurrent * get_kv_recurrent() const;
private:
const llama_hparams & hparams;
const std::unique_ptr<llama_kv_cache_unified> kv_attn;
const std::unique_ptr<llama_kv_cache_recurrent> kv_recurrent;
};
class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
public:
// init failure
explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status);
// init full
explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv);
// init update
explicit llama_kv_cache_hybrid_recurrent_state(
llama_kv_cache_hybrid_recurrent * kv,
llama_context * lctx,
bool optimize);
// init success
llama_kv_cache_hybrid_recurrent_state(
llama_kv_cache_hybrid_recurrent * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches);
~llama_kv_cache_hybrid_recurrent_state() = default;
bool next() override;
bool apply() override;
std::vector<int64_t> & out_ids() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_hybrid_recurrent_state
//
const llama_kv_cache_unified_state * get_state_attn () const;
const llama_kv_cache_recurrent_state * get_state_recurrent() const;
private:
llama_memory_status status;
llama_sbatch sbatch;
// the index of the next ubatch to process
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
llama_memory_state_ptr state_attn;
llama_memory_state_ptr state_recurrent;
};

View File

@@ -16,12 +16,13 @@
//
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer;
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
@@ -59,12 +60,14 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
return it->second;
};
k_l.reserve(n_layer);
v_l.reserve(n_layer);
k_l.resize(n_layer);
v_l.resize(n_layer);
for (int i = 0; i < n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
if (filter && !filter(i)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
continue;
}
const char * dev_name = "CPU";
@@ -84,12 +87,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
throw std::runtime_error("failed to create ggml context for kv cache");
}
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_s()*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_s()*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
k_l.push_back(k);
v_l.push_back(v);
k_l[i] = k;
v_l[i] = v;
}
// allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -381,11 +384,11 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
return std::make_unique<llama_kv_cache_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
}
llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
return std::make_unique<llama_kv_cache_recurrent_state>(this);
}
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
@@ -644,7 +647,9 @@ size_t llama_kv_cache_recurrent::size_k_bytes() const {
size_t size_k_bytes = 0;
for (const auto & k : k_l) {
size_k_bytes += ggml_nbytes(k);
if (k != nullptr) {
size_k_bytes += ggml_nbytes(k);
}
}
return size_k_bytes;
@@ -654,7 +659,9 @@ size_t llama_kv_cache_recurrent::size_v_bytes() const {
size_t size_v_bytes = 0;
for (const auto & v : v_l) {
size_v_bytes += ggml_nbytes(v);
if (v != nullptr) {
size_v_bytes += ggml_nbytes(v);
}
}
return size_v_bytes;
@@ -748,14 +755,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
// Iterate and write all the keys first, each row is a cell
// Get whole range at a time
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
// Write key type
const int32_t k_type_i = (int32_t)k_l[il]->type;
io.write(&k_type_i, sizeof(k_type_i));
// Write row size of key
const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
const uint64_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s());
io.write(&k_size_row, sizeof(k_size_row));
// Read each range of cells of k_size length each into tmp_buf and write out
@@ -768,14 +774,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
if (!v_trans) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Write value type
const int32_t v_type_i = (int32_t)v_l[il]->type;
io.write(&v_type_i, sizeof(v_type_i));
// Write row size of value
const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
const uint64_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s());
io.write(&v_size_row, sizeof(v_size_row));
// Read each range of cells of v_size length each into tmp_buf and write out
@@ -789,7 +794,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
// When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t kv_size = size;
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_s = hparams.n_embd_v_s();
// Write value type
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -800,10 +805,10 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
io.write(&v_size_el, sizeof(v_size_el));
// Write GQA embedding size
io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
io.write(&n_embd_v_s, sizeof(n_embd_v_s));
// For each row, we get the element values of each cell
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
for (uint32_t j = 0; j < n_embd_v_s; ++j) {
// Read each range of cells of v_size_el length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
@@ -936,7 +941,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
// Read type of key
int32_t k_type_i_ref;
@@ -950,7 +954,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
// Read row size of key
uint64_t k_size_row_ref;
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
const size_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s());
if (k_size_row != k_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
return false;
@@ -964,7 +968,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
if (!v_trans) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Read type of value
int32_t v_type_i_ref;
@@ -978,7 +981,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
// Read row size of value
uint64_t v_size_row_ref;
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
const size_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s());
if (v_size_row != v_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
return false;
@@ -992,7 +995,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
} else {
// For each layer, read the values for each cell (transposed)
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_s = hparams.n_embd_v_s();
// Read type of value
int32_t v_type_i_ref;
@@ -1012,17 +1015,17 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
return false;
}
// Read GQA embedding size
uint32_t n_embd_v_gqa_ref;
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
// Read state embedding size
uint32_t n_embd_v_s_ref;
io.read_to(&n_embd_v_s_ref, sizeof(n_embd_v_s_ref));
if (n_embd_v_s != n_embd_v_s_ref) {
LLAMA_LOG_ERROR("%s: mismatched state embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_s, n_embd_v_s_ref, il);
return false;
}
if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
for (uint32_t j = 0; j < n_embd_v_s; ++j) {
const size_t dst_offset = (head + j * size) * v_size_el;
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
}
@@ -1040,15 +1043,13 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
llama_kv_cache_recurrent * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), is_full(true) {
}
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv,
llama_sbatch sbatch,
std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;

View File

@@ -11,17 +11,24 @@
// llama_kv_cache_recurrent
//
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
// TODO: extract the cache state used for graph computation into llama_kv_cache_recurrent_state_i
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
// TODO: avoid the notion of "KV cache" / "KV cells", etc.
// TODO: rename to llama_recurrent_state / llama_recurrent_cache
class llama_kv_cache_recurrent : public llama_memory_i {
public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max);
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max);
~llama_kv_cache_recurrent() = default;
@@ -126,12 +133,10 @@ public:
// used to create a full-cache state
llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv);
// used to create a state from a batch
llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv,
llama_sbatch sbatch,
std::vector<llama_ubatch> ubatches);

View File

@@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
continue;
}
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
const char * dev_name = "CPU";
@@ -1430,7 +1430,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
for (const auto & layer : layers) {
const uint32_t il = layer.il;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
// Write key type
const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1452,7 +1452,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
for (const auto & layer : layers) {
const uint32_t il = layer.il;
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// Write value type
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1476,7 +1476,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
for (const auto & layer : layers) {
const uint32_t il = layer.il;
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// Write value type
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1621,7 +1621,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
for (const auto & layer : layers) {
const uint32_t il = layer.il;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
// Read type of key
int32_t k_type_i_ref;
@@ -1651,7 +1651,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
for (const auto & layer : layers) {
const uint32_t il = layer.il;
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// Read type of value
int32_t v_type_i_ref;
@@ -1681,7 +1681,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
for (const auto & layer : layers) {
const uint32_t il = layer.il;
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// Read type of value
int32_t v_type_i_ref;

View File

@@ -9,6 +9,7 @@
#include "llama-kv-cache-unified.h"
#include "llama-kv-cache-unified-iswa.h"
#include "llama-kv-cache-recurrent.h"
#include "llama-kv-cache-hybrid-recurrent.h"
#include "ggml-cpp.h"
@@ -470,6 +471,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
std::fill(
hparams.recurrent_layer_arr.begin(),
hparams.recurrent_layer_arr.end(),
llm_arch_is_recurrent(ml.get_arch()));
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
@@ -9111,7 +9116,7 @@ struct llm_build_mamba : public llm_graph_context {
// {n_embd, n_tokens}
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * state_copy = build_inp_s_copy();
auto * rs_inp = build_rs_inp();
for (int il = 0; il < n_layer; ++il) {
// norm
@@ -9120,7 +9125,7 @@ struct llm_build_mamba : public llm_graph_context {
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
@@ -9158,11 +9163,11 @@ struct llm_build_mamba : public llm_graph_context {
// TODO: split
ggml_tensor * build_mamba_layer(
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * state_copy,
const llama_ubatch & ubatch,
int il) const {
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * cur,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto kv_head = kv_state->get_head();
@@ -9187,12 +9192,12 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
// (ab)using the KV cache to store the states
ggml_tensor * conv = build_recurrent_state(
gf, conv_states_all, state_copy,
ggml_tensor * conv = build_rs(
inp, gf, conv_states_all,
hparams.n_embd_k_s(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
ggml_tensor * ssm = build_recurrent_state(
gf, ssm_states_all, state_copy,
ggml_tensor * ssm = build_rs(
inp, gf, ssm_states_all,
hparams.n_embd_v_s(), n_seqs);
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
@@ -11904,10 +11909,10 @@ struct llm_build_rwkv6_base : public llm_graph_context {
}
ggml_tensor * build_rwkv6_time_mix(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -12031,8 +12036,8 @@ struct llm_build_rwkv6_base : public llm_graph_context {
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
}
ggml_tensor * wkv_state = build_recurrent_state(
gf, kv_state->get_v_l(il), state_copy,
ggml_tensor * wkv_state = build_rs(
inp, gf, kv_state->get_v_l(il),
hparams.n_embd_v_s(), n_seqs);
ggml_tensor * wkv_output;
@@ -12087,7 +12092,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
inpL = build_inp_embd(model.tok_embd);
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
ggml_tensor * state_copy = build_inp_s_copy();
auto * rs_inp = build_rs_inp();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12097,9 +12102,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
const llama_layer * layer = &model.layers[il];
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
gf, state_copy, ubatch, il
);
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -12114,7 +12117,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
1
);
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@@ -12184,7 +12187,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * state_copy = build_inp_s_copy();
auto * rs_inp = build_rs_inp();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12194,9 +12197,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
const llama_layer * layer = &model.layers[il];
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
gf, state_copy, ubatch, il
);
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
cb(att_norm, "attn_norm", il);
@@ -12208,7 +12209,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
1
);
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -12296,10 +12297,10 @@ struct llm_build_rwkv7_base : public llm_graph_context {
}
ggml_tensor * build_rwkv7_time_mix(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
ggml_tensor *& first_layer_value,
const llama_ubatch & ubatch,
int il) const {
@@ -12382,8 +12383,8 @@ struct llm_build_rwkv7_base : public llm_graph_context {
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
ggml_tensor * wkv_state = build_recurrent_state(
gf, kv_state->get_v_l(il), state_copy,
ggml_tensor * wkv_state = build_rs(
inp, gf, kv_state->get_v_l(il),
hparams.n_embd_v_s(), n_seqs);
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@@ -12440,7 +12441,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
inpL = build_inp_embd(model.tok_embd);
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
ggml_tensor * state_copy = build_inp_s_copy();
auto * rs_inp = build_rs_inp();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12450,9 +12451,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
const llama_layer * layer = &model.layers[il];
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
gf, state_copy, ubatch, il
);
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -12467,7 +12466,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
1
);
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@@ -12533,7 +12532,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * state_copy = build_inp_s_copy();
auto * rs_inp = build_rs_inp();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12543,9 +12542,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
const llama_layer * layer = &model.layers[il];
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
gf, state_copy, ubatch, il
);
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
cb(att_norm, "attn_norm", il);
@@ -12557,7 +12554,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
1
);
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -13738,6 +13735,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
llama_memory_i * res;
switch (arch) {
// Models that need specific instantiation should be handled in the
// switch statement
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_NOMIC_BERT:
@@ -13747,57 +13746,75 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
{
res = nullptr;
} break;
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
{
res = new llama_kv_cache_recurrent(
*this,
GGML_TYPE_F32,
GGML_TYPE_F32,
cparams.offload_kqv,
std::max((uint32_t) 1, cparams.n_seq_max),
cparams.n_seq_max);
} break;
// Models that need standard caching should rely on recurrent/hybrid
// checks
default:
{
const auto padding = llama_kv_cache_unified::get_padding(cparams);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
res = new llama_kv_cache_unified_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.n_ctx,
cparams.n_seq_max,
cparams.n_ubatch,
padding);
} else {
GGML_ASSERT(!hparams.is_swa_any());
res = new llama_kv_cache_unified(
if (llm_arch_is_recurrent(arch)) {
res = new llama_kv_cache_recurrent(
*this,
nullptr,
params.type_k,
params.type_v,
!cparams.flash_attn,
GGML_TYPE_F32,
GGML_TYPE_F32,
cparams.offload_kqv,
cparams.n_ctx,
cparams.n_seq_max,
padding,
hparams.n_swa,
hparams.swa_type);
std::max((uint32_t) 1, cparams.n_seq_max),
cparams.n_seq_max);
} else if (llm_arch_is_hybrid_recurrent(arch)) {
const auto padding = llama_kv_cache_unified::get_padding(cparams);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
res = new llama_kv_cache_hybrid_recurrent(
/* model */ *this,
/* attn_type_k */ params.type_k,
/* attn_type_v */ params.type_v,
/* attn_v_trans */ !cparams.flash_attn,
/* attn_kv_size */ cparams.n_ctx,
/* attn_n_pad */ padding,
/* attn_n_swa */ hparams.n_swa,
/* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32,
/* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv);
} else {
const auto padding = llama_kv_cache_unified::get_padding(cparams);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
res = new llama_kv_cache_unified_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.n_ctx,
cparams.n_seq_max,
cparams.n_ubatch,
padding);
} else {
GGML_ASSERT(!hparams.is_swa_any());
res = new llama_kv_cache_unified(
*this,
nullptr,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.n_ctx,
cparams.n_seq_max,
padding,
hparams.n_swa,
hparams.swa_type);
}
}
}
}
@@ -14377,14 +14394,7 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
}
bool llama_model_is_recurrent(const llama_model * model) {
switch (model->arch) {
case LLM_ARCH_MAMBA: return true;
case LLM_ARCH_RWKV6: return true;
case LLM_ARCH_RWKV6QWEN2: return true;
case LLM_ARCH_RWKV7: return true;
case LLM_ARCH_ARWKV7: return true;
default: return false;
}
return llm_arch_is_recurrent(model->arch);
}
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {