diff --git a/docs/source/user_guide/feature_guide/kv_pool.md b/docs/source/user_guide/feature_guide/kv_pool.md index 2ba0ea9e..4371ce92 100644 --- a/docs/source/user_guide/feature_guide/kv_pool.md +++ b/docs/source/user_guide/feature_guide/kv_pool.md @@ -257,7 +257,23 @@ python3 -m vllm.entrypoints.openai.api_server \ }' ``` -#### 2.Start proxy_server. +Currently, the key-value pool in PD Disaggregate only stores the kv cache generated by the Prefill node by default. In models using MLA, it is now supported that the Decode node stores the kv cache for use by the Prefill node, enabled by adding `consumer_is_to_put: true` to the AscendStoreConnector. If the Prefill node enables PP, `prefill_pp_size` or `prefill_pp_layer_partition` also needs to be set. Example as follows: + +``` +{ + "kv_connector": "AscendStoreConnector", + "kv_role": "kv_consumer", + "kv_connector_extra_config": { + "lookup_rpc_port":"0", + "backend": "mooncake" + "consumer_is_to_put": true, + "prefill_pp_size": 2 + "prefill_pp_layer_partition": "30,31" + } +} +``` + +#### 2、Start proxy_server. ``` python vllm-ascend/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py \ diff --git a/vllm_ascend/distributed/kvpool/ascend_store_connector.py b/vllm_ascend/distributed/kvpool/ascend_store_connector.py index 093f3c07..753806c3 100644 --- a/vllm_ascend/distributed/kvpool/ascend_store_connector.py +++ b/vllm_ascend/distributed/kvpool/ascend_store_connector.py @@ -34,6 +34,8 @@ class AscendStoreConnector(KVConnectorBase_V1): self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "use_layerwise", False) + self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "consumer_is_to_put", False) connector_name = vllm_config.kv_transfer_config.kv_connector if connector_name == "MooncakeConnectorStoreV1": @@ -121,7 +123,7 @@ class AscendStoreConnector(KVConnectorBase_V1): self.connector_worker.save_kv_layer(self._get_connector_metadata()) def wait_for_save(self): - if self.kv_role == "kv_consumer": + if self.kv_role == "kv_consumer" and not self.consumer_is_to_put: # Don't do save if the role is kv_consumer return @@ -135,7 +137,8 @@ class AscendStoreConnector(KVConnectorBase_V1): """Get the finished recving and sending requests.""" assert self.connector_worker is not None meta = self._get_connector_metadata() - done_sending, done_recving = self.connector_worker.get_finished() + done_sending, done_recving = self.connector_worker.get_finished( + finished_req_ids) sended_and_finished: set[str] = set() for item in list(self.sended_but_unfinished_reqs): if item not in meta.unfinished_request_ids: diff --git a/vllm_ascend/distributed/kvpool/config_data.py b/vllm_ascend/distributed/kvpool/config_data.py index 5de21350..8800b5f5 100644 --- a/vllm_ascend/distributed/kvpool/config_data.py +++ b/vllm_ascend/distributed/kvpool/config_data.py @@ -87,12 +87,14 @@ class LayerPoolKey(PoolKey): class ChunkedTokenDatabase(): - def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool): + def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool, + partitions: Optional[List[int]]): self.metadata = metadata self.block_size = block_size self.use_mla = use_mla self.kv_caches_base_addr: list[int] = [] self.block_len: list[int] = [] + self.partitions = partitions def _make_key_by_hash(self, chunk_hash: str, @@ -188,6 +190,28 @@ class ChunkedTokenDatabase(): else: yield start_idx, end_idx, self._make_key_by_hash(hash_val) + def decode_adaptor_prefill_pp(self, key, addr, size): + if self.partitions is None or len(self.partitions) == 1: + return key, addr, size + + new_key = [] + new_addr = [] + new_size = [] + + for i, (addr_list, size_list) in enumerate(zip(addr, size)): + start = 0 + for j, part in enumerate(self.partitions): + # part * 2 because addr and size contain both k and v + end = len(addr_list) if j == len( + self.partitions) - 1 else start + part * 2 + new_str = key[i].replace( # type: ignore[attr-defined] + "@pp_rank:0", f"@pp_rank:{j}", 1) + new_key.append(new_str) + new_addr.append(addr_list[start:end]) + new_size.append(size_list[start:end]) + start = end + return new_key, new_addr, new_size + #Parameters related to the connector metadata @dataclass @@ -247,15 +271,12 @@ class RequestTracker: def update( self, - new_token_ids: list[int], new_block_ids: Union[tuple[list[int], ...], list[int]], ) -> None: """Update the request tracker when a running request is scheduled again """ - self.token_len = self.token_len + len(new_token_ids) - if len(new_block_ids) == 0: new_block_ids = [] elif isinstance(new_block_ids, tuple): @@ -378,4 +399,4 @@ class LasyerMultiBlockReqMeta: block_ids: list[int] layer_id: int is_last_chunk: Optional[bool] = True - current_event: Optional[torch.npu.Event] = None \ No newline at end of file + current_event: Optional[torch.npu.Event] = None diff --git a/vllm_ascend/distributed/kvpool/kv_transfer.py b/vllm_ascend/distributed/kvpool/kv_transfer.py index 02d79c88..bfb6eba0 100644 --- a/vllm_ascend/distributed/kvpool/kv_transfer.py +++ b/vllm_ascend/distributed/kvpool/kv_transfer.py @@ -1,5 +1,6 @@ import queue import threading +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from typing import Any @@ -99,7 +100,7 @@ class KVCacheStoreSendingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, block_size: int, tp_rank: int, dcp_size: int, put_step: int, - ready_event: threading.Event): + kv_role: str, ready_event: threading.Event): super().__init__(m_store, token_database, block_size, @@ -108,6 +109,17 @@ class KVCacheStoreSendingThread(KVTransferThread): ready_event, name="KVCacheSendingThread") self.put_step = put_step + self.kv_role = kv_role + self.stored_requests = defaultdict[str, int](int) + + def add_stored_request(self, req_id: str): + with self.done_task_lock: + self.stored_requests[req_id] += 1 + + def delete_finished_stored_request(self, req_id: str): + with self.done_task_lock: + if req_id in self.stored_requests: + del self.stored_requests[req_id] def _handle_request(self, req_meta: ReqMeta): token_len = req_meta.token_len_chunk @@ -154,13 +166,6 @@ class KVCacheStoreSendingThread(KVTransferThread): req_id, ) - addrs = [] - sizes = [] - for index, start in enumerate(starts): - addr, size, _ = self.token_database.prepare_value( - start, ends[index], block_ids) - addrs.append(addr) - sizes.append(size) if keys: """ Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. @@ -168,12 +173,24 @@ class KVCacheStoreSendingThread(KVTransferThread): You can manually build the master branch of the project at https://gitcode.com/cann/hixl to resolve this issue before the 8.5.RC1 release. """ + addrs = [] + sizes = [] + for index, start in enumerate(starts): + addr, size, _ = self.token_database.prepare_value( + start, ends[index], block_ids) + addrs.append(addr) + sizes.append(size) + + if self.kv_role == "kv_consumer": + keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp( + keys, addrs, sizes) + if current_event is not None: current_event.synchronize() self.m_store.put(keys, addrs, sizes) - if is_last_chunk: - self.set_finished_request(req_id) + with self.done_task_lock: + self.stored_requests[req_id] -= 1 self.request_queue.task_done() diff --git a/vllm_ascend/distributed/kvpool/pool_scheduler.py b/vllm_ascend/distributed/kvpool/pool_scheduler.py index 9e1d982a..30199fd9 100644 --- a/vllm_ascend/distributed/kvpool/pool_scheduler.py +++ b/vllm_ascend/distributed/kvpool/pool_scheduler.py @@ -24,6 +24,8 @@ class KVPoolScheduler: self.kv_role = vllm_config.kv_transfer_config.kv_role self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "consumer_is_to_load", False) + self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "consumer_is_to_put", False) self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "load_async", False) self.client = LookupKeyClient(vllm_config) @@ -149,7 +151,8 @@ class KVPoolScheduler: scheduler_output (SchedulerOutput): the scheduler output object. """ - force_skip_save = self.kv_role == "kv_consumer" + force_skip_save = (self.kv_role == "kv_consumer" + and not self.consumer_is_to_put) for finished_req_id in scheduler_output.finished_req_ids: self._request_trackers.pop(finished_req_id, None) @@ -197,6 +200,7 @@ class KVPoolScheduler: num_current_tokens = request_tracker.token_len new_token_ids = request.all_token_ids[ num_current_tokens:num_current_tokens + num_new_tokens] + request_tracker.token_len += len(new_token_ids) else: raise ValueError( f"Request {req_id} is not in _unfinished_requests, " @@ -204,10 +208,7 @@ class KVPoolScheduler: new_block_ids = cached_reqs.new_block_ids[i] if not new_block_ids: continue - request_tracker.update(new_token_ids, new_block_ids) - # decode not save - if request_tracker.token_len > len(request.prompt_token_ids): - continue + request_tracker.update(new_block_ids) last_chunk_tokens_num = ((len(request.prompt_token_ids) // self._block_size * self._block_size) @@ -270,7 +271,7 @@ class KVPoolScheduler: Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later. """ - if self.kv_role == "kv_consumer": + if self.kv_role == "kv_consumer" and not self.consumer_is_to_put: return False, None tracker = self._request_trackers.get(request.request_id) if tracker is not None and tracker.num_saved_tokens <= 0: diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index 97bc4c5f..8a5e6718 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -61,6 +61,8 @@ class KVPoolWorker: self.kv_role = vllm_config.kv_transfer_config.kv_role self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "load_async", False) + self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "consumer_is_to_put", False) self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "backend", "mooncake") self.block_size = vllm_config.cache_config.block_size @@ -92,9 +94,44 @@ class KVPoolWorker: self.pp_rank, ) + partitions = None + if self.kv_role == "kv_consumer" and self.consumer_is_to_put: + num_hidden_layers = model_config.hf_config.num_hidden_layers + partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "prefill_pp_layer_partition", None) + prefill_pp_size = int( + vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "prefill_pp_size", 1)) + + if partition_list_str is not None: + try: + partitions = [ + int(layer) for layer in partition_list_str.split(",") + ] + except ValueError as err: + raise ValueError("Invalid partition string: {}".format( + partition_list_str)) from err + if len(partitions) != prefill_pp_size: + raise ValueError( + f"{len(partitions)=} does not match {prefill_pp_size=}." + ) + if sum(partitions) != num_hidden_layers: + raise ValueError( + f"{sum(partitions)=} does not match {num_hidden_layers=}." + ) + else: + layers_per_partition = num_hidden_layers // prefill_pp_size + partitions = [ + layers_per_partition for _ in range(prefill_pp_size) + ] + + if remaining_layers := num_hidden_layers % prefill_pp_size: + for i in range(2, remaining_layers + 2): + partitions[-i] += 1 + self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, - self.use_mla) + self.use_mla, partitions) real_backend = backend_map.get(self.backend.lower()) self.m_store = real_backend( # type: ignore[misc] @@ -103,6 +140,8 @@ class KVPoolWorker: self.kv_send_thread: Optional[KVTransferThread] = None self.kv_recv_thread: Optional[KVTransferThread] = None + self.finished_store_req: set[str] = set() + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): _, first_kv_cache_tuple = next(iter(kv_caches.items())) first_kv_cache = first_kv_cache_tuple[0] @@ -176,11 +215,12 @@ class KVPoolWorker: self.kv_recv_thread.start() ready_event.wait() else: - if self.kv_role in ['kv_producer', 'kv_both']: + if self.kv_role in ['kv_producer', 'kv_both' + ] or self.consumer_is_to_put: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreSendingThread( self.m_store, self.token_database, self.block_size, - self.tp_rank, self.dcp_size, self.put_step, + self.tp_rank, self.dcp_size, self.put_step, self.kv_role, ready_event_sending) self.kv_send_thread.start() if self.load_async: @@ -289,6 +329,8 @@ class KVPoolWorker: continue request.current_event = current_event + self.kv_send_thread.add_stored_request( # type: ignore[union-attr] + request.req_id) self.kv_send_thread.add_request( # type: ignore[union-attr] request, ) @@ -413,11 +455,13 @@ class KVPoolWorker: for layer_id in range(self.num_layers): yield - def get_finished(self) -> tuple[set[str], set[str]]: + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: done_sending = ( - self.kv_send_thread. - get_and_clear_finished_requests( # type: ignore[union-attr] - ) if self.kv_role in ['kv_producer', 'kv_both'] else set()) + self.get_and_clear_finished_requests( + finished_req_ids # type: ignore[union-attr] + ) if self.kv_role in ['kv_producer', 'kv_both'] + or self.consumer_is_to_put else set()) done_recving = ( self.kv_recv_thread. @@ -430,6 +474,29 @@ class KVPoolWorker: self.tp_rank) return done_sending, done_recving + def get_and_clear_finished_requests(self, finished_req_ids) -> set[str]: + finished_sending = set() + for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr] + ): + if self.kv_send_thread.stored_requests[ # type: ignore[union-attr] + req_id] == 0 and req_id in self.finished_store_req: + self.finished_store_req.remove(req_id) + finished_sending.add(req_id) + self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr] + req_id) + + for req_id in finished_req_ids: + req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr] + req_id) + if req_remain_jobs == 0: + finished_sending.add(req_id) + self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr] + req_id) + elif req_remain_jobs is not None: + self.finished_store_req.add(req_id) + + return finished_sending + def lookup( self, token_len: int,