[KVPOOL]decode save kvcache (#5168)

### What this PR does / why we need it?

kvpool decode save kvcache
now only support mla

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: baxingpiaochong <771405853@qq.com>
Co-authored-by: Chao Lei <leichao139636@163.com>
This commit is contained in:
baxingpiaochong
2026-01-04 22:22:01 +08:00
committed by GitHub
parent 350b95efcf
commit 46c2fc6a3c
6 changed files with 156 additions and 31 deletions

View File

@@ -257,7 +257,23 @@ python3 -m vllm.entrypoints.openai.api_server \
}' }'
``` ```
#### 2.Start proxy_server. Currently, the key-value pool in PD Disaggregate only stores the kv cache generated by the Prefill node by default. In models using MLA, it is now supported that the Decode node stores the kv cache for use by the Prefill node, enabled by adding `consumer_is_to_put: true` to the AscendStoreConnector. If the Prefill node enables PP, `prefill_pp_size` or `prefill_pp_layer_partition` also needs to be set. Example as follows:
```
{
"kv_connector": "AscendStoreConnector",
"kv_role": "kv_consumer",
"kv_connector_extra_config": {
"lookup_rpc_port":"0",
"backend": "mooncake"
"consumer_is_to_put": true,
"prefill_pp_size": 2
"prefill_pp_layer_partition": "30,31"
}
}
```
#### 2、Start proxy_server.
``` ```
python vllm-ascend/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py \ python vllm-ascend/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py \

View File

@@ -34,6 +34,8 @@ class AscendStoreConnector(KVConnectorBase_V1):
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"use_layerwise", False) "use_layerwise", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
connector_name = vllm_config.kv_transfer_config.kv_connector connector_name = vllm_config.kv_transfer_config.kv_connector
if connector_name == "MooncakeConnectorStoreV1": if connector_name == "MooncakeConnectorStoreV1":
@@ -121,7 +123,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
self.connector_worker.save_kv_layer(self._get_connector_metadata()) self.connector_worker.save_kv_layer(self._get_connector_metadata())
def wait_for_save(self): def wait_for_save(self):
if self.kv_role == "kv_consumer": if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
# Don't do save if the role is kv_consumer # Don't do save if the role is kv_consumer
return return
@@ -135,7 +137,8 @@ class AscendStoreConnector(KVConnectorBase_V1):
"""Get the finished recving and sending requests.""" """Get the finished recving and sending requests."""
assert self.connector_worker is not None assert self.connector_worker is not None
meta = self._get_connector_metadata() meta = self._get_connector_metadata()
done_sending, done_recving = self.connector_worker.get_finished() done_sending, done_recving = self.connector_worker.get_finished(
finished_req_ids)
sended_and_finished: set[str] = set() sended_and_finished: set[str] = set()
for item in list(self.sended_but_unfinished_reqs): for item in list(self.sended_but_unfinished_reqs):
if item not in meta.unfinished_request_ids: if item not in meta.unfinished_request_ids:

View File

@@ -87,12 +87,14 @@ class LayerPoolKey(PoolKey):
class ChunkedTokenDatabase(): class ChunkedTokenDatabase():
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool): def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
partitions: Optional[List[int]]):
self.metadata = metadata self.metadata = metadata
self.block_size = block_size self.block_size = block_size
self.use_mla = use_mla self.use_mla = use_mla
self.kv_caches_base_addr: list[int] = [] self.kv_caches_base_addr: list[int] = []
self.block_len: list[int] = [] self.block_len: list[int] = []
self.partitions = partitions
def _make_key_by_hash(self, def _make_key_by_hash(self,
chunk_hash: str, chunk_hash: str,
@@ -188,6 +190,28 @@ class ChunkedTokenDatabase():
else: else:
yield start_idx, end_idx, self._make_key_by_hash(hash_val) yield start_idx, end_idx, self._make_key_by_hash(hash_val)
def decode_adaptor_prefill_pp(self, key, addr, size):
if self.partitions is None or len(self.partitions) == 1:
return key, addr, size
new_key = []
new_addr = []
new_size = []
for i, (addr_list, size_list) in enumerate(zip(addr, size)):
start = 0
for j, part in enumerate(self.partitions):
# part * 2 because addr and size contain both k and v
end = len(addr_list) if j == len(
self.partitions) - 1 else start + part * 2
new_str = key[i].replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{j}", 1)
new_key.append(new_str)
new_addr.append(addr_list[start:end])
new_size.append(size_list[start:end])
start = end
return new_key, new_addr, new_size
#Parameters related to the connector metadata #Parameters related to the connector metadata
@dataclass @dataclass
@@ -247,15 +271,12 @@ class RequestTracker:
def update( def update(
self, self,
new_token_ids: list[int],
new_block_ids: Union[tuple[list[int], ...], list[int]], new_block_ids: Union[tuple[list[int], ...], list[int]],
) -> None: ) -> None:
"""Update the request tracker when a running request is """Update the request tracker when a running request is
scheduled again scheduled again
""" """
self.token_len = self.token_len + len(new_token_ids)
if len(new_block_ids) == 0: if len(new_block_ids) == 0:
new_block_ids = [] new_block_ids = []
elif isinstance(new_block_ids, tuple): elif isinstance(new_block_ids, tuple):

View File

@@ -1,5 +1,6 @@
import queue import queue
import threading import threading
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any
@@ -99,7 +100,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int, put_step: int, block_size: int, tp_rank: int, dcp_size: int, put_step: int,
ready_event: threading.Event): kv_role: str, ready_event: threading.Event):
super().__init__(m_store, super().__init__(m_store,
token_database, token_database,
block_size, block_size,
@@ -108,6 +109,17 @@ class KVCacheStoreSendingThread(KVTransferThread):
ready_event, ready_event,
name="KVCacheSendingThread") name="KVCacheSendingThread")
self.put_step = put_step self.put_step = put_step
self.kv_role = kv_role
self.stored_requests = defaultdict[str, int](int)
def add_stored_request(self, req_id: str):
with self.done_task_lock:
self.stored_requests[req_id] += 1
def delete_finished_stored_request(self, req_id: str):
with self.done_task_lock:
if req_id in self.stored_requests:
del self.stored_requests[req_id]
def _handle_request(self, req_meta: ReqMeta): def _handle_request(self, req_meta: ReqMeta):
token_len = req_meta.token_len_chunk token_len = req_meta.token_len_chunk
@@ -154,13 +166,6 @@ class KVCacheStoreSendingThread(KVTransferThread):
req_id, req_id,
) )
addrs = []
sizes = []
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(
start, ends[index], block_ids)
addrs.append(addr)
sizes.append(size)
if keys: if keys:
""" """
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
@@ -168,12 +173,24 @@ class KVCacheStoreSendingThread(KVTransferThread):
You can manually build the master branch of the project at https://gitcode.com/cann/hixl You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release. to resolve this issue before the 8.5.RC1 release.
""" """
addrs = []
sizes = []
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(
start, ends[index], block_ids)
addrs.append(addr)
sizes.append(size)
if self.kv_role == "kv_consumer":
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(
keys, addrs, sizes)
if current_event is not None: if current_event is not None:
current_event.synchronize() current_event.synchronize()
self.m_store.put(keys, addrs, sizes) self.m_store.put(keys, addrs, sizes)
if is_last_chunk: with self.done_task_lock:
self.set_finished_request(req_id) self.stored_requests[req_id] -= 1
self.request_queue.task_done() self.request_queue.task_done()

View File

@@ -24,6 +24,8 @@ class KVPoolScheduler:
self.kv_role = vllm_config.kv_transfer_config.kv_role self.kv_role = vllm_config.kv_transfer_config.kv_role
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_load", False) "consumer_is_to_load", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False) "load_async", False)
self.client = LookupKeyClient(vllm_config) self.client = LookupKeyClient(vllm_config)
@@ -149,7 +151,8 @@ class KVPoolScheduler:
scheduler_output (SchedulerOutput): the scheduler output object. scheduler_output (SchedulerOutput): the scheduler output object.
""" """
force_skip_save = self.kv_role == "kv_consumer" force_skip_save = (self.kv_role == "kv_consumer"
and not self.consumer_is_to_put)
for finished_req_id in scheduler_output.finished_req_ids: for finished_req_id in scheduler_output.finished_req_ids:
self._request_trackers.pop(finished_req_id, None) self._request_trackers.pop(finished_req_id, None)
@@ -197,6 +200,7 @@ class KVPoolScheduler:
num_current_tokens = request_tracker.token_len num_current_tokens = request_tracker.token_len
new_token_ids = request.all_token_ids[ new_token_ids = request.all_token_ids[
num_current_tokens:num_current_tokens + num_new_tokens] num_current_tokens:num_current_tokens + num_new_tokens]
request_tracker.token_len += len(new_token_ids)
else: else:
raise ValueError( raise ValueError(
f"Request {req_id} is not in _unfinished_requests, " f"Request {req_id} is not in _unfinished_requests, "
@@ -204,10 +208,7 @@ class KVPoolScheduler:
new_block_ids = cached_reqs.new_block_ids[i] new_block_ids = cached_reqs.new_block_ids[i]
if not new_block_ids: if not new_block_ids:
continue continue
request_tracker.update(new_token_ids, new_block_ids) request_tracker.update(new_block_ids)
# decode not save
if request_tracker.token_len > len(request.prompt_token_ids):
continue
last_chunk_tokens_num = ((len(request.prompt_token_ids) // last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size) self._block_size * self._block_size)
@@ -270,7 +271,7 @@ class KVPoolScheduler:
Once a request is finished, determine whether request blocks Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later. should be freed now or will be sent asynchronously and freed later.
""" """
if self.kv_role == "kv_consumer": if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
return False, None return False, None
tracker = self._request_trackers.get(request.request_id) tracker = self._request_trackers.get(request.request_id)
if tracker is not None and tracker.num_saved_tokens <= 0: if tracker is not None and tracker.num_saved_tokens <= 0:

View File

@@ -61,6 +61,8 @@ class KVPoolWorker:
self.kv_role = vllm_config.kv_transfer_config.kv_role self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False) "load_async", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"backend", "mooncake") "backend", "mooncake")
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
@@ -92,9 +94,44 @@ class KVPoolWorker:
self.pp_rank, self.pp_rank,
) )
partitions = None
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
num_hidden_layers = model_config.hf_config.num_hidden_layers
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_layer_partition", None)
prefill_pp_size = int(
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_size", 1))
if partition_list_str is not None:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != prefill_pp_size:
raise ValueError(
f"{len(partitions)=} does not match {prefill_pp_size=}."
)
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}."
)
else:
layers_per_partition = num_hidden_layers // prefill_pp_size
partitions = [
layers_per_partition for _ in range(prefill_pp_size)
]
if remaining_layers := num_hidden_layers % prefill_pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
self.token_database = ChunkedTokenDatabase(self.metadata, self.token_database = ChunkedTokenDatabase(self.metadata,
self.block_size, self.block_size,
self.use_mla) self.use_mla, partitions)
real_backend = backend_map.get(self.backend.lower()) real_backend = backend_map.get(self.backend.lower())
self.m_store = real_backend( # type: ignore[misc] self.m_store = real_backend( # type: ignore[misc]
@@ -103,6 +140,8 @@ class KVPoolWorker:
self.kv_send_thread: Optional[KVTransferThread] = None self.kv_send_thread: Optional[KVTransferThread] = None
self.kv_recv_thread: Optional[KVTransferThread] = None self.kv_recv_thread: Optional[KVTransferThread] = None
self.finished_store_req: set[str] = set()
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
_, first_kv_cache_tuple = next(iter(kv_caches.items())) _, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0] first_kv_cache = first_kv_cache_tuple[0]
@@ -176,11 +215,12 @@ class KVPoolWorker:
self.kv_recv_thread.start() self.kv_recv_thread.start()
ready_event.wait() ready_event.wait()
else: else:
if self.kv_role in ['kv_producer', 'kv_both']: if self.kv_role in ['kv_producer', 'kv_both'
] or self.consumer_is_to_put:
ready_event_sending = threading.Event() ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread( self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.block_size, self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, self.put_step, self.tp_rank, self.dcp_size, self.put_step, self.kv_role,
ready_event_sending) ready_event_sending)
self.kv_send_thread.start() self.kv_send_thread.start()
if self.load_async: if self.load_async:
@@ -289,6 +329,8 @@ class KVPoolWorker:
continue continue
request.current_event = current_event request.current_event = current_event
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
request.req_id)
self.kv_send_thread.add_request( # type: ignore[union-attr] self.kv_send_thread.add_request( # type: ignore[union-attr]
request, ) request, )
@@ -413,11 +455,13 @@ class KVPoolWorker:
for layer_id in range(self.num_layers): for layer_id in range(self.num_layers):
yield yield
def get_finished(self) -> tuple[set[str], set[str]]: def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
done_sending = ( done_sending = (
self.kv_send_thread. self.get_and_clear_finished_requests(
get_and_clear_finished_requests( # type: ignore[union-attr] finished_req_ids # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both'] else set()) ) if self.kv_role in ['kv_producer', 'kv_both']
or self.consumer_is_to_put else set())
done_recving = ( done_recving = (
self.kv_recv_thread. self.kv_recv_thread.
@@ -430,6 +474,29 @@ class KVPoolWorker:
self.tp_rank) self.tp_rank)
return done_sending, done_recving return done_sending, done_recving
def get_and_clear_finished_requests(self, finished_req_ids) -> set[str]:
finished_sending = set()
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
):
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
req_id] == 0 and req_id in self.finished_store_req:
self.finished_store_req.remove(req_id)
finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
for req_id in finished_req_ids:
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
req_id)
if req_remain_jobs == 0:
finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
elif req_remain_jobs is not None:
self.finished_store_req.add(req_id)
return finished_sending
def lookup( def lookup(
self, self,
token_len: int, token_len: int,