[PD] Support PD disaggregation with Prefill PP (#8846)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Signed-off-by: Shangming Cai <csmthu@gmail.com> Co-authored-by: root <huzhiyuan@xiaohongshu.com> Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Francis <38564764+ssssnow@users.noreply.github.com> Co-authored-by: zitto <zhjc1124@gmail.com>
This commit is contained in:
@@ -30,6 +30,7 @@ class KVArgs:
|
||||
# for pp prefill
|
||||
prefill_pp_size: int
|
||||
pp_rank: int
|
||||
prefill_start_layer: int
|
||||
# for system dp
|
||||
system_dp_rank: int
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.common.utils import (
|
||||
)
|
||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.distributed import get_pp_group
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_dp_rank,
|
||||
get_attention_dp_size,
|
||||
@@ -180,6 +181,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
self.session_failures = defaultdict(int)
|
||||
self.failed_sessions = set()
|
||||
self.session_lock = threading.Lock()
|
||||
self.pp_group = get_pp_group()
|
||||
# Determine the number of threads to use for kv sender
|
||||
cpu_count = os.cpu_count()
|
||||
transfer_thread_pool_size = get_int_env_var(
|
||||
@@ -313,11 +315,11 @@ class MooncakeKVManager(BaseKVManager):
|
||||
layers_params = None
|
||||
|
||||
# pp is not supported on the decode side yet
|
||||
start_layer = self.kv_args.prefill_start_layer
|
||||
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
|
||||
if self.is_mla_backend:
|
||||
src_kv_ptrs = self.kv_args.kv_data_ptrs
|
||||
layers_per_pp_stage = len(src_kv_ptrs)
|
||||
start_layer = self.pp_rank * layers_per_pp_stage
|
||||
end_layer = start_layer + layers_per_pp_stage
|
||||
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||
layers_params = [
|
||||
@@ -330,17 +332,15 @@ class MooncakeKVManager(BaseKVManager):
|
||||
]
|
||||
else:
|
||||
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
||||
dst_num_total_layers = num_kv_layers * self.pp_size
|
||||
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
||||
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
||||
layers_per_pp_stage = len(src_k_ptrs)
|
||||
start_layer = self.pp_rank * layers_per_pp_stage
|
||||
end_layer = start_layer + layers_per_pp_stage
|
||||
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||
dst_v_ptrs = dst_kv_ptrs[
|
||||
num_kv_layers + start_layer : num_kv_layers + end_layer
|
||||
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
||||
]
|
||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||
|
||||
layers_params = [
|
||||
(
|
||||
src_k_ptrs[layer_id],
|
||||
@@ -452,6 +452,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
|
||||
# pp is not supported on the decode side yet
|
||||
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
||||
dst_num_total_layers = num_kv_layers * self.pp_size
|
||||
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
||||
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
||||
layers_per_pp_stage = len(src_k_ptrs)
|
||||
@@ -459,7 +460,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
end_layer = start_layer + layers_per_pp_stage
|
||||
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||
dst_v_ptrs = dst_kv_ptrs[
|
||||
num_kv_layers + start_layer : num_kv_layers + end_layer
|
||||
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
||||
]
|
||||
|
||||
# Calculate precise byte offset and length for the sub-slice within the token
|
||||
@@ -612,7 +613,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
)
|
||||
polls = []
|
||||
dst_ranks_infos = []
|
||||
local_rank = self.kv_args.engine_rank
|
||||
local_rank = self.attn_tp_rank * self.pp_size + self.pp_rank
|
||||
for req in reqs_to_be_processed:
|
||||
if not req.is_dummy:
|
||||
# Early exit if the request has failed
|
||||
@@ -695,13 +696,14 @@ class MooncakeKVManager(BaseKVManager):
|
||||
break
|
||||
|
||||
if kv_chunk.is_last:
|
||||
# Only the last chunk we need to send the aux data
|
||||
ret = self.send_aux(
|
||||
req.mooncake_session_id,
|
||||
kv_chunk.prefill_aux_index,
|
||||
target_rank_registration_info.dst_aux_ptrs,
|
||||
req.dst_aux_index,
|
||||
)
|
||||
if self.pp_group.is_last_rank:
|
||||
# Only the last chunk we need to send the aux data
|
||||
ret = self.send_aux(
|
||||
req.mooncake_session_id,
|
||||
kv_chunk.prefill_aux_index,
|
||||
target_rank_registration_info.dst_aux_ptrs,
|
||||
req.dst_aux_index,
|
||||
)
|
||||
polls.append(True if ret == 0 else False)
|
||||
dst_ranks_infos.append(
|
||||
(req.endpoint, req.dst_port, req.room)
|
||||
@@ -798,10 +800,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
arrived_response_num = len(
|
||||
self.prefill_response_tracker[bootstrap_room]
|
||||
)
|
||||
if (
|
||||
self.is_mla_backend
|
||||
or arrived_response_num == expected_response_num
|
||||
):
|
||||
if arrived_response_num == expected_response_num:
|
||||
self.update_status(bootstrap_room, KVPoll.Success)
|
||||
elif status == KVPoll.Failed:
|
||||
self.record_failure(
|
||||
@@ -1183,7 +1182,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
||||
)
|
||||
self.required_dst_info_num = 1
|
||||
self.required_prefill_response_num = 1
|
||||
self.required_prefill_response_num = 1 * (
|
||||
self.prefill_pp_size // self.kv_mgr.pp_size
|
||||
)
|
||||
self.target_tp_ranks = [self.target_tp_rank]
|
||||
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
|
||||
if not self.kv_mgr.is_mla_backend:
|
||||
@@ -1196,7 +1197,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.required_dst_info_num = (
|
||||
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
|
||||
)
|
||||
self.required_prefill_response_num = 1
|
||||
self.required_prefill_response_num = 1 * (
|
||||
self.prefill_pp_size // self.kv_mgr.pp_size
|
||||
)
|
||||
self.target_tp_ranks = [self.target_tp_rank]
|
||||
else:
|
||||
if not self.kv_mgr.is_mla_backend:
|
||||
@@ -1219,9 +1222,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
# or the KVPoll will never be set correctly
|
||||
self.target_tp_rank = self.target_tp_ranks[0]
|
||||
self.required_dst_info_num = 1
|
||||
self.required_prefill_response_num = (
|
||||
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
||||
)
|
||||
if self.kv_mgr.is_mla_backend:
|
||||
self.required_prefill_response_num = (
|
||||
self.prefill_pp_size // self.kv_mgr.pp_size
|
||||
)
|
||||
else:
|
||||
self.required_prefill_response_num = (
|
||||
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
||||
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
|
||||
|
||||
if self.data_parallel_rank is not None:
|
||||
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
||||
@@ -1530,7 +1538,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
||||
"rank_port": rank_port,
|
||||
}
|
||||
logger.debug(
|
||||
f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
||||
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
||||
)
|
||||
|
||||
return web.Response(text="OK", status=200)
|
||||
|
||||
@@ -43,8 +43,13 @@ from sglang.srt.disaggregation.utils import (
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import require_mlp_sync
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||
from sglang.srt.utils import (
|
||||
DynamicGradMode,
|
||||
broadcast_pyobj,
|
||||
point_to_point_pyobj,
|
||||
require_mlp_sync,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -107,6 +112,7 @@ class PrefillBootstrapQueue:
|
||||
kv_args.system_dp_rank = self.scheduler.dp_rank
|
||||
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
|
||||
kv_args.prefill_pp_size = self.pp_size
|
||||
kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer
|
||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
@@ -208,8 +214,8 @@ class PrefillBootstrapQueue:
|
||||
polls = poll_and_all_reduce(
|
||||
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
||||
)
|
||||
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
||||
|
||||
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
||||
if rids_to_check is not None:
|
||||
# if req not in reqs_info_to_check, skip
|
||||
if req.rid not in rids_to_check:
|
||||
@@ -395,7 +401,10 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
req.output_ids.append(next_token_id)
|
||||
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
||||
self.disagg_prefill_inflight_queue.append(req)
|
||||
if logits_output.hidden_states is not None:
|
||||
if (
|
||||
logits_output is not None
|
||||
and logits_output.hidden_states is not None
|
||||
):
|
||||
last_hidden_index = (
|
||||
hidden_state_offset + extend_input_len_per_req[i] - 1
|
||||
)
|
||||
@@ -603,3 +612,250 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
)
|
||||
return
|
||||
req.disagg_kv_sender.send(page_indices)
|
||||
|
||||
# PP
|
||||
@DynamicGradMode()
|
||||
def event_loop_pp_disagg_prefill(self: Scheduler):
|
||||
"""
|
||||
An event loop for the prefill server in pipeline parallelism.
|
||||
|
||||
Rules:
|
||||
1. Each stage runs in the same order and is notified by the previous stage.
|
||||
2. Each send/recv operation is blocking and matched by the neighboring stage.
|
||||
|
||||
Regular Schedule:
|
||||
====================================================================
|
||||
Stage i | Stage i+1
|
||||
send ith req | recv ith req
|
||||
send ith proxy | recv ith proxy
|
||||
send prev (i+1)th carry | recv prev (i+1)th carry
|
||||
====================================================================
|
||||
|
||||
Prefill Server Schedule:
|
||||
====================================================================
|
||||
Stage i | Stage i+1
|
||||
send ith req | recv ith req
|
||||
send ith bootstrap req | recv ith bootstrap req
|
||||
send ith transferred req | recv ith transferred req
|
||||
send ith proxy | recv ith proxy
|
||||
send prev (i+1)th carry | recv prev (i+1)th carry
|
||||
send prev (i+1)th release req | recv prev (i+1)th release req
|
||||
====================================================================
|
||||
|
||||
There are two additional elements compared to the regular schedule:
|
||||
|
||||
1. Bootstrap Requests:
|
||||
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
|
||||
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
|
||||
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
|
||||
|
||||
2. Transferred Requests + Release Requests:
|
||||
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
|
||||
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
|
||||
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
|
||||
"""
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
|
||||
mbs = [None] * self.pp_size
|
||||
last_mbs = [None] * self.pp_size
|
||||
self.running_mbs = [
|
||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||
]
|
||||
bids = [None] * self.pp_size
|
||||
pp_outputs: Optional[PPProxyTensors] = None
|
||||
|
||||
# Either success or failed
|
||||
bootstrapped_rids: List[str] = []
|
||||
transferred_rids: List[str] = []
|
||||
release_rids: Optional[List[str]] = None
|
||||
|
||||
# transferred microbatch
|
||||
tmbs = [None] * self.pp_size
|
||||
|
||||
ENABLE_RELEASE = True # For debug
|
||||
|
||||
while True:
|
||||
server_is_idle = True
|
||||
|
||||
for mb_id in range(self.pp_size):
|
||||
self.running_batch = self.running_mbs[mb_id]
|
||||
self.last_batch = last_mbs[mb_id]
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
if self.pp_group.is_first_rank:
|
||||
# First rank, pop the bootstrap reqs from the bootstrap queue
|
||||
bootstrapped_reqs, failed_reqs = (
|
||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
||||
return_failed_reqs=True
|
||||
)
|
||||
)
|
||||
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
|
||||
req.rid for req in failed_reqs
|
||||
]
|
||||
self.waiting_queue.extend(bootstrapped_reqs)
|
||||
else:
|
||||
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
|
||||
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
|
||||
bootstrapped_reqs = (
|
||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
||||
rids_to_check=bootstrapped_rids
|
||||
)
|
||||
)
|
||||
self.waiting_queue.extend(bootstrapped_reqs)
|
||||
|
||||
if self.pp_group.is_first_rank:
|
||||
transferred_rids = self.get_transferred_rids()
|
||||
# if other ranks,
|
||||
else:
|
||||
# 1. recv previous stage's transferred reqs info
|
||||
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
|
||||
# 2. get the current stage's transferred reqs info
|
||||
curr_transferred_rids = self.get_transferred_rids()
|
||||
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
|
||||
transferred_rids = list(
|
||||
set(prev_transferred_rids) & set(curr_transferred_rids)
|
||||
)
|
||||
|
||||
tmbs[mb_id] = transferred_rids
|
||||
|
||||
self.process_prefill_chunk()
|
||||
mbs[mb_id] = self.get_new_batch_prefill()
|
||||
self.running_mbs[mb_id] = self.running_batch
|
||||
|
||||
self.cur_batch = mbs[mb_id]
|
||||
if self.cur_batch:
|
||||
server_is_idle = False
|
||||
result = self.run_batch(self.cur_batch)
|
||||
|
||||
# send the outputs to the next step
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
next_token_ids, bids[mb_id] = (
|
||||
result.next_token_ids,
|
||||
result.bid,
|
||||
)
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
}
|
||||
)
|
||||
# send the output from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
if ENABLE_RELEASE:
|
||||
if self.pp_group.is_last_rank:
|
||||
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
|
||||
release_rids = transferred_rids
|
||||
# send to the first rank
|
||||
self.send_pyobj_to_next_stage(release_rids)
|
||||
|
||||
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
||||
next_mb_id = (mb_id + 1) % self.pp_size
|
||||
next_pp_outputs = None
|
||||
next_release_rids = None
|
||||
|
||||
if mbs[next_mb_id] is not None:
|
||||
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
||||
self.pp_group.recv_tensor_dict(
|
||||
all_gather_group=self.attn_tp_group
|
||||
)
|
||||
)
|
||||
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||
output_result = GenerationBatchResult(
|
||||
logits_output=None,
|
||||
pp_hidden_states_proxy_tensors=None,
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=None,
|
||||
extend_logprob_start_len_per_req=None,
|
||||
bid=bids[next_mb_id],
|
||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||
)
|
||||
self.process_batch_result_disagg_prefill(
|
||||
mbs[next_mb_id], output_result
|
||||
)
|
||||
|
||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||
|
||||
if ENABLE_RELEASE:
|
||||
if tmbs[next_mb_id] is not None:
|
||||
# recv consensus rids from the previous rank
|
||||
next_release_rids = self.recv_pyobj_from_prev_stage()
|
||||
self.process_disagg_prefill_inflight_queue(next_release_rids)
|
||||
|
||||
# carry the outputs to the next stage
|
||||
if not self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
bids[mb_id] = result.bid
|
||||
if pp_outputs:
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
if ENABLE_RELEASE:
|
||||
if release_rids is not None:
|
||||
self.send_pyobj_to_next_stage(release_rids)
|
||||
|
||||
if not self.pp_group.is_last_rank:
|
||||
# send out reqs to the next stage
|
||||
self.send_pyobj_to_next_stage(recv_reqs)
|
||||
self.send_pyobj_to_next_stage(bootstrapped_rids)
|
||||
self.send_pyobj_to_next_stage(transferred_rids)
|
||||
|
||||
# send out proxy tensors to the next stage
|
||||
if self.cur_batch:
|
||||
self.pp_group.send_tensor_dict(
|
||||
result.pp_hidden_states_proxy_tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
pp_outputs = next_pp_outputs
|
||||
release_rids = next_release_rids
|
||||
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
if not ENABLE_RELEASE:
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
self.process_disagg_prefill_inflight_queue()
|
||||
|
||||
# When the server is idle, self-check and re-init some states
|
||||
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
|
||||
self.check_memory()
|
||||
self.check_tree_cache()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
def send_pyobj_to_next_stage(self, data):
|
||||
if self.attn_tp_rank == 0:
|
||||
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||
point_to_point_pyobj(
|
||||
data,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
self.world_group.device_group,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset,
|
||||
)
|
||||
|
||||
def recv_pyobj_from_prev_stage(self):
|
||||
if self.attn_tp_rank == 0:
|
||||
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||
data = point_to_point_pyobj(
|
||||
[],
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
self.world_group.device_group,
|
||||
((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
)
|
||||
else:
|
||||
data = None
|
||||
|
||||
if self.tp_size != 1:
|
||||
data = broadcast_pyobj(
|
||||
data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0]
|
||||
)
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user