kv-cache : support layer reuse (#15504)

* kv-cache : support layer reuse

ggml-ci

* cont : update comments [no ci]
This commit is contained in:
Georgi Gerganov
2025-08-24 13:07:07 +03:00
committed by GitHub
parent c9a24fb932
commit b730706a49
12 changed files with 203 additions and 136 deletions

View File

@@ -1115,6 +1115,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.set_swa_pattern(5);
hparams.n_layer_kv_from_start = 20;
hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
hparams.f_attention_scale = 1.0f;
@@ -1474,12 +1475,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// Expert gating function (GLM-4.5 uses sigmoid)
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
}
// NextN/MTP parameters
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
switch (hparams.n_layer) {
case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer)
@@ -10524,7 +10528,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
const int64_t n_embd_altup;
const int64_t n_altup;
const int i_altup_act;
const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
const int n_layer_sparsity = 10; // number of layers using activation sparsity
const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
@@ -10574,8 +10577,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
for (int il = 0; il < n_layer; ++il) {
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
const bool has_kv = (il < n_layer_kv);
const float freq_base_l = model.get_rope_freq_base (cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
@@ -10595,7 +10596,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
// self-attention
if (has_kv) {
if (hparams.has_kv(il)) {
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
@@ -10635,7 +10636,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
} else {
// no KV layers
// reuse KV cache of earlier layers
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
@@ -18256,12 +18257,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
if (llm_arch_is_recurrent(arch)) {
res = new llama_memory_recurrent(
*this,
nullptr,
GGML_TYPE_F32,
GGML_TYPE_F32,
cparams.offload_kqv,
std::max((uint32_t) 1, cparams.n_seq_max),
cparams.n_seq_max);
cparams.n_seq_max,
nullptr);
} else if (llm_arch_is_hybrid(arch)) {
const auto padding = llama_kv_cache::get_padding(cparams);
@@ -18302,6 +18303,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
llama_memory_i::layer_reuse_cb reuse = nullptr;
if (arch == LLM_ARCH_GEMMA3N) {
reuse = [&](int32_t il) {
if (il >= (int32_t) hparams.n_layer_kv_from_start) {
return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1);
}
return -1;
};
}
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
@@ -18316,13 +18329,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
n_ctx_per_stream,
cparams.n_seq_max,
cparams.n_ubatch,
padding);
padding,
nullptr,
reuse);
} else {
GGML_ASSERT(!hparams.is_swa_any());
res = new llama_kv_cache(
*this,
nullptr,
params.type_k,
params.type_v,
!cparams.flash_attn,
@@ -18332,7 +18346,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_seq_max,
padding,
hparams.n_swa,
hparams.swa_type);
hparams.swa_type,
nullptr,
nullptr);
}
}
}