diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 6c6c609..9a4084d 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -667,10 +667,6 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase): self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3]) self.assertEqual(len(self.scheduler._reqs_need_recv), 0) - def test_get_finished_count(self): - count = self.scheduler.get_finished_count() - self.assertEqual(count, 2) - class TestHelperFunctions(unittest.TestCase): diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 57b4494..0118dcd 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -74,6 +74,10 @@ class KVCacheTaskTracker: self.record_finished_requests: set[str] = set() self.delayed_free_requests: OrderedDict[str, float] = OrderedDict() + def add_not_transfer_request(self, request_id: str): + with self.done_task_lock: + self.finished_requests.add(request_id) + def update_done_task_count(self, request_id: str): with self.done_task_lock: self.finished_requests.add(request_id) @@ -151,6 +155,9 @@ class KVCacheSendingThread(threading.Thread): """ return self.task_tracker.get_and_clear_finished_requests() + def add_not_transfer_request(self, request_id: str): + self.task_tracker.add_not_transfer_request(request_id) + def add_delayed_request(self, request_id: str, delay_start_time: float): return self.task_tracker.add_delayed_request(request_id, delay_start_time) @@ -652,10 +659,6 @@ class MooncakeConnector(KVConnectorBase_V1): assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) - def get_finished_count(self) -> Optional[int]: - assert self.connector_scheduler is not None - return self.connector_scheduler.get_finished_count() - ############################################################ # Worker Side Methods ############################################################ @@ -840,39 +843,6 @@ class MooncakeConnectorScheduler: last_token_id=request.output_token_ids[-1], ) - def get_finished_count(self) -> Optional[int]: - prefill_parallel_config: dict[ - str, - Any] = self.vllm_config.kv_transfer_config.get_from_extra_config( - "prefill", {}) - - assert "tp_size" in prefill_parallel_config.keys() - self._prefill_tp_size = prefill_parallel_config["tp_size"] - decode_parallel_config: dict[ - str, - Any] = self.vllm_config.kv_transfer_config.get_from_extra_config( - "decode", {}) - assert "tp_size" in decode_parallel_config.keys() - self._decode_tp_size = decode_parallel_config["tp_size"] - num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads - if self.vllm_config.model_config.use_mla or hasattr( - self.vllm_config.model_config.hf_config, "index_topk"): - num_need_pulls = 1 - else: - num_p_block_heads = max( - 1, num_key_value_heads // self._prefill_tp_size) - num_d_block_heads = max( - 1, num_key_value_heads // self._decode_tp_size) - num_need_pulls = num_d_block_heads // num_p_block_heads - kv_role = self.vllm_config.kv_transfer_config.kv_role - logger.debug( - "get_finished_count, kv_role=%s, num_need_pulls=%d, decode_tp_size=%d", - kv_role, num_need_pulls, self._decode_tp_size) - if kv_role == 'kv_producer': - return num_need_pulls * self._decode_tp_size - else: - return self._decode_tp_size - class MooncakeConnectorWorker: """Implementation of Worker side methods""" @@ -1144,6 +1114,8 @@ class MooncakeConnectorWorker: if self.tp_rank in self._prefill_get_remote_tp_rank(req_id): self.kv_send_thread.add_delayed_request( req_id, delay_start_time) + else: + self.kv_send_thread.add_not_transfer_request(req_id) def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]: return sum(self._get_remote_tp_ranks_for_req(req_id), [])