[PD] Support PD disaggregation with Prefill PP (#8846)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: root <huzhiyuan@xiaohongshu.com>
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Francis <38564764+ssssnow@users.noreply.github.com>
Co-authored-by: zitto <zhjc1124@gmail.com>
This commit is contained in:
Shangming Cai
2025-08-17 09:31:31 +08:00
committed by GitHub
parent 6a9d6ca33c
commit 384f8ab5ce
11 changed files with 632 additions and 82 deletions

View File

@@ -30,6 +30,7 @@ class KVArgs:
# for pp prefill
prefill_pp_size: int
pp_rank: int
prefill_start_layer: int
# for system dp
system_dp_rank: int

View File

@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.common.utils import (
)
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.dp_attention import (
get_attention_dp_rank,
get_attention_dp_size,
@@ -180,6 +181,7 @@ class MooncakeKVManager(BaseKVManager):
self.session_failures = defaultdict(int)
self.failed_sessions = set()
self.session_lock = threading.Lock()
self.pp_group = get_pp_group()
# Determine the number of threads to use for kv sender
cpu_count = os.cpu_count()
transfer_thread_pool_size = get_int_env_var(
@@ -313,11 +315,11 @@ class MooncakeKVManager(BaseKVManager):
layers_params = None
# pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
if self.is_mla_backend:
src_kv_ptrs = self.kv_args.kv_data_ptrs
layers_per_pp_stage = len(src_kv_ptrs)
start_layer = self.pp_rank * layers_per_pp_stage
end_layer = start_layer + layers_per_pp_stage
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [
@@ -330,17 +332,15 @@ class MooncakeKVManager(BaseKVManager):
]
else:
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
dst_num_total_layers = num_kv_layers * self.pp_size
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
start_layer = self.pp_rank * layers_per_pp_stage
end_layer = start_layer + layers_per_pp_stage
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
num_kv_layers + start_layer : num_kv_layers + end_layer
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [
(
src_k_ptrs[layer_id],
@@ -452,6 +452,7 @@ class MooncakeKVManager(BaseKVManager):
# pp is not supported on the decode side yet
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
dst_num_total_layers = num_kv_layers * self.pp_size
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
@@ -459,7 +460,7 @@ class MooncakeKVManager(BaseKVManager):
end_layer = start_layer + layers_per_pp_stage
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
num_kv_layers + start_layer : num_kv_layers + end_layer
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
# Calculate precise byte offset and length for the sub-slice within the token
@@ -612,7 +613,7 @@ class MooncakeKVManager(BaseKVManager):
)
polls = []
dst_ranks_infos = []
local_rank = self.kv_args.engine_rank
local_rank = self.attn_tp_rank * self.pp_size + self.pp_rank
for req in reqs_to_be_processed:
if not req.is_dummy:
# Early exit if the request has failed
@@ -695,13 +696,14 @@ class MooncakeKVManager(BaseKVManager):
break
if kv_chunk.is_last:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
target_rank_registration_info.dst_aux_ptrs,
req.dst_aux_index,
)
if self.pp_group.is_last_rank:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
target_rank_registration_info.dst_aux_ptrs,
req.dst_aux_index,
)
polls.append(True if ret == 0 else False)
dst_ranks_infos.append(
(req.endpoint, req.dst_port, req.room)
@@ -798,10 +800,7 @@ class MooncakeKVManager(BaseKVManager):
arrived_response_num = len(
self.prefill_response_tracker[bootstrap_room]
)
if (
self.is_mla_backend
or arrived_response_num == expected_response_num
):
if arrived_response_num == expected_response_num:
self.update_status(bootstrap_room, KVPoll.Success)
elif status == KVPoll.Failed:
self.record_failure(
@@ -1183,7 +1182,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
)
self.required_dst_info_num = 1
self.required_prefill_response_num = 1
self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
)
self.target_tp_ranks = [self.target_tp_rank]
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
if not self.kv_mgr.is_mla_backend:
@@ -1196,7 +1197,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.required_dst_info_num = (
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
)
self.required_prefill_response_num = 1
self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
)
self.target_tp_ranks = [self.target_tp_rank]
else:
if not self.kv_mgr.is_mla_backend:
@@ -1219,9 +1222,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
# or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
self.required_prefill_response_num = (
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
)
if self.kv_mgr.is_mla_backend:
self.required_prefill_response_num = (
self.prefill_pp_size // self.kv_mgr.pp_size
)
else:
self.required_prefill_response_num = (
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
@@ -1530,7 +1538,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
"rank_port": rank_port,
}
logger.debug(
f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
)
return web.Response(text="OK", status=200)

View File

@@ -43,8 +43,13 @@ from sglang.srt.disaggregation.utils import (
prepare_abort,
)
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import require_mlp_sync
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.utils import (
DynamicGradMode,
broadcast_pyobj,
point_to_point_pyobj,
require_mlp_sync,
)
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
@@ -107,6 +112,7 @@ class PrefillBootstrapQueue:
kv_args.system_dp_rank = self.scheduler.dp_rank
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
kv_args.prefill_pp_size = self.pp_size
kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
@@ -208,8 +214,8 @@ class PrefillBootstrapQueue:
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.queue], self.gloo_group
)
for i, (req, poll) in enumerate(zip(self.queue, polls)):
for i, (req, poll) in enumerate(zip(self.queue, polls)):
if rids_to_check is not None:
# if req not in reqs_info_to_check, skip
if req.rid not in rids_to_check:
@@ -395,7 +401,10 @@ class SchedulerDisaggregationPrefillMixin:
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
self.disagg_prefill_inflight_queue.append(req)
if logits_output.hidden_states is not None:
if (
logits_output is not None
and logits_output.hidden_states is not None
):
last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1
)
@@ -603,3 +612,250 @@ class SchedulerDisaggregationPrefillMixin:
)
return
req.disagg_kv_sender.send(page_indices)
# PP
@DynamicGradMode()
def event_loop_pp_disagg_prefill(self: Scheduler):
"""
An event loop for the prefill server in pipeline parallelism.
Rules:
1. Each stage runs in the same order and is notified by the previous stage.
2. Each send/recv operation is blocking and matched by the neighboring stage.
Regular Schedule:
====================================================================
Stage i | Stage i+1
send ith req | recv ith req
send ith proxy | recv ith proxy
send prev (i+1)th carry | recv prev (i+1)th carry
====================================================================
Prefill Server Schedule:
====================================================================
Stage i | Stage i+1
send ith req | recv ith req
send ith bootstrap req | recv ith bootstrap req
send ith transferred req | recv ith transferred req
send ith proxy | recv ith proxy
send prev (i+1)th carry | recv prev (i+1)th carry
send prev (i+1)th release req | recv prev (i+1)th release req
====================================================================
There are two additional elements compared to the regular schedule:
1. Bootstrap Requests:
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
2. Transferred Requests + Release Requests:
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
"""
from sglang.srt.managers.scheduler import GenerationBatchResult
mbs = [None] * self.pp_size
last_mbs = [None] * self.pp_size
self.running_mbs = [
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
]
bids = [None] * self.pp_size
pp_outputs: Optional[PPProxyTensors] = None
# Either success or failed
bootstrapped_rids: List[str] = []
transferred_rids: List[str] = []
release_rids: Optional[List[str]] = None
# transferred microbatch
tmbs = [None] * self.pp_size
ENABLE_RELEASE = True # For debug
while True:
server_is_idle = True
for mb_id in range(self.pp_size):
self.running_batch = self.running_mbs[mb_id]
self.last_batch = last_mbs[mb_id]
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
if self.pp_group.is_first_rank:
# First rank, pop the bootstrap reqs from the bootstrap queue
bootstrapped_reqs, failed_reqs = (
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
return_failed_reqs=True
)
)
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
req.rid for req in failed_reqs
]
self.waiting_queue.extend(bootstrapped_reqs)
else:
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
bootstrapped_reqs = (
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
rids_to_check=bootstrapped_rids
)
)
self.waiting_queue.extend(bootstrapped_reqs)
if self.pp_group.is_first_rank:
transferred_rids = self.get_transferred_rids()
# if other ranks,
else:
# 1. recv previous stage's transferred reqs info
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
# 2. get the current stage's transferred reqs info
curr_transferred_rids = self.get_transferred_rids()
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
transferred_rids = list(
set(prev_transferred_rids) & set(curr_transferred_rids)
)
tmbs[mb_id] = transferred_rids
self.process_prefill_chunk()
mbs[mb_id] = self.get_new_batch_prefill()
self.running_mbs[mb_id] = self.running_batch
self.cur_batch = mbs[mb_id]
if self.cur_batch:
server_is_idle = False
result = self.run_batch(self.cur_batch)
# send the outputs to the next step
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids, bids[mb_id] = (
result.next_token_ids,
result.bid,
)
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
}
)
# send the output from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
if ENABLE_RELEASE:
if self.pp_group.is_last_rank:
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
release_rids = transferred_rids
# send to the first rank
self.send_pyobj_to_next_stage(release_rids)
# receive outputs and post-process (filter finished reqs) the coming microbatch
next_mb_id = (mb_id + 1) % self.pp_size
next_pp_outputs = None
next_release_rids = None
if mbs[next_mb_id] is not None:
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
self.pp_group.recv_tensor_dict(
all_gather_group=self.attn_tp_group
)
)
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
output_result = GenerationBatchResult(
logits_output=None,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=None,
extend_logprob_start_len_per_req=None,
bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph,
)
self.process_batch_result_disagg_prefill(
mbs[next_mb_id], output_result
)
last_mbs[next_mb_id] = mbs[next_mb_id]
if ENABLE_RELEASE:
if tmbs[next_mb_id] is not None:
# recv consensus rids from the previous rank
next_release_rids = self.recv_pyobj_from_prev_stage()
self.process_disagg_prefill_inflight_queue(next_release_rids)
# carry the outputs to the next stage
if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
if pp_outputs:
# send the outputs from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
if ENABLE_RELEASE:
if release_rids is not None:
self.send_pyobj_to_next_stage(release_rids)
if not self.pp_group.is_last_rank:
# send out reqs to the next stage
self.send_pyobj_to_next_stage(recv_reqs)
self.send_pyobj_to_next_stage(bootstrapped_rids)
self.send_pyobj_to_next_stage(transferred_rids)
# send out proxy tensors to the next stage
if self.cur_batch:
self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors,
all_gather_group=self.attn_tp_group,
)
pp_outputs = next_pp_outputs
release_rids = next_release_rids
self.running_batch.batch_is_full = False
if not ENABLE_RELEASE:
if len(self.disagg_prefill_inflight_queue) > 0:
self.process_disagg_prefill_inflight_queue()
# When the server is idle, self-check and re-init some states
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
def send_pyobj_to_next_stage(self, data):
if self.attn_tp_rank == 0:
dp_offset = self.attn_dp_rank * self.attn_tp_size
point_to_point_pyobj(
data,
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
self.pp_rank * self.tp_size + dp_offset,
((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset,
)
def recv_pyobj_from_prev_stage(self):
if self.attn_tp_rank == 0:
dp_offset = self.attn_dp_rank * self.attn_tp_size
data = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset,
self.pp_rank * self.tp_size + dp_offset,
)
else:
data = None
if self.tp_size != 1:
data = broadcast_pyobj(
data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0]
)
return data