server : host-memory prompt caching (#16391)

* minor : code style

* server : fix prompt similarity calculation

* server : initial host-memory prompt caching

* cont

* server : refactor

* cont

* cont : make the server task of the slot const

* cont : minor [no ci]

* server : cache prompts and checkpoints only for completion tasks

* server : improve prompt caching logic

* cont : fix check for number of cached prompts [no ci]

* server : improve caching logic, add -cram CLI arg

* server : print prompt mismatch info

* cont : better naming [no ci]

* server : improve prompt cache loading logic

* server : add option to debug the slot contents (#16482)

* server : add option to debug the slot contents

* Update tools/server/server.cpp

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

* server : add option to disable prompt cache

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
This commit is contained in:
Georgi Gerganov
2025-10-09 18:54:51 +03:00
committed by GitHub
parent 8328fd4bae
commit d00cbea63c
10 changed files with 813 additions and 471 deletions

View File

@@ -31,10 +31,10 @@
using json = nlohmann::ordered_json;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
@@ -1102,6 +1102,7 @@ public:
~server_tokens() = default;
// Prevent copying
// TODO: server_tokens should be copyable - remove this:
server_tokens(const server_tokens&) = delete;
server_tokens& operator=(const server_tokens&) = delete;
@@ -1119,7 +1120,7 @@ public:
}
}
server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
// for debugging
std::string str() const {
@@ -1144,9 +1145,8 @@ public:
auto it = map_pos_to_media.find(pos);
if (it != map_pos_to_media.end()) {
return it->second;
} else {
throw std::runtime_error("Chunk not found");
}
throw std::runtime_error("Chunk not found");
}
void push_back(llama_token tok) {
@@ -1170,7 +1170,7 @@ public:
map_pos_to_media[start_pos] = std::move(new_chunk);
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens;
auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
for (size_t i = 0; i < n_tokens; ++i) {
push_back(text_tokens[i]);
}
@@ -1190,7 +1190,7 @@ public:
// We could also just check, but this will prevent silently dropping MTMD data.
GGML_ASSERT(has_mtmd);
for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
auto chunk = tokens.map_pos_to_media[it->first].get();
auto * chunk = tokens.map_pos_to_media[it->first].get();
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
map_pos_to_media[start_pos+it->first] = std::move(new_chunk);
}
@@ -1271,33 +1271,52 @@ public:
}
size_t get_common_prefix(const server_tokens & b) const {
size_t max_idx = std::min(tokens.size(), b.tokens.size());
for (size_t i = 0; i < max_idx; ++i) {
auto & ai = tokens[i];
auto & bi = b.tokens[i];
const size_t max_idx = std::min(tokens.size(), b.tokens.size());
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
GGML_ASSERT(has_mtmd);
const auto & a_chunk = find_chunk(i);
const auto & b_chunk = b.find_chunk(i);
GGML_ASSERT(a_chunk && b_chunk);
std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get());
std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get());
size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get());
size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get());
if (ai_id == bi_id && a_pos == b_pos) {
GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen
i += a_pos - 1; // will be +1 by the for loop
if (!has_mtmd) {
for (size_t i = 0; i < max_idx; ++i) {
if (tokens[i] == b.tokens[i]) {
continue;
} else {
return i;
}
} else if (ai == bi) {
continue;
} else {
return i;
}
return max_idx;
}
for (size_t i = 0; i < max_idx; ++i) {
const llama_token ai = tokens[i];
const llama_token bi = b.tokens[i];
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
const auto & a_chunk = find_chunk(i);
const auto & b_chunk = b.find_chunk(i);
GGML_ASSERT(a_chunk && b_chunk);
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
if (id_ai == id_bi && pos_a == pos_b) {
GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
i += pos_a - 1; // will be +1 by the for loop
continue;
}
return i;
}
if (ai == bi) {
continue;
}
return i;
}
return max_idx; // all tokens are equal
}
@@ -1308,7 +1327,7 @@ public:
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
for (size_t i = 0; i < tokens.size(); ++i) {
auto & t = tokens[i];
const auto & t = tokens[i];
if (t == LLAMA_TOKEN_NULL) {
try {
const auto & chunk = find_chunk(i);
@@ -1330,8 +1349,8 @@ public:
mtmd_context * mctx,
llama_pos n_past,
int32_t seq_id,
llama_pos & n_pos_out) {
auto & chunk = find_chunk(n_past);
llama_pos & n_pos_out) const {
const auto & chunk = find_chunk(n_past);
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
? "image" : "audio";
SRV_INF("processing %s...\n", name);