[0.11.0][Bugfix] fix delay free prefill req & D node support prefix cache (#3609)

### What this PR does / why we need it?
Fix mooncake connector. In scenarios where TP is not equal, when the
prefill TP size is less than the number of key-value heads,
_get_remote_tp_ranks_for_req will return a list of np.arrays. Performing
an operation like int in list of np.arrays will cause an error.
Converting the list of np.arrays into a single np.array resolves this
issue.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
qwen235B
P tp16, D tp1
P tp8, D tp1
P tp4, D tp1
P tp8, D tp2


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
liziyu
2025-10-23 20:39:35 +08:00
committed by GitHub
parent 6975d46627
commit f3ea657e93
2 changed files with 29 additions and 24 deletions

View File

@@ -342,9 +342,15 @@ class KVCacheRecvingThread(threading.Thread):
# Full prefix cache hit: do not need to read remote blocks, just notify
# P worker that we have the blocks we need.
if len(local_block_ids) == 0:
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
return
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
# Check if we have the remote metadata cached.
if remote_engine_id not in self.kv_caches_base_addr or \
remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]:
@@ -736,13 +742,11 @@ class MooncakeConnectorScheduler:
num_computed_tokens, params)
if params is not None and params.get("do_remote_prefill"):
assert num_computed_tokens == 0, "Currently only support " \
"prefill with num_computed_tokens == 0."
# Assume that the request's KV cache is already fully prefilled and
# can be fetched entirely from the prefill node.
count = max(len(request.prompt_token_ids) - 1, 0)
if count > 0:
return count, True
# Remote prefill: get all prompt blocks from remote.
assert num_computed_tokens % self.block_size == 0
# Note: We use the full token count as transmit data here.
count = max(len(request.prompt_token_ids) - num_computed_tokens, 0)
return count, count > 0
# No remote prefill for this request.
return 0, False
@@ -1137,10 +1141,13 @@ class MooncakeConnectorWorker:
if self.kv_send_thread is not None:
for req_id, delay_start_time in metadata.requests_to_send.items():
if self.tp_rank in self._get_remote_tp_ranks_for_req(req_id):
if self.tp_rank in self._prefill_get_remote_tp_rank(req_id):
self.kv_send_thread.add_delayed_request(
req_id, delay_start_time)
def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]:
return sum(self._get_remote_tp_ranks_for_req(req_id), [])
def _get_remote_tp_rank(self, req_id: str) -> List[int]:
return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank]
@@ -1176,8 +1183,8 @@ class MooncakeConnectorWorker:
else:
group_size = self._prefill_tp_size // self._decode_tp_size
for i in range(self._decode_tp_size):
slice = ori_data[i * group_size:(i + 1) * group_size]
sampled_nums.append(slice)
ori_data_slice = ori_data[i * group_size:(i + 1) * group_size]
sampled_nums.append(ori_data_slice.tolist())
return sampled_nums