[KVPOOL]decode save kvcache (#5168)

### What this PR does / why we need it?

kvpool decode save kvcache
now only support mla

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: baxingpiaochong <771405853@qq.com>
Co-authored-by: Chao Lei <leichao139636@163.com>
This commit is contained in:
baxingpiaochong
2026-01-04 22:22:01 +08:00
committed by GitHub
parent 350b95efcf
commit 46c2fc6a3c
6 changed files with 156 additions and 31 deletions

View File

@@ -61,6 +61,8 @@ class KVPoolWorker:
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"backend", "mooncake")
self.block_size = vllm_config.cache_config.block_size
@@ -92,9 +94,44 @@ class KVPoolWorker:
self.pp_rank,
)
partitions = None
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
num_hidden_layers = model_config.hf_config.num_hidden_layers
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_layer_partition", None)
prefill_pp_size = int(
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_size", 1))
if partition_list_str is not None:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != prefill_pp_size:
raise ValueError(
f"{len(partitions)=} does not match {prefill_pp_size=}."
)
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}."
)
else:
layers_per_partition = num_hidden_layers // prefill_pp_size
partitions = [
layers_per_partition for _ in range(prefill_pp_size)
]
if remaining_layers := num_hidden_layers % prefill_pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
self.token_database = ChunkedTokenDatabase(self.metadata,
self.block_size,
self.use_mla)
self.use_mla, partitions)
real_backend = backend_map.get(self.backend.lower())
self.m_store = real_backend( # type: ignore[misc]
@@ -103,6 +140,8 @@ class KVPoolWorker:
self.kv_send_thread: Optional[KVTransferThread] = None
self.kv_recv_thread: Optional[KVTransferThread] = None
self.finished_store_req: set[str] = set()
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
@@ -176,11 +215,12 @@ class KVPoolWorker:
self.kv_recv_thread.start()
ready_event.wait()
else:
if self.kv_role in ['kv_producer', 'kv_both']:
if self.kv_role in ['kv_producer', 'kv_both'
] or self.consumer_is_to_put:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, self.put_step,
self.tp_rank, self.dcp_size, self.put_step, self.kv_role,
ready_event_sending)
self.kv_send_thread.start()
if self.load_async:
@@ -289,6 +329,8 @@ class KVPoolWorker:
continue
request.current_event = current_event
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
request.req_id)
self.kv_send_thread.add_request( # type: ignore[union-attr]
request, )
@@ -413,11 +455,13 @@ class KVPoolWorker:
for layer_id in range(self.num_layers):
yield
def get_finished(self) -> tuple[set[str], set[str]]:
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
done_sending = (
self.kv_send_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
self.get_and_clear_finished_requests(
finished_req_ids # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both']
or self.consumer_is_to_put else set())
done_recving = (
self.kv_recv_thread.
@@ -430,6 +474,29 @@ class KVPoolWorker:
self.tp_rank)
return done_sending, done_recving
def get_and_clear_finished_requests(self, finished_req_ids) -> set[str]:
finished_sending = set()
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
):
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
req_id] == 0 and req_id in self.finished_store_req:
self.finished_store_req.remove(req_id)
finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
for req_id in finished_req_ids:
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
req_id)
if req_remain_jobs == 0:
finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
elif req_remain_jobs is not None:
self.finished_store_req.add(req_id)
return finished_sending
def lookup(
self,
token_len: int,