server : support unified cache across slots (#16736)

* server : support unified context across slots

* cont : fix speculative decoding initialization

* context : fix n_ctx_per_seq computation

* server : purge slots one by one

* tests : add unified cache server tests

* llama : update per-seq context computation

* test-thread-safety : handle tiny training context of the input model

* server : fix server_tokens clear()

* server : use 4 slots + unified KV by default

* llama : add note about context size queries

* cont : update todos [no ci]

* context : do not cap the size of the context

* tests : adjust parameters to be CI friendlier

* context : add warning
This commit is contained in:
Georgi Gerganov
2025-11-02 18:14:04 +02:00
committed by GitHub
parent 87c9efc3b2
commit cd5e3b5754
12 changed files with 163 additions and 48 deletions

View File

@@ -112,11 +112,24 @@ llama_context::llama_context(
}
}
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
if (cparams.kv_unified) {
cparams.n_ctx_seq = cparams.n_ctx;
} else {
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
if (cparams.n_ctx_seq == 0) {
throw std::runtime_error("n_ctx_seq == 0");
}
if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
}
}
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
@@ -125,14 +138,14 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
if (n_ctx_per_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
}
if (n_ctx_per_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
}
if (!hparams.vocab_only) {
@@ -453,8 +466,8 @@ uint32_t llama_context::n_ctx() const {
return cparams.n_ctx;
}
uint32_t llama_context::n_ctx_per_seq() const {
return cparams.n_ctx / cparams.n_seq_max;
uint32_t llama_context::n_ctx_seq() const {
return cparams.n_ctx_seq;
}
uint32_t llama_context::n_batch() const {
@@ -2383,6 +2396,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
return ctx->n_ctx();
}
uint32_t llama_n_ctx_seq(const llama_context * ctx) {
return ctx->n_ctx_seq();
}
uint32_t llama_n_batch(const llama_context * ctx) {
return ctx->n_batch();
}

View File

@@ -43,11 +43,11 @@ struct llama_context {
ggml_backend_sched_t get_sched() const;
uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;
uint32_t n_ctx() const;
uint32_t n_ctx_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;
uint32_t n_threads() const;
uint32_t n_threads_batch() const;

View File

@@ -8,6 +8,7 @@
struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_ctx_seq; // context for a single sequence
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;

View File

@@ -6712,14 +6712,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co
}
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
const uint32_t n_ctx_seq = cparams.n_ctx_seq;
// choose long/short freq factors based on the context size
if (layers[il].rope_freqs != nullptr) {
return layers[il].rope_freqs;
}
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
if (n_ctx_seq > hparams.n_ctx_orig_yarn) {
return layers[il].rope_long;
}
@@ -6795,12 +6795,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
} else {
uint32_t n_ctx_per_stream = cparams.n_ctx;
if (!cparams.kv_unified) {
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
}
llama_memory_i::layer_reuse_cb reuse = nullptr;
if (arch == LLM_ARCH_GEMMA3N) {
@@ -6824,7 +6818,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
@@ -6840,7 +6834,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_ctx_seq,
cparams.n_seq_max,
1,
hparams.n_swa,