[KVPOOL]decode save kvcache (#5168)

### What this PR does / why we need it?

kvpool decode save kvcache
now only support mla

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: baxingpiaochong <771405853@qq.com>
Co-authored-by: Chao Lei <leichao139636@163.com>
This commit is contained in:
baxingpiaochong
2026-01-04 22:22:01 +08:00
committed by GitHub
parent 350b95efcf
commit 46c2fc6a3c
6 changed files with 156 additions and 31 deletions

View File

@@ -34,6 +34,8 @@ class AscendStoreConnector(KVConnectorBase_V1):
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"use_layerwise", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
connector_name = vllm_config.kv_transfer_config.kv_connector
if connector_name == "MooncakeConnectorStoreV1":
@@ -121,7 +123,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
self.connector_worker.save_kv_layer(self._get_connector_metadata())
def wait_for_save(self):
if self.kv_role == "kv_consumer":
if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
# Don't do save if the role is kv_consumer
return
@@ -135,7 +137,8 @@ class AscendStoreConnector(KVConnectorBase_V1):
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
meta = self._get_connector_metadata()
done_sending, done_recving = self.connector_worker.get_finished()
done_sending, done_recving = self.connector_worker.get_finished(
finished_req_ids)
sended_and_finished: set[str] = set()
for item in list(self.sended_but_unfinished_reqs):
if item not in meta.unfinished_request_ids:

View File

@@ -87,12 +87,14 @@ class LayerPoolKey(PoolKey):
class ChunkedTokenDatabase():
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool):
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
partitions: Optional[List[int]]):
self.metadata = metadata
self.block_size = block_size
self.use_mla = use_mla
self.kv_caches_base_addr: list[int] = []
self.block_len: list[int] = []
self.partitions = partitions
def _make_key_by_hash(self,
chunk_hash: str,
@@ -188,6 +190,28 @@ class ChunkedTokenDatabase():
else:
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
def decode_adaptor_prefill_pp(self, key, addr, size):
if self.partitions is None or len(self.partitions) == 1:
return key, addr, size
new_key = []
new_addr = []
new_size = []
for i, (addr_list, size_list) in enumerate(zip(addr, size)):
start = 0
for j, part in enumerate(self.partitions):
# part * 2 because addr and size contain both k and v
end = len(addr_list) if j == len(
self.partitions) - 1 else start + part * 2
new_str = key[i].replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{j}", 1)
new_key.append(new_str)
new_addr.append(addr_list[start:end])
new_size.append(size_list[start:end])
start = end
return new_key, new_addr, new_size
#Parameters related to the connector metadata
@dataclass
@@ -247,15 +271,12 @@ class RequestTracker:
def update(
self,
new_token_ids: list[int],
new_block_ids: Union[tuple[list[int], ...], list[int]],
) -> None:
"""Update the request tracker when a running request is
scheduled again
"""
self.token_len = self.token_len + len(new_token_ids)
if len(new_block_ids) == 0:
new_block_ids = []
elif isinstance(new_block_ids, tuple):
@@ -378,4 +399,4 @@ class LasyerMultiBlockReqMeta:
block_ids: list[int]
layer_id: int
is_last_chunk: Optional[bool] = True
current_event: Optional[torch.npu.Event] = None
current_event: Optional[torch.npu.Event] = None

View File

@@ -1,5 +1,6 @@
import queue
import threading
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any
@@ -99,7 +100,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
ready_event: threading.Event):
kv_role: str, ready_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
@@ -108,6 +109,17 @@ class KVCacheStoreSendingThread(KVTransferThread):
ready_event,
name="KVCacheSendingThread")
self.put_step = put_step
self.kv_role = kv_role
self.stored_requests = defaultdict[str, int](int)
def add_stored_request(self, req_id: str):
with self.done_task_lock:
self.stored_requests[req_id] += 1
def delete_finished_stored_request(self, req_id: str):
with self.done_task_lock:
if req_id in self.stored_requests:
del self.stored_requests[req_id]
def _handle_request(self, req_meta: ReqMeta):
token_len = req_meta.token_len_chunk
@@ -154,13 +166,6 @@ class KVCacheStoreSendingThread(KVTransferThread):
req_id,
)
addrs = []
sizes = []
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(
start, ends[index], block_ids)
addrs.append(addr)
sizes.append(size)
if keys:
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
@@ -168,12 +173,24 @@ class KVCacheStoreSendingThread(KVTransferThread):
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
addrs = []
sizes = []
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(
start, ends[index], block_ids)
addrs.append(addr)
sizes.append(size)
if self.kv_role == "kv_consumer":
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(
keys, addrs, sizes)
if current_event is not None:
current_event.synchronize()
self.m_store.put(keys, addrs, sizes)
if is_last_chunk:
self.set_finished_request(req_id)
with self.done_task_lock:
self.stored_requests[req_id] -= 1
self.request_queue.task_done()

View File

@@ -24,6 +24,8 @@ class KVPoolScheduler:
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_load", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.client = LookupKeyClient(vllm_config)
@@ -149,7 +151,8 @@ class KVPoolScheduler:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
force_skip_save = self.kv_role == "kv_consumer"
force_skip_save = (self.kv_role == "kv_consumer"
and not self.consumer_is_to_put)
for finished_req_id in scheduler_output.finished_req_ids:
self._request_trackers.pop(finished_req_id, None)
@@ -197,6 +200,7 @@ class KVPoolScheduler:
num_current_tokens = request_tracker.token_len
new_token_ids = request.all_token_ids[
num_current_tokens:num_current_tokens + num_new_tokens]
request_tracker.token_len += len(new_token_ids)
else:
raise ValueError(
f"Request {req_id} is not in _unfinished_requests, "
@@ -204,10 +208,7 @@ class KVPoolScheduler:
new_block_ids = cached_reqs.new_block_ids[i]
if not new_block_ids:
continue
request_tracker.update(new_token_ids, new_block_ids)
# decode not save
if request_tracker.token_len > len(request.prompt_token_ids):
continue
request_tracker.update(new_block_ids)
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
@@ -270,7 +271,7 @@ class KVPoolScheduler:
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
if self.kv_role == "kv_consumer":
if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
return False, None
tracker = self._request_trackers.get(request.request_id)
if tracker is not None and tracker.num_saved_tokens <= 0:

View File

@@ -61,6 +61,8 @@ class KVPoolWorker:
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"backend", "mooncake")
self.block_size = vllm_config.cache_config.block_size
@@ -92,9 +94,44 @@ class KVPoolWorker:
self.pp_rank,
)
partitions = None
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
num_hidden_layers = model_config.hf_config.num_hidden_layers
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_layer_partition", None)
prefill_pp_size = int(
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_size", 1))
if partition_list_str is not None:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != prefill_pp_size:
raise ValueError(
f"{len(partitions)=} does not match {prefill_pp_size=}."
)
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}."
)
else:
layers_per_partition = num_hidden_layers // prefill_pp_size
partitions = [
layers_per_partition for _ in range(prefill_pp_size)
]
if remaining_layers := num_hidden_layers % prefill_pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
self.token_database = ChunkedTokenDatabase(self.metadata,
self.block_size,
self.use_mla)
self.use_mla, partitions)
real_backend = backend_map.get(self.backend.lower())
self.m_store = real_backend( # type: ignore[misc]
@@ -103,6 +140,8 @@ class KVPoolWorker:
self.kv_send_thread: Optional[KVTransferThread] = None
self.kv_recv_thread: Optional[KVTransferThread] = None
self.finished_store_req: set[str] = set()
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
@@ -176,11 +215,12 @@ class KVPoolWorker:
self.kv_recv_thread.start()
ready_event.wait()
else:
if self.kv_role in ['kv_producer', 'kv_both']:
if self.kv_role in ['kv_producer', 'kv_both'
] or self.consumer_is_to_put:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, self.put_step,
self.tp_rank, self.dcp_size, self.put_step, self.kv_role,
ready_event_sending)
self.kv_send_thread.start()
if self.load_async:
@@ -289,6 +329,8 @@ class KVPoolWorker:
continue
request.current_event = current_event
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
request.req_id)
self.kv_send_thread.add_request( # type: ignore[union-attr]
request, )
@@ -413,11 +455,13 @@ class KVPoolWorker:
for layer_id in range(self.num_layers):
yield
def get_finished(self) -> tuple[set[str], set[str]]:
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
done_sending = (
self.kv_send_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
self.get_and_clear_finished_requests(
finished_req_ids # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both']
or self.consumer_is_to_put else set())
done_recving = (
self.kv_recv_thread.
@@ -430,6 +474,29 @@ class KVPoolWorker:
self.tp_rank)
return done_sending, done_recving
def get_and_clear_finished_requests(self, finished_req_ids) -> set[str]:
finished_sending = set()
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
):
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
req_id] == 0 and req_id in self.finished_store_req:
self.finished_store_req.remove(req_id)
finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
for req_id in finished_req_ids:
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
req_id)
if req_remain_jobs == 0:
finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
elif req_remain_jobs is not None:
self.finished_store_req.add(req_id)
return finished_sending
def lookup(
self,
token_len: int,