[KVPOOL]decode save kvcache (#5168)

### What this PR does / why we need it?

kvpool decode save kvcache
now only support mla

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: baxingpiaochong <771405853@qq.com>
Co-authored-by: Chao Lei <leichao139636@163.com>
This commit is contained in:
baxingpiaochong
2026-01-04 22:22:01 +08:00
committed by GitHub
parent 350b95efcf
commit 46c2fc6a3c
6 changed files with 156 additions and 31 deletions

View File

@@ -257,7 +257,23 @@ python3 -m vllm.entrypoints.openai.api_server \
}'
```
#### 2.Start proxy_server.
Currently, the key-value pool in PD Disaggregate only stores the kv cache generated by the Prefill node by default. In models using MLA, it is now supported that the Decode node stores the kv cache for use by the Prefill node, enabled by adding `consumer_is_to_put: true` to the AscendStoreConnector. If the Prefill node enables PP, `prefill_pp_size` or `prefill_pp_layer_partition` also needs to be set. Example as follows:
```
{
"kv_connector": "AscendStoreConnector",
"kv_role": "kv_consumer",
"kv_connector_extra_config": {
"lookup_rpc_port":"0",
"backend": "mooncake"
"consumer_is_to_put": true,
"prefill_pp_size": 2
"prefill_pp_layer_partition": "30,31"
}
}
```
#### 2、Start proxy_server.
```
python vllm-ascend/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py \

View File

@@ -34,6 +34,8 @@ class AscendStoreConnector(KVConnectorBase_V1):
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"use_layerwise", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
connector_name = vllm_config.kv_transfer_config.kv_connector
if connector_name == "MooncakeConnectorStoreV1":
@@ -121,7 +123,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
self.connector_worker.save_kv_layer(self._get_connector_metadata())
def wait_for_save(self):
if self.kv_role == "kv_consumer":
if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
# Don't do save if the role is kv_consumer
return
@@ -135,7 +137,8 @@ class AscendStoreConnector(KVConnectorBase_V1):
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
meta = self._get_connector_metadata()
done_sending, done_recving = self.connector_worker.get_finished()
done_sending, done_recving = self.connector_worker.get_finished(
finished_req_ids)
sended_and_finished: set[str] = set()
for item in list(self.sended_but_unfinished_reqs):
if item not in meta.unfinished_request_ids:

View File

@@ -87,12 +87,14 @@ class LayerPoolKey(PoolKey):
class ChunkedTokenDatabase():
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool):
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
partitions: Optional[List[int]]):
self.metadata = metadata
self.block_size = block_size
self.use_mla = use_mla
self.kv_caches_base_addr: list[int] = []
self.block_len: list[int] = []
self.partitions = partitions
def _make_key_by_hash(self,
chunk_hash: str,
@@ -188,6 +190,28 @@ class ChunkedTokenDatabase():
else:
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
def decode_adaptor_prefill_pp(self, key, addr, size):
if self.partitions is None or len(self.partitions) == 1:
return key, addr, size
new_key = []
new_addr = []
new_size = []
for i, (addr_list, size_list) in enumerate(zip(addr, size)):
start = 0
for j, part in enumerate(self.partitions):
# part * 2 because addr and size contain both k and v
end = len(addr_list) if j == len(
self.partitions) - 1 else start + part * 2
new_str = key[i].replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{j}", 1)
new_key.append(new_str)
new_addr.append(addr_list[start:end])
new_size.append(size_list[start:end])
start = end
return new_key, new_addr, new_size
#Parameters related to the connector metadata
@dataclass
@@ -247,15 +271,12 @@ class RequestTracker:
def update(
self,
new_token_ids: list[int],
new_block_ids: Union[tuple[list[int], ...], list[int]],
) -> None:
"""Update the request tracker when a running request is
scheduled again
"""
self.token_len = self.token_len + len(new_token_ids)
if len(new_block_ids) == 0:
new_block_ids = []
elif isinstance(new_block_ids, tuple):

View File

@@ -1,5 +1,6 @@
import queue
import threading
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any
@@ -99,7 +100,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
ready_event: threading.Event):
kv_role: str, ready_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
@@ -108,6 +109,17 @@ class KVCacheStoreSendingThread(KVTransferThread):
ready_event,
name="KVCacheSendingThread")
self.put_step = put_step
self.kv_role = kv_role
self.stored_requests = defaultdict[str, int](int)
def add_stored_request(self, req_id: str):
with self.done_task_lock:
self.stored_requests[req_id] += 1
def delete_finished_stored_request(self, req_id: str):
with self.done_task_lock:
if req_id in self.stored_requests:
del self.stored_requests[req_id]
def _handle_request(self, req_meta: ReqMeta):
token_len = req_meta.token_len_chunk
@@ -154,13 +166,6 @@ class KVCacheStoreSendingThread(KVTransferThread):
req_id,
)
addrs = []
sizes = []
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(
start, ends[index], block_ids)
addrs.append(addr)
sizes.append(size)
if keys:
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
@@ -168,12 +173,24 @@ class KVCacheStoreSendingThread(KVTransferThread):
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
addrs = []
sizes = []
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(
start, ends[index], block_ids)
addrs.append(addr)
sizes.append(size)
if self.kv_role == "kv_consumer":
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(
keys, addrs, sizes)
if current_event is not None:
current_event.synchronize()
self.m_store.put(keys, addrs, sizes)
if is_last_chunk:
self.set_finished_request(req_id)
with self.done_task_lock:
self.stored_requests[req_id] -= 1
self.request_queue.task_done()

View File

@@ -24,6 +24,8 @@ class KVPoolScheduler:
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_load", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.client = LookupKeyClient(vllm_config)
@@ -149,7 +151,8 @@ class KVPoolScheduler:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
force_skip_save = self.kv_role == "kv_consumer"
force_skip_save = (self.kv_role == "kv_consumer"
and not self.consumer_is_to_put)
for finished_req_id in scheduler_output.finished_req_ids:
self._request_trackers.pop(finished_req_id, None)
@@ -197,6 +200,7 @@ class KVPoolScheduler:
num_current_tokens = request_tracker.token_len
new_token_ids = request.all_token_ids[
num_current_tokens:num_current_tokens + num_new_tokens]
request_tracker.token_len += len(new_token_ids)
else:
raise ValueError(
f"Request {req_id} is not in _unfinished_requests, "
@@ -204,10 +208,7 @@ class KVPoolScheduler:
new_block_ids = cached_reqs.new_block_ids[i]
if not new_block_ids:
continue
request_tracker.update(new_token_ids, new_block_ids)
# decode not save
if request_tracker.token_len > len(request.prompt_token_ids):
continue
request_tracker.update(new_block_ids)
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
@@ -270,7 +271,7 @@ class KVPoolScheduler:
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
if self.kv_role == "kv_consumer":
if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
return False, None
tracker = self._request_trackers.get(request.request_id)
if tracker is not None and tracker.num_saved_tokens <= 0:

View File

@@ -61,6 +61,8 @@ class KVPoolWorker:
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"backend", "mooncake")
self.block_size = vllm_config.cache_config.block_size
@@ -92,9 +94,44 @@ class KVPoolWorker:
self.pp_rank,
)
partitions = None
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
num_hidden_layers = model_config.hf_config.num_hidden_layers
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_layer_partition", None)
prefill_pp_size = int(
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_size", 1))
if partition_list_str is not None:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != prefill_pp_size:
raise ValueError(
f"{len(partitions)=} does not match {prefill_pp_size=}."
)
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}."
)
else:
layers_per_partition = num_hidden_layers // prefill_pp_size
partitions = [
layers_per_partition for _ in range(prefill_pp_size)
]
if remaining_layers := num_hidden_layers % prefill_pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
self.token_database = ChunkedTokenDatabase(self.metadata,
self.block_size,
self.use_mla)
self.use_mla, partitions)
real_backend = backend_map.get(self.backend.lower())
self.m_store = real_backend( # type: ignore[misc]
@@ -103,6 +140,8 @@ class KVPoolWorker:
self.kv_send_thread: Optional[KVTransferThread] = None
self.kv_recv_thread: Optional[KVTransferThread] = None
self.finished_store_req: set[str] = set()
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
@@ -176,11 +215,12 @@ class KVPoolWorker:
self.kv_recv_thread.start()
ready_event.wait()
else:
if self.kv_role in ['kv_producer', 'kv_both']:
if self.kv_role in ['kv_producer', 'kv_both'
] or self.consumer_is_to_put:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, self.put_step,
self.tp_rank, self.dcp_size, self.put_step, self.kv_role,
ready_event_sending)
self.kv_send_thread.start()
if self.load_async:
@@ -289,6 +329,8 @@ class KVPoolWorker:
continue
request.current_event = current_event
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
request.req_id)
self.kv_send_thread.add_request( # type: ignore[union-attr]
request, )
@@ -413,11 +455,13 @@ class KVPoolWorker:
for layer_id in range(self.num_layers):
yield
def get_finished(self) -> tuple[set[str], set[str]]:
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
done_sending = (
self.kv_send_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
self.get_and_clear_finished_requests(
finished_req_ids # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both']
or self.consumer_is_to_put else set())
done_recving = (
self.kv_recv_thread.
@@ -430,6 +474,29 @@ class KVPoolWorker:
self.tp_rank)
return done_sending, done_recving
def get_and_clear_finished_requests(self, finished_req_ids) -> set[str]:
finished_sending = set()
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
):
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
req_id] == 0 and req_id in self.finished_store_req:
self.finished_store_req.remove(req_id)
finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
for req_id in finished_req_ids:
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
req_id)
if req_remain_jobs == 0:
finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
elif req_remain_jobs is not None:
self.finished_store_req.add(req_id)
return finished_sending
def lookup(
self,
token_len: int,