[BugFix][v0.18.0]Adjust request map pop time (#7857)

### What this PR does / why we need it?
Adjust request map pop time.This pull request optimizes the KV cache
transfer mechanism by streamlining how requests are tracked and cleaned
up. By removing unnecessary mapping structures and adjusting the timing
of request removal, the system achieves more efficient state management
during the transfer process.
pick-from:https://github.com/vllm-project/vllm-ascend/pull/7855


### How was this patch tested?
By ci
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
wangxiaoteng888
2026-03-31 18:55:36 +08:00
committed by GitHub
parent 66db070423
commit 82e26b5a6e

View File

@@ -508,7 +508,6 @@ class KVCacheRecvingLayerThread(threading.Thread):
self.lock = threading.Lock()
self.done_requests = set[str]()
self.task_tracker = dict[str, int]()
self.request_map = dict[str, str]()
self.ready_event = ready_event
self.metadata = metadata
@@ -525,11 +524,12 @@ class KVCacheRecvingLayerThread(threading.Thread):
def update_task(self, req_id, trans_count):
with self.lock:
if req_id not in self.task_tracker:
self.task_tracker[req_id] = 0
self.task_tracker[req_id] += 1
if self.task_tracker[req_id] == trans_count:
self.task_tracker.pop(req_id)
self.done_requests.add(self.request_map[req_id])
self.request_map.pop(req_id)
self.done_requests.add(req_id)
def run(self):
"""Run the thread to handle KV cache transfer requests."""
@@ -996,6 +996,7 @@ class MooncakeLayerwiseConnectorWorker:
self.side_channel_host = get_ip()
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
self.use_mla = self.vllm_config.model_config.use_mla
self.request_map = dict[str, str]()
if self.use_mla:
self.total_num_kv_heads = 1
else:
@@ -1248,8 +1249,12 @@ class MooncakeLayerwiseConnectorWorker:
if self.vllm_config.kv_transfer_config.is_kv_consumer
else set()
)
done_recving = {self.request_map[s] for s in done_recving if s in self.request_map}
done_recving.update(self.virtual_request)
self.virtual_request = set()
for req_id in done_recving:
req_id = req_id[:-9]
self.request_map.pop(req_id, None)
if len(done_recving) > 0:
logger.info(
f"Number of completed KV cache recv requests: {len(done_recving)}, receive requests: {done_recving}"
@@ -1427,9 +1432,7 @@ class MooncakeLayerwiseConnectorWorker:
continue
external_req_id = get_external_request_id(req_id)
assert self.kv_recv_layer_thread is not None
with self.kv_recv_layer_thread.lock:
self.kv_recv_layer_thread.task_tracker[external_req_id] = 0
self.kv_recv_layer_thread.request_map[external_req_id] = req_id
self.request_map[external_req_id] = req_id
elif self.vllm_config.kv_transfer_config.is_kv_producer:
# update trans info
update_metadata = {}
@@ -1741,7 +1744,19 @@ class MooncakeLayerwiseConnectorWorker:
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx]))
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
ack = sock.recv()
# Avoid blocking forever waiting for the REQ/ACK response.
sock.setsockopt(zmq.RCVTIMEO, int(self.timeout * 1000)) # type: ignore
try:
ack = sock.recv()
except zmq.Again: # type: ignore
logger.warning(
"Timeout waiting ACK for request %s from %s:%d (timeout=%.3fs)",
external_req_id,
req_meta.remote_host,
req_meta.remote_port,
self.timeout,
)
return
if ack != b"ACK":
raise ValueError(f"Unexpected ACK response: {ack}")
except Exception as e: