server : context checkpointing for hybrid and recurrent models (#16382)

* initial commit for branch 3

* generalize `swa_checkpoint` to `ctx_checkpoint`

this extends `llama-server`'s SWA checkpointing logic to include
hybrid/recurrent models such as Jamba, Granite

* oops

* disable debug prints

* keep backwards compat with `--swa-checkpoints`

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* update prompt re-processing message

* fix off-by-one error per GG

* keep `seq_rm` log per GG

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server : fix checkpoint logic to support recurrent caches

* server : cleanup and fixes

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
ddh0
2025-10-03 13:34:51 -05:00
committed by GitHub
parent 606a73f531
commit f6dcda3900
8 changed files with 87 additions and 72 deletions

View File

@@ -220,7 +220,7 @@ bool llama_kv_cache_iswa::get_can_shift() const {
}
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
kv_base->state_write(io, seq_id, flags);
}
@@ -228,7 +228,7 @@ void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id
}
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
kv_base->state_read(io, seq_id, flags);
}

View File

@@ -175,17 +175,17 @@ std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid::memory_breakdo
}
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
mem_attn->state_write(io, seq_id);
mem_recr->state_write(io, seq_id);
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
mem_attn->state_write(io, seq_id, flags);
}
mem_recr->state_write(io, seq_id, flags);
}
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);
mem_attn->state_read(io, seq_id);
mem_recr->state_read(io, seq_id);
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
mem_attn->state_read(io, seq_id, flags);
}
mem_recr->state_read(io, seq_id, flags);
}
llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {

View File

@@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) {
}
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
//printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1);
uint32_t new_head = size;
if (p0 < 0) {
@@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
if (tail_id >= 0) {
const auto & cell = cells[tail_id];
// partial intersection is invalid
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
return false;
}
// invalidate tails which will be cleared
@@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
} else {
// seq_id is negative, then the range should include everything or nothing
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n");
return false;
}
}

View File

@@ -20151,6 +20151,10 @@ bool llama_model_is_recurrent(const llama_model * model) {
return llm_arch_is_recurrent(model->arch);
}
bool llama_model_is_hybrid(const llama_model * model) {
return llm_arch_is_hybrid(model->arch);
}
bool llama_model_is_diffusion(const llama_model * model) {
return llm_arch_is_diffusion(model->arch);
}