server : host-memory prompt caching (#16391)
* minor : code style * server : fix prompt similarity calculation * server : initial host-memory prompt caching * cont * server : refactor * cont * cont : make the server task of the slot const * cont : minor [no ci] * server : cache prompts and checkpoints only for completion tasks * server : improve prompt caching logic * cont : fix check for number of cached prompts [no ci] * server : improve caching logic, add -cram CLI arg * server : print prompt mismatch info * cont : better naming [no ci] * server : improve prompt cache loading logic * server : add option to debug the slot contents (#16482) * server : add option to debug the slot contents * Update tools/server/server.cpp --------- Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> * server : add option to disable prompt cache --------- Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
This commit is contained in:
@@ -31,10 +31,10 @@
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
|
||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
@@ -1102,6 +1102,7 @@ public:
|
||||
~server_tokens() = default;
|
||||
|
||||
// Prevent copying
|
||||
// TODO: server_tokens should be copyable - remove this:
|
||||
server_tokens(const server_tokens&) = delete;
|
||||
server_tokens& operator=(const server_tokens&) = delete;
|
||||
|
||||
@@ -1119,7 +1120,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
|
||||
server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
|
||||
|
||||
// for debugging
|
||||
std::string str() const {
|
||||
@@ -1144,9 +1145,8 @@ public:
|
||||
auto it = map_pos_to_media.find(pos);
|
||||
if (it != map_pos_to_media.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
throw std::runtime_error("Chunk not found");
|
||||
}
|
||||
throw std::runtime_error("Chunk not found");
|
||||
}
|
||||
|
||||
void push_back(llama_token tok) {
|
||||
@@ -1170,7 +1170,7 @@ public:
|
||||
map_pos_to_media[start_pos] = std::move(new_chunk);
|
||||
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
size_t n_tokens;
|
||||
auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
|
||||
const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
push_back(text_tokens[i]);
|
||||
}
|
||||
@@ -1190,7 +1190,7 @@ public:
|
||||
// We could also just check, but this will prevent silently dropping MTMD data.
|
||||
GGML_ASSERT(has_mtmd);
|
||||
for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
|
||||
auto chunk = tokens.map_pos_to_media[it->first].get();
|
||||
auto * chunk = tokens.map_pos_to_media[it->first].get();
|
||||
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
||||
map_pos_to_media[start_pos+it->first] = std::move(new_chunk);
|
||||
}
|
||||
@@ -1271,33 +1271,52 @@ public:
|
||||
}
|
||||
|
||||
size_t get_common_prefix(const server_tokens & b) const {
|
||||
size_t max_idx = std::min(tokens.size(), b.tokens.size());
|
||||
for (size_t i = 0; i < max_idx; ++i) {
|
||||
auto & ai = tokens[i];
|
||||
auto & bi = b.tokens[i];
|
||||
const size_t max_idx = std::min(tokens.size(), b.tokens.size());
|
||||
|
||||
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
|
||||
GGML_ASSERT(has_mtmd);
|
||||
const auto & a_chunk = find_chunk(i);
|
||||
const auto & b_chunk = b.find_chunk(i);
|
||||
GGML_ASSERT(a_chunk && b_chunk);
|
||||
std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get());
|
||||
std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get());
|
||||
size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get());
|
||||
size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get());
|
||||
if (ai_id == bi_id && a_pos == b_pos) {
|
||||
GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen
|
||||
i += a_pos - 1; // will be +1 by the for loop
|
||||
if (!has_mtmd) {
|
||||
for (size_t i = 0; i < max_idx; ++i) {
|
||||
if (tokens[i] == b.tokens[i]) {
|
||||
continue;
|
||||
} else {
|
||||
return i;
|
||||
}
|
||||
} else if (ai == bi) {
|
||||
continue;
|
||||
} else {
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
return max_idx;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < max_idx; ++i) {
|
||||
const llama_token ai = tokens[i];
|
||||
const llama_token bi = b.tokens[i];
|
||||
|
||||
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
|
||||
const auto & a_chunk = find_chunk(i);
|
||||
const auto & b_chunk = b.find_chunk(i);
|
||||
|
||||
GGML_ASSERT(a_chunk && b_chunk);
|
||||
|
||||
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
|
||||
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
|
||||
|
||||
const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
|
||||
const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
|
||||
|
||||
if (id_ai == id_bi && pos_a == pos_b) {
|
||||
GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
|
||||
i += pos_a - 1; // will be +1 by the for loop
|
||||
continue;
|
||||
}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
if (ai == bi) {
|
||||
continue;
|
||||
}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
return max_idx; // all tokens are equal
|
||||
}
|
||||
|
||||
@@ -1308,7 +1327,7 @@ public:
|
||||
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
for (size_t i = 0; i < tokens.size(); ++i) {
|
||||
auto & t = tokens[i];
|
||||
const auto & t = tokens[i];
|
||||
if (t == LLAMA_TOKEN_NULL) {
|
||||
try {
|
||||
const auto & chunk = find_chunk(i);
|
||||
@@ -1330,8 +1349,8 @@ public:
|
||||
mtmd_context * mctx,
|
||||
llama_pos n_past,
|
||||
int32_t seq_id,
|
||||
llama_pos & n_pos_out) {
|
||||
auto & chunk = find_chunk(n_past);
|
||||
llama_pos & n_pos_out) const {
|
||||
const auto & chunk = find_chunk(n_past);
|
||||
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
|
||||
? "image" : "audio";
|
||||
SRV_INF("processing %s...\n", name);
|
||||
|
||||
Reference in New Issue
Block a user