[bugfix] resolve kv cache leak on P-side due to incorrect req_id (#6325)

### What this PR does / why we need it?
This PR fixes a critical bug in the PD-separated inference pipeline
where KV cache on the Prefill (P) side was not being properly released.
The issue arises when multiple clients use the same x-request-id: to
avoid request ID collisions, both Prefill and Decode nodes append a
random suffix to the incoming x-request-id. A previous PR ensured
consistency by having the P-side pass its final request_id as
remote_request_id to the D-side via kv_transfer_param. However, during
KV cache cleanup, the D-side incorrectly used the local req_id (instead
of remote_request_id) to select the target P-side rank. This mismatch
caused the P-side KV cache to remain unreleased on certain ranks,
leading to memory leaks. This PR corrects the logic to use
remote_request_id consistently when determining the P-side rank.
### Does this PR introduce _any_ user-facing change?
No. 
### How was this patch tested?
The fix was validated by running multiple concurrent benchmark instances

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

Signed-off-by: ghphotoframe <854746559@qq.com>
This commit is contained in:
JiangWeixiang
2026-01-29 16:05:56 +08:00
committed by GitHub
parent 597091be9f
commit 41a52beb26

View File

@@ -1544,6 +1544,7 @@ class MooncakeConnectorWorker:
prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size) tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
remote_req_id = meta.remote_request_id
if meta.remote_pcp_size * meta.remote_dcp_size > 1: if meta.remote_pcp_size * meta.remote_dcp_size > 1:
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata( remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
@@ -1562,7 +1563,7 @@ class MooncakeConnectorWorker:
) )
self.kv_recv_thread.add_request( self.kv_recv_thread.add_request(
request_id=req_id, request_id=req_id,
remote_request_id=meta.remote_request_id, remote_request_id=remote_req_id,
local_block_ids=local_block_ids_list[pcp_dcp_rank], local_block_ids=local_block_ids_list[pcp_dcp_rank],
remote_block_ids=remote_block_ids_list[pcp_dcp_rank], remote_block_ids=remote_block_ids_list[pcp_dcp_rank],
remote_engine_id=remote_engine_id, remote_engine_id=remote_engine_id,
@@ -1576,7 +1577,7 @@ class MooncakeConnectorWorker:
), ),
) )
else: # TODO: support prefill context parallel and pipeline parallel open at the same time else: # TODO: support prefill context parallel and pipeline parallel open at the same time
choosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size) choosen_rank_list = self._get_remote_rank(remote_req_id, prefill_tp_size)
remote_handshake_port_list = [[x + meta.remote_port] for x in choosen_rank_list] remote_handshake_port_list = [[x + meta.remote_port] for x in choosen_rank_list]
for i in range(tp_num_need_pulls * self._prefill_pp_size): for i in range(tp_num_need_pulls * self._prefill_pp_size):
assert self.kv_recv_thread is not None assert self.kv_recv_thread is not None
@@ -1589,7 +1590,7 @@ class MooncakeConnectorWorker:
) )
self.kv_recv_thread.add_request( self.kv_recv_thread.add_request(
request_id=req_id, request_id=req_id,
remote_request_id=meta.remote_request_id, remote_request_id=remote_req_id,
local_block_ids=meta.local_block_ids, local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids, remote_block_ids=meta.remote_block_ids,
remote_engine_id=remote_engine_id, remote_engine_id=remote_engine_id,