### What this PR does / why we need it?
When using multi connector, the multi connector does not define
get_finished_count, which will cause the kv cache to be released ###
Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19
Signed-off-by: baxingpiaochong <771405853@qq.com>
Co-authored-by: baxingpiaochong <771405853@qq.com>
This commit is contained in:
@@ -74,6 +74,10 @@ class KVCacheTaskTracker:
|
||||
self.record_finished_requests: set[str] = set()
|
||||
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
|
||||
|
||||
def add_not_transfer_request(self, request_id: str):
|
||||
with self.done_task_lock:
|
||||
self.finished_requests.add(request_id)
|
||||
|
||||
def update_done_task_count(self, request_id: str):
|
||||
with self.done_task_lock:
|
||||
self.finished_requests.add(request_id)
|
||||
@@ -151,6 +155,9 @@ class KVCacheSendingThread(threading.Thread):
|
||||
"""
|
||||
return self.task_tracker.get_and_clear_finished_requests()
|
||||
|
||||
def add_not_transfer_request(self, request_id: str):
|
||||
self.task_tracker.add_not_transfer_request(request_id)
|
||||
|
||||
def add_delayed_request(self, request_id: str, delay_start_time: float):
|
||||
return self.task_tracker.add_delayed_request(request_id,
|
||||
delay_start_time)
|
||||
@@ -652,10 +659,6 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
def get_finished_count(self) -> Optional[int]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_finished_count()
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
@@ -840,39 +843,6 @@ class MooncakeConnectorScheduler:
|
||||
last_token_id=request.output_token_ids[-1],
|
||||
)
|
||||
|
||||
def get_finished_count(self) -> Optional[int]:
|
||||
prefill_parallel_config: dict[
|
||||
str,
|
||||
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"prefill", {})
|
||||
|
||||
assert "tp_size" in prefill_parallel_config.keys()
|
||||
self._prefill_tp_size = prefill_parallel_config["tp_size"]
|
||||
decode_parallel_config: dict[
|
||||
str,
|
||||
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"decode", {})
|
||||
assert "tp_size" in decode_parallel_config.keys()
|
||||
self._decode_tp_size = decode_parallel_config["tp_size"]
|
||||
num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||
if self.vllm_config.model_config.use_mla or hasattr(
|
||||
self.vllm_config.model_config.hf_config, "index_topk"):
|
||||
num_need_pulls = 1
|
||||
else:
|
||||
num_p_block_heads = max(
|
||||
1, num_key_value_heads // self._prefill_tp_size)
|
||||
num_d_block_heads = max(
|
||||
1, num_key_value_heads // self._decode_tp_size)
|
||||
num_need_pulls = num_d_block_heads // num_p_block_heads
|
||||
kv_role = self.vllm_config.kv_transfer_config.kv_role
|
||||
logger.debug(
|
||||
"get_finished_count, kv_role=%s, num_need_pulls=%d, decode_tp_size=%d",
|
||||
kv_role, num_need_pulls, self._decode_tp_size)
|
||||
if kv_role == 'kv_producer':
|
||||
return num_need_pulls * self._decode_tp_size
|
||||
else:
|
||||
return self._decode_tp_size
|
||||
|
||||
|
||||
class MooncakeConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
@@ -1144,6 +1114,8 @@ class MooncakeConnectorWorker:
|
||||
if self.tp_rank in self._prefill_get_remote_tp_rank(req_id):
|
||||
self.kv_send_thread.add_delayed_request(
|
||||
req_id, delay_start_time)
|
||||
else:
|
||||
self.kv_send_thread.add_not_transfer_request(req_id)
|
||||
|
||||
def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]:
|
||||
return sum(self._get_remote_tp_ranks_for_req(req_id), [])
|
||||
|
||||
Reference in New Issue
Block a user