Fix PD disaggregation bugs (#5326)

This commit is contained in:
Liangsheng Yin
2025-04-14 19:27:30 +08:00
committed by GitHub
parent 072df75354
commit 44afde82d7
2 changed files with 141 additions and 107 deletions

View File

@@ -81,7 +81,7 @@ class PrefillBootstrapQueue:
self.gloo_group = gloo_group
self.bootstrap_port = bootstrap_port
def allocate_token_id(self, idx: int, token_id: int):
def store_prefill_results(self, idx: int, token_id: int):
assert token_id >= 0, f"token_id: {token_id} is negative"
output_id_buffer = self.metadata_buffers[0]
output_id_buffer[idx] = token_id
@@ -146,7 +146,7 @@ class PrefillBootstrapQueue:
elif poll == KVPoll.Failed:
raise Exception("Bootstrap failed")
# KV.WaitingForInput - init here
# KV.WaitingForInput
num_kv_indices = len(req.origin_input_ids)
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
break
@@ -222,6 +222,7 @@ class SchedulerDisaggregationPrefillMixin:
elif poll == KVPoll.Success: # transfer done
self.tree_cache.cache_finished_req(req) # unlock the tree
req.finished_reason = FINISH_LENGTH(length=0)
# FIXME: clean up req's data in transfer engine
done_reqs.append(req)
elif poll == KVPoll.Failed:
raise Exception("Transferring failed")
@@ -256,14 +257,18 @@ class SchedulerDisaggregationPrefillMixin:
"""
start_idx = req.start_send_idx
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
# Update next start_send_idx
req.start_send_idx = end_idx
kv_indices = (
self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
.cpu()
.numpy()
)
req.start_send_idx = end_idx
if token_id is not None:
self.disagg_prefill_pending_queue.allocate_token_id(
self.disagg_prefill_pending_queue.store_prefill_results(
req.metadata_buffer_index, token_id
)
req.disagg_kv_sender.send(kv_indices)
is_last = token_id is not None
req.disagg_kv_sender.send(kv_indices, slice(start_idx, end_idx), is_last)