[BugFix][v0.18.0]Adjust request map pop time (#7857)
### What this PR does / why we need it? Adjust request map pop time.This pull request optimizes the KV cache transfer mechanism by streamlining how requests are tracked and cleaned up. By removing unnecessary mapping structures and adjusting the timing of request removal, the system achieves more efficient state management during the transfer process. pick-from:https://github.com/vllm-project/vllm-ascend/pull/7855 ### How was this patch tested? By ci <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -508,7 +508,6 @@ class KVCacheRecvingLayerThread(threading.Thread):
|
||||
self.lock = threading.Lock()
|
||||
self.done_requests = set[str]()
|
||||
self.task_tracker = dict[str, int]()
|
||||
self.request_map = dict[str, str]()
|
||||
self.ready_event = ready_event
|
||||
self.metadata = metadata
|
||||
|
||||
@@ -525,11 +524,12 @@ class KVCacheRecvingLayerThread(threading.Thread):
|
||||
|
||||
def update_task(self, req_id, trans_count):
|
||||
with self.lock:
|
||||
if req_id not in self.task_tracker:
|
||||
self.task_tracker[req_id] = 0
|
||||
self.task_tracker[req_id] += 1
|
||||
if self.task_tracker[req_id] == trans_count:
|
||||
self.task_tracker.pop(req_id)
|
||||
self.done_requests.add(self.request_map[req_id])
|
||||
self.request_map.pop(req_id)
|
||||
self.done_requests.add(req_id)
|
||||
|
||||
def run(self):
|
||||
"""Run the thread to handle KV cache transfer requests."""
|
||||
@@ -996,6 +996,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.side_channel_host = get_ip()
|
||||
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
|
||||
self.use_mla = self.vllm_config.model_config.use_mla
|
||||
self.request_map = dict[str, str]()
|
||||
if self.use_mla:
|
||||
self.total_num_kv_heads = 1
|
||||
else:
|
||||
@@ -1248,8 +1249,12 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
if self.vllm_config.kv_transfer_config.is_kv_consumer
|
||||
else set()
|
||||
)
|
||||
done_recving = {self.request_map[s] for s in done_recving if s in self.request_map}
|
||||
done_recving.update(self.virtual_request)
|
||||
self.virtual_request = set()
|
||||
for req_id in done_recving:
|
||||
req_id = req_id[:-9]
|
||||
self.request_map.pop(req_id, None)
|
||||
if len(done_recving) > 0:
|
||||
logger.info(
|
||||
f"Number of completed KV cache recv requests: {len(done_recving)}, receive requests: {done_recving}"
|
||||
@@ -1427,9 +1432,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
continue
|
||||
external_req_id = get_external_request_id(req_id)
|
||||
assert self.kv_recv_layer_thread is not None
|
||||
with self.kv_recv_layer_thread.lock:
|
||||
self.kv_recv_layer_thread.task_tracker[external_req_id] = 0
|
||||
self.kv_recv_layer_thread.request_map[external_req_id] = req_id
|
||||
self.request_map[external_req_id] = req_id
|
||||
elif self.vllm_config.kv_transfer_config.is_kv_producer:
|
||||
# update trans info
|
||||
update_metadata = {}
|
||||
@@ -1741,7 +1744,19 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx]))
|
||||
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
|
||||
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
|
||||
ack = sock.recv()
|
||||
# Avoid blocking forever waiting for the REQ/ACK response.
|
||||
sock.setsockopt(zmq.RCVTIMEO, int(self.timeout * 1000)) # type: ignore
|
||||
try:
|
||||
ack = sock.recv()
|
||||
except zmq.Again: # type: ignore
|
||||
logger.warning(
|
||||
"Timeout waiting ACK for request %s from %s:%d (timeout=%.3fs)",
|
||||
external_req_id,
|
||||
req_meta.remote_host,
|
||||
req_meta.remote_port,
|
||||
self.timeout,
|
||||
)
|
||||
return
|
||||
if ack != b"ACK":
|
||||
raise ValueError(f"Unexpected ACK response: {ack}")
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user