[Bugfix]fix_mulit_connector_bug (#3332)
### What this PR does / why we need it?
When using multi connector, the multi connector does not define
get_finished_count, which will cause the kv cache to be released
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19
---------
Signed-off-by: baxingpiaochong <771405853@qq.com>
This commit is contained in:
@@ -673,10 +673,6 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
|||||||
self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3])
|
self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3])
|
||||||
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
|
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
|
||||||
|
|
||||||
def test_get_finished_count(self):
|
|
||||||
count = self.scheduler.get_finished_count()
|
|
||||||
self.assertEqual(count, 2)
|
|
||||||
|
|
||||||
|
|
||||||
class TestHelperFunctions(unittest.TestCase):
|
class TestHelperFunctions(unittest.TestCase):
|
||||||
|
|
||||||
|
|||||||
@@ -80,6 +80,10 @@ class KVCacheTaskTracker:
|
|||||||
self.record_finished_requests: set[str] = set()
|
self.record_finished_requests: set[str] = set()
|
||||||
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
|
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
|
||||||
|
|
||||||
|
def add_not_transfer_request(self, request_id: str):
|
||||||
|
with self.done_task_lock:
|
||||||
|
self.finished_requests.add(request_id)
|
||||||
|
|
||||||
def update_done_task_count(self, request_id: str):
|
def update_done_task_count(self, request_id: str):
|
||||||
with self.done_task_lock:
|
with self.done_task_lock:
|
||||||
self.finished_requests.add(request_id)
|
self.finished_requests.add(request_id)
|
||||||
@@ -157,6 +161,9 @@ class KVCacheSendingThread(threading.Thread):
|
|||||||
"""
|
"""
|
||||||
return self.task_tracker.get_and_clear_finished_requests()
|
return self.task_tracker.get_and_clear_finished_requests()
|
||||||
|
|
||||||
|
def add_not_transfer_request(self, request_id: str):
|
||||||
|
self.task_tracker.add_not_transfer_request(request_id)
|
||||||
|
|
||||||
def add_delayed_request(self, request_id: str, delay_start_time: float):
|
def add_delayed_request(self, request_id: str, delay_start_time: float):
|
||||||
return self.task_tracker.add_delayed_request(request_id,
|
return self.task_tracker.add_delayed_request(request_id,
|
||||||
delay_start_time)
|
delay_start_time)
|
||||||
@@ -658,10 +665,6 @@ class MooncakeConnector(KVConnectorBase_V1):
|
|||||||
assert self.connector_scheduler is not None
|
assert self.connector_scheduler is not None
|
||||||
return self.connector_scheduler.request_finished(request, block_ids)
|
return self.connector_scheduler.request_finished(request, block_ids)
|
||||||
|
|
||||||
def get_finished_count(self) -> Optional[int]:
|
|
||||||
assert self.connector_scheduler is not None
|
|
||||||
return self.connector_scheduler.get_finished_count()
|
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# Worker Side Methods
|
# Worker Side Methods
|
||||||
############################################################
|
############################################################
|
||||||
@@ -846,39 +849,6 @@ class MooncakeConnectorScheduler:
|
|||||||
last_token_id=request.output_token_ids[-1],
|
last_token_id=request.output_token_ids[-1],
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_finished_count(self) -> Optional[int]:
|
|
||||||
prefill_parallel_config: dict[
|
|
||||||
str,
|
|
||||||
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
|
|
||||||
"prefill", {})
|
|
||||||
|
|
||||||
assert "tp_size" in prefill_parallel_config.keys()
|
|
||||||
self._prefill_tp_size = prefill_parallel_config["tp_size"]
|
|
||||||
decode_parallel_config: dict[
|
|
||||||
str,
|
|
||||||
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
|
|
||||||
"decode", {})
|
|
||||||
assert "tp_size" in decode_parallel_config.keys()
|
|
||||||
self._decode_tp_size = decode_parallel_config["tp_size"]
|
|
||||||
num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
|
|
||||||
if self.vllm_config.model_config.use_mla or hasattr(
|
|
||||||
self.vllm_config.model_config.hf_config, "index_topk"):
|
|
||||||
num_need_pulls = 1
|
|
||||||
else:
|
|
||||||
num_p_block_heads = max(
|
|
||||||
1, num_key_value_heads // self._prefill_tp_size)
|
|
||||||
num_d_block_heads = max(
|
|
||||||
1, num_key_value_heads // self._decode_tp_size)
|
|
||||||
num_need_pulls = num_d_block_heads // num_p_block_heads
|
|
||||||
kv_role = self.vllm_config.kv_transfer_config.kv_role
|
|
||||||
logger.debug(
|
|
||||||
"get_finished_count, kv_role=%s, num_need_pulls=%d, decode_tp_size=%d",
|
|
||||||
kv_role, num_need_pulls, self._decode_tp_size)
|
|
||||||
if kv_role == 'kv_producer':
|
|
||||||
return num_need_pulls * self._decode_tp_size
|
|
||||||
else:
|
|
||||||
return self._decode_tp_size
|
|
||||||
|
|
||||||
|
|
||||||
class MooncakeConnectorWorker:
|
class MooncakeConnectorWorker:
|
||||||
"""Implementation of Worker side methods"""
|
"""Implementation of Worker side methods"""
|
||||||
@@ -1150,6 +1120,8 @@ class MooncakeConnectorWorker:
|
|||||||
if self.tp_rank in self._prefill_get_remote_tp_rank(req_id):
|
if self.tp_rank in self._prefill_get_remote_tp_rank(req_id):
|
||||||
self.kv_send_thread.add_delayed_request(
|
self.kv_send_thread.add_delayed_request(
|
||||||
req_id, delay_start_time)
|
req_id, delay_start_time)
|
||||||
|
else:
|
||||||
|
self.kv_send_thread.add_not_transfer_request(req_id)
|
||||||
|
|
||||||
def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]:
|
def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]:
|
||||||
return sum(self._get_remote_tp_ranks_for_req(req_id), [])
|
return sum(self._get_remote_tp_ranks_for_req(req_id), [])
|
||||||
|
|||||||
Reference in New Issue
Block a user