[BugFix][v0.18.0]Adjust request map pop time (#7857)

### What this PR does / why we need it?
Adjust request map pop time.This pull request optimizes the KV cache
transfer mechanism by streamlining how requests are tracked and cleaned
up. By removing unnecessary mapping structures and adjusting the timing
of request removal, the system achieves more efficient state management
during the transfer process.
pick-from:https://github.com/vllm-project/vllm-ascend/pull/7855


### How was this patch tested?
By ci
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
wangxiaoteng888
2026-03-31 18:55:36 +08:00
committed by GitHub
parent 66db070423
commit 82e26b5a6e

View File

@@ -508,7 +508,6 @@ class KVCacheRecvingLayerThread(threading.Thread):
self.lock = threading.Lock() self.lock = threading.Lock()
self.done_requests = set[str]() self.done_requests = set[str]()
self.task_tracker = dict[str, int]() self.task_tracker = dict[str, int]()
self.request_map = dict[str, str]()
self.ready_event = ready_event self.ready_event = ready_event
self.metadata = metadata self.metadata = metadata
@@ -525,11 +524,12 @@ class KVCacheRecvingLayerThread(threading.Thread):
def update_task(self, req_id, trans_count): def update_task(self, req_id, trans_count):
with self.lock: with self.lock:
if req_id not in self.task_tracker:
self.task_tracker[req_id] = 0
self.task_tracker[req_id] += 1 self.task_tracker[req_id] += 1
if self.task_tracker[req_id] == trans_count: if self.task_tracker[req_id] == trans_count:
self.task_tracker.pop(req_id) self.task_tracker.pop(req_id)
self.done_requests.add(self.request_map[req_id]) self.done_requests.add(req_id)
self.request_map.pop(req_id)
def run(self): def run(self):
"""Run the thread to handle KV cache transfer requests.""" """Run the thread to handle KV cache transfer requests."""
@@ -996,6 +996,7 @@ class MooncakeLayerwiseConnectorWorker:
self.side_channel_host = get_ip() self.side_channel_host = get_ip()
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
self.use_mla = self.vllm_config.model_config.use_mla self.use_mla = self.vllm_config.model_config.use_mla
self.request_map = dict[str, str]()
if self.use_mla: if self.use_mla:
self.total_num_kv_heads = 1 self.total_num_kv_heads = 1
else: else:
@@ -1248,8 +1249,12 @@ class MooncakeLayerwiseConnectorWorker:
if self.vllm_config.kv_transfer_config.is_kv_consumer if self.vllm_config.kv_transfer_config.is_kv_consumer
else set() else set()
) )
done_recving = {self.request_map[s] for s in done_recving if s in self.request_map}
done_recving.update(self.virtual_request) done_recving.update(self.virtual_request)
self.virtual_request = set() self.virtual_request = set()
for req_id in done_recving:
req_id = req_id[:-9]
self.request_map.pop(req_id, None)
if len(done_recving) > 0: if len(done_recving) > 0:
logger.info( logger.info(
f"Number of completed KV cache recv requests: {len(done_recving)}, receive requests: {done_recving}" f"Number of completed KV cache recv requests: {len(done_recving)}, receive requests: {done_recving}"
@@ -1427,9 +1432,7 @@ class MooncakeLayerwiseConnectorWorker:
continue continue
external_req_id = get_external_request_id(req_id) external_req_id = get_external_request_id(req_id)
assert self.kv_recv_layer_thread is not None assert self.kv_recv_layer_thread is not None
with self.kv_recv_layer_thread.lock: self.request_map[external_req_id] = req_id
self.kv_recv_layer_thread.task_tracker[external_req_id] = 0
self.kv_recv_layer_thread.request_map[external_req_id] = req_id
elif self.vllm_config.kv_transfer_config.is_kv_producer: elif self.vllm_config.kv_transfer_config.is_kv_producer:
# update trans info # update trans info
update_metadata = {} update_metadata = {}
@@ -1741,7 +1744,19 @@ class MooncakeLayerwiseConnectorWorker:
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx])) encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx]))
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}") ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
ack = sock.recv() # Avoid blocking forever waiting for the REQ/ACK response.
sock.setsockopt(zmq.RCVTIMEO, int(self.timeout * 1000)) # type: ignore
try:
ack = sock.recv()
except zmq.Again: # type: ignore
logger.warning(
"Timeout waiting ACK for request %s from %s:%d (timeout=%.3fs)",
external_req_id,
req_meta.remote_host,
req_meta.remote_port,
self.timeout,
)
return
if ack != b"ACK": if ack != b"ACK":
raise ValueError(f"Unexpected ACK response: {ack}") raise ValueError(f"Unexpected ACK response: {ack}")
except Exception as e: except Exception as e: