From 384f8ab5ce2220caf00bb0815e08d33068ec5c06 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Sun, 17 Aug 2025 09:31:31 +0800 Subject: [PATCH] [PD] Support PD disaggregation with Prefill PP (#8846) Signed-off-by: Shangming Cai Signed-off-by: Shangming Cai Co-authored-by: root Co-authored-by: Ying Sheng Co-authored-by: Francis <38564764+ssssnow@users.noreply.github.com> Co-authored-by: zitto --- python/sglang/srt/disaggregation/base/conn.py | 1 + .../srt/disaggregation/mooncake/conn.py | 58 ++-- python/sglang/srt/disaggregation/prefill.py | 264 +++++++++++++++++- python/sglang/srt/managers/scheduler.py | 5 +- python/sglang/srt/managers/utils.py | 60 +++- python/sglang/srt/mem_cache/memory_pool.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 11 +- python/sglang/srt/models/deepseek_nextn.py | 4 +- python/sglang/srt/models/deepseek_v2.py | 151 +++++++--- test/srt/test_disaggregation_pp.py | 133 +++++++++ test/srt/test_pp_single_node.py | 25 ++ 11 files changed, 632 insertions(+), 82 deletions(-) create mode 100644 test/srt/test_disaggregation_pp.py diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index d37575dcf..584530e69 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -30,6 +30,7 @@ class KVArgs: # for pp prefill prefill_pp_size: int pp_rank: int + prefill_start_layer: int # for system dp system_dp_rank: int diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 9e35078e7..e58186d33 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -34,6 +34,7 @@ from sglang.srt.disaggregation.common.utils import ( ) from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.distributed import get_pp_group from sglang.srt.layers.dp_attention import ( get_attention_dp_rank, get_attention_dp_size, @@ -180,6 +181,7 @@ class MooncakeKVManager(BaseKVManager): self.session_failures = defaultdict(int) self.failed_sessions = set() self.session_lock = threading.Lock() + self.pp_group = get_pp_group() # Determine the number of threads to use for kv sender cpu_count = os.cpu_count() transfer_thread_pool_size = get_int_env_var( @@ -313,11 +315,11 @@ class MooncakeKVManager(BaseKVManager): layers_params = None # pp is not supported on the decode side yet + start_layer = self.kv_args.prefill_start_layer + end_layer = start_layer + len(self.kv_args.kv_data_ptrs) if self.is_mla_backend: src_kv_ptrs = self.kv_args.kv_data_ptrs layers_per_pp_stage = len(src_kv_ptrs) - start_layer = self.pp_rank * layers_per_pp_stage - end_layer = start_layer + layers_per_pp_stage dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer] kv_item_len = self.kv_args.kv_item_lens[0] layers_params = [ @@ -330,17 +332,15 @@ class MooncakeKVManager(BaseKVManager): ] else: num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 + dst_num_total_layers = num_kv_layers * self.pp_size src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers] src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:] layers_per_pp_stage = len(src_k_ptrs) - start_layer = self.pp_rank * layers_per_pp_stage - end_layer = start_layer + layers_per_pp_stage dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer] dst_v_ptrs = dst_kv_ptrs[ - num_kv_layers + start_layer : num_kv_layers + end_layer + dst_num_total_layers + start_layer : dst_num_total_layers + end_layer ] kv_item_len = self.kv_args.kv_item_lens[0] - layers_params = [ ( src_k_ptrs[layer_id], @@ -452,6 +452,7 @@ class MooncakeKVManager(BaseKVManager): # pp is not supported on the decode side yet num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 + dst_num_total_layers = num_kv_layers * self.pp_size src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers] src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:] layers_per_pp_stage = len(src_k_ptrs) @@ -459,7 +460,7 @@ class MooncakeKVManager(BaseKVManager): end_layer = start_layer + layers_per_pp_stage dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer] dst_v_ptrs = dst_kv_ptrs[ - num_kv_layers + start_layer : num_kv_layers + end_layer + dst_num_total_layers + start_layer : dst_num_total_layers + end_layer ] # Calculate precise byte offset and length for the sub-slice within the token @@ -612,7 +613,7 @@ class MooncakeKVManager(BaseKVManager): ) polls = [] dst_ranks_infos = [] - local_rank = self.kv_args.engine_rank + local_rank = self.attn_tp_rank * self.pp_size + self.pp_rank for req in reqs_to_be_processed: if not req.is_dummy: # Early exit if the request has failed @@ -695,13 +696,14 @@ class MooncakeKVManager(BaseKVManager): break if kv_chunk.is_last: - # Only the last chunk we need to send the aux data - ret = self.send_aux( - req.mooncake_session_id, - kv_chunk.prefill_aux_index, - target_rank_registration_info.dst_aux_ptrs, - req.dst_aux_index, - ) + if self.pp_group.is_last_rank: + # Only the last chunk we need to send the aux data + ret = self.send_aux( + req.mooncake_session_id, + kv_chunk.prefill_aux_index, + target_rank_registration_info.dst_aux_ptrs, + req.dst_aux_index, + ) polls.append(True if ret == 0 else False) dst_ranks_infos.append( (req.endpoint, req.dst_port, req.room) @@ -798,10 +800,7 @@ class MooncakeKVManager(BaseKVManager): arrived_response_num = len( self.prefill_response_tracker[bootstrap_room] ) - if ( - self.is_mla_backend - or arrived_response_num == expected_response_num - ): + if arrived_response_num == expected_response_num: self.update_status(bootstrap_room, KVPoll.Success) elif status == KVPoll.Failed: self.record_failure( @@ -1183,7 +1182,9 @@ class MooncakeKVReceiver(BaseKVReceiver): self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size ) self.required_dst_info_num = 1 - self.required_prefill_response_num = 1 + self.required_prefill_response_num = 1 * ( + self.prefill_pp_size // self.kv_mgr.pp_size + ) self.target_tp_ranks = [self.target_tp_rank] elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size: if not self.kv_mgr.is_mla_backend: @@ -1196,7 +1197,9 @@ class MooncakeKVReceiver(BaseKVReceiver): self.required_dst_info_num = ( self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size ) - self.required_prefill_response_num = 1 + self.required_prefill_response_num = 1 * ( + self.prefill_pp_size // self.kv_mgr.pp_size + ) self.target_tp_ranks = [self.target_tp_rank] else: if not self.kv_mgr.is_mla_backend: @@ -1219,9 +1222,14 @@ class MooncakeKVReceiver(BaseKVReceiver): # or the KVPoll will never be set correctly self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 - self.required_prefill_response_num = ( - self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size - ) + if self.kv_mgr.is_mla_backend: + self.required_prefill_response_num = ( + self.prefill_pp_size // self.kv_mgr.pp_size + ) + else: + self.required_prefill_response_num = ( + self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size + ) * (self.prefill_pp_size // self.kv_mgr.pp_size) if self.data_parallel_rank is not None: logger.debug(f"Targeting DP rank: {self.data_parallel_rank}") @@ -1530,7 +1538,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): "rank_port": rank_port, } logger.debug( - f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" + f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" ) return web.Response(text="OK", status=200) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 72cf9d3f9..675e3708a 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -43,8 +43,13 @@ from sglang.srt.disaggregation.utils import ( prepare_abort, ) from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch -from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.utils import require_mlp_sync +from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors +from sglang.srt.utils import ( + DynamicGradMode, + broadcast_pyobj, + point_to_point_pyobj, + require_mlp_sync, +) if TYPE_CHECKING: from torch.distributed import ProcessGroup @@ -107,6 +112,7 @@ class PrefillBootstrapQueue: kv_args.system_dp_rank = self.scheduler.dp_rank kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size kv_args.prefill_pp_size = self.pp_size + kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer kv_data_ptrs, kv_data_lens, kv_item_lens = ( self.token_to_kv_pool.get_contiguous_buf_infos() ) @@ -208,8 +214,8 @@ class PrefillBootstrapQueue: polls = poll_and_all_reduce( [req.disagg_kv_sender for req in self.queue], self.gloo_group ) - for i, (req, poll) in enumerate(zip(self.queue, polls)): + for i, (req, poll) in enumerate(zip(self.queue, polls)): if rids_to_check is not None: # if req not in reqs_info_to_check, skip if req.rid not in rids_to_check: @@ -395,7 +401,10 @@ class SchedulerDisaggregationPrefillMixin: req.output_ids.append(next_token_id) self.tree_cache.cache_unfinished_req(req) # update the tree and lock self.disagg_prefill_inflight_queue.append(req) - if logits_output.hidden_states is not None: + if ( + logits_output is not None + and logits_output.hidden_states is not None + ): last_hidden_index = ( hidden_state_offset + extend_input_len_per_req[i] - 1 ) @@ -603,3 +612,250 @@ class SchedulerDisaggregationPrefillMixin: ) return req.disagg_kv_sender.send(page_indices) + + # PP + @DynamicGradMode() + def event_loop_pp_disagg_prefill(self: Scheduler): + """ + An event loop for the prefill server in pipeline parallelism. + + Rules: + 1. Each stage runs in the same order and is notified by the previous stage. + 2. Each send/recv operation is blocking and matched by the neighboring stage. + + Regular Schedule: + ==================================================================== + Stage i | Stage i+1 + send ith req | recv ith req + send ith proxy | recv ith proxy + send prev (i+1)th carry | recv prev (i+1)th carry + ==================================================================== + + Prefill Server Schedule: + ==================================================================== + Stage i | Stage i+1 + send ith req | recv ith req + send ith bootstrap req | recv ith bootstrap req + send ith transferred req | recv ith transferred req + send ith proxy | recv ith proxy + send prev (i+1)th carry | recv prev (i+1)th carry + send prev (i+1)th release req | recv prev (i+1)th release req + ==================================================================== + + There are two additional elements compared to the regular schedule: + + 1. Bootstrap Requests: + a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization. + b. The first stage polls the status and propagates the bootstrapped requests down to all other stages. + c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together. + + 2. Transferred Requests + Release Requests: + a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage. + b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory. + c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage. + """ + from sglang.srt.managers.scheduler import GenerationBatchResult + + mbs = [None] * self.pp_size + last_mbs = [None] * self.pp_size + self.running_mbs = [ + ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) + ] + bids = [None] * self.pp_size + pp_outputs: Optional[PPProxyTensors] = None + + # Either success or failed + bootstrapped_rids: List[str] = [] + transferred_rids: List[str] = [] + release_rids: Optional[List[str]] = None + + # transferred microbatch + tmbs = [None] * self.pp_size + + ENABLE_RELEASE = True # For debug + + while True: + server_is_idle = True + + for mb_id in range(self.pp_size): + self.running_batch = self.running_mbs[mb_id] + self.last_batch = last_mbs[mb_id] + + recv_reqs = self.recv_requests() + + self.process_input_requests(recv_reqs) + + if self.pp_group.is_first_rank: + # First rank, pop the bootstrap reqs from the bootstrap queue + bootstrapped_reqs, failed_reqs = ( + self.disagg_prefill_bootstrap_queue.pop_bootstrapped( + return_failed_reqs=True + ) + ) + bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [ + req.rid for req in failed_reqs + ] + self.waiting_queue.extend(bootstrapped_reqs) + else: + # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus + bootstrapped_rids = self.recv_pyobj_from_prev_stage() + bootstrapped_reqs = ( + self.disagg_prefill_bootstrap_queue.pop_bootstrapped( + rids_to_check=bootstrapped_rids + ) + ) + self.waiting_queue.extend(bootstrapped_reqs) + + if self.pp_group.is_first_rank: + transferred_rids = self.get_transferred_rids() + # if other ranks, + else: + # 1. recv previous stage's transferred reqs info + prev_transferred_rids = self.recv_pyobj_from_prev_stage() + # 2. get the current stage's transferred reqs info + curr_transferred_rids = self.get_transferred_rids() + # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids) + transferred_rids = list( + set(prev_transferred_rids) & set(curr_transferred_rids) + ) + + tmbs[mb_id] = transferred_rids + + self.process_prefill_chunk() + mbs[mb_id] = self.get_new_batch_prefill() + self.running_mbs[mb_id] = self.running_batch + + self.cur_batch = mbs[mb_id] + if self.cur_batch: + server_is_idle = False + result = self.run_batch(self.cur_batch) + + # send the outputs to the next step + if self.pp_group.is_last_rank: + if self.cur_batch: + next_token_ids, bids[mb_id] = ( + result.next_token_ids, + result.bid, + ) + pp_outputs = PPProxyTensors( + { + "next_token_ids": next_token_ids, + } + ) + # send the output from the last round to let the next stage worker run post processing + self.pp_group.send_tensor_dict( + pp_outputs.tensors, + all_gather_group=self.attn_tp_group, + ) + + if ENABLE_RELEASE: + if self.pp_group.is_last_rank: + # At the last stage, all stages has reached the consensus to release memory for transferred_rids + release_rids = transferred_rids + # send to the first rank + self.send_pyobj_to_next_stage(release_rids) + + # receive outputs and post-process (filter finished reqs) the coming microbatch + next_mb_id = (mb_id + 1) % self.pp_size + next_pp_outputs = None + next_release_rids = None + + if mbs[next_mb_id] is not None: + next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors( + self.pp_group.recv_tensor_dict( + all_gather_group=self.attn_tp_group + ) + ) + mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"] + output_result = GenerationBatchResult( + logits_output=None, + pp_hidden_states_proxy_tensors=None, + next_token_ids=next_pp_outputs["next_token_ids"], + extend_input_len_per_req=None, + extend_logprob_start_len_per_req=None, + bid=bids[next_mb_id], + can_run_cuda_graph=result.can_run_cuda_graph, + ) + self.process_batch_result_disagg_prefill( + mbs[next_mb_id], output_result + ) + + last_mbs[next_mb_id] = mbs[next_mb_id] + + if ENABLE_RELEASE: + if tmbs[next_mb_id] is not None: + # recv consensus rids from the previous rank + next_release_rids = self.recv_pyobj_from_prev_stage() + self.process_disagg_prefill_inflight_queue(next_release_rids) + + # carry the outputs to the next stage + if not self.pp_group.is_last_rank: + if self.cur_batch: + bids[mb_id] = result.bid + if pp_outputs: + # send the outputs from the last round to let the next stage worker run post processing + self.pp_group.send_tensor_dict( + pp_outputs.tensors, + all_gather_group=self.attn_tp_group, + ) + if ENABLE_RELEASE: + if release_rids is not None: + self.send_pyobj_to_next_stage(release_rids) + + if not self.pp_group.is_last_rank: + # send out reqs to the next stage + self.send_pyobj_to_next_stage(recv_reqs) + self.send_pyobj_to_next_stage(bootstrapped_rids) + self.send_pyobj_to_next_stage(transferred_rids) + + # send out proxy tensors to the next stage + if self.cur_batch: + self.pp_group.send_tensor_dict( + result.pp_hidden_states_proxy_tensors, + all_gather_group=self.attn_tp_group, + ) + + pp_outputs = next_pp_outputs + release_rids = next_release_rids + + self.running_batch.batch_is_full = False + + if not ENABLE_RELEASE: + if len(self.disagg_prefill_inflight_queue) > 0: + self.process_disagg_prefill_inflight_queue() + + # When the server is idle, self-check and re-init some states + if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0: + self.check_memory() + self.check_tree_cache() + self.new_token_ratio = self.init_new_token_ratio + + def send_pyobj_to_next_stage(self, data): + if self.attn_tp_rank == 0: + dp_offset = self.attn_dp_rank * self.attn_tp_size + point_to_point_pyobj( + data, + self.pp_rank * self.tp_size + dp_offset, + self.world_group.device_group, + self.pp_rank * self.tp_size + dp_offset, + ((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset, + ) + + def recv_pyobj_from_prev_stage(self): + if self.attn_tp_rank == 0: + dp_offset = self.attn_dp_rank * self.attn_tp_size + data = point_to_point_pyobj( + [], + self.pp_rank * self.tp_size + dp_offset, + self.world_group.device_group, + ((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset, + self.pp_rank * self.tp_size + dp_offset, + ) + else: + data = None + + if self.tp_size != 1: + data = broadcast_pyobj( + data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0] + ) + return data diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 04e6f13b0..05878fe4e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2579,7 +2579,10 @@ def run_scheduler_process( if scheduler.enable_overlap: scheduler.event_loop_overlap_disagg_prefill() else: - scheduler.event_loop_normal_disagg_prefill() + if server_args.pp_size > 1: + scheduler.event_loop_pp_disagg_prefill() + else: + scheduler.event_loop_normal_disagg_prefill() elif disaggregation_mode == DisaggregationMode.DECODE: if scheduler.enable_overlap: diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index 2ab32f242..de83c4590 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -1,9 +1,16 @@ +from __future__ import annotations + import logging import multiprocessing as mp from http import HTTPStatus -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req +from sglang.srt.model_executor.forward_batch_info import PPProxyTensors + +if TYPE_CHECKING: + from sglang.srt.managers.scheduler import GenerationBatchResult logger = logging.getLogger(__name__) @@ -41,6 +48,57 @@ def validate_input_length( return None +def get_logprob_dict_from_result(result: GenerationBatchResult) -> dict: + + logits_output = result.logits_output + assert logits_output is not None + + return { + "extend_input_len_per_req": result.extend_input_len_per_req, + "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req, + "next_token_logprobs": result.logits_output.next_token_logprobs, + "next_token_top_logprobs_val": result.logits_output.next_token_top_logprobs_val, + "next_token_top_logprobs_idx": result.logits_output.next_token_top_logprobs_idx, + "next_token_token_ids_logprobs_val": result.logits_output.next_token_token_ids_logprobs_val, + "next_token_token_ids_logprobs_idx": result.logits_output.next_token_token_ids_logprobs_idx, + "input_token_logprobs": result.logits_output.input_token_logprobs, + "input_top_logprobs_val": result.logits_output.input_top_logprobs_val, + "input_top_logprobs_idx": result.logits_output.input_top_logprobs_idx, + "input_token_ids_logprobs_val": result.logits_output.input_token_ids_logprobs_val, + "input_token_ids_logprobs_idx": result.logits_output.input_token_ids_logprobs_idx, + } + + +def get_logprob_from_pp_outputs( + next_pp_outputs: PPProxyTensors, +) -> tuple[LogitsProcessorOutput, list[int], list[int]]: + logits_output = LogitsProcessorOutput( + # Do not send logits and hidden states because they are large + next_token_logits=None, + hidden_states=None, + next_token_logprobs=next_pp_outputs["next_token_logprobs"], + next_token_top_logprobs_val=next_pp_outputs["next_token_top_logprobs_val"], + next_token_top_logprobs_idx=next_pp_outputs["next_token_top_logprobs_idx"], + next_token_token_ids_logprobs_val=next_pp_outputs[ + "next_token_token_ids_logprobs_val" + ], + next_token_token_ids_logprobs_idx=next_pp_outputs[ + "next_token_token_ids_logprobs_idx" + ], + input_token_logprobs=next_pp_outputs["input_token_logprobs"], + input_top_logprobs_val=next_pp_outputs["input_top_logprobs_val"], + input_top_logprobs_idx=next_pp_outputs["input_top_logprobs_idx"], + input_token_ids_logprobs_val=next_pp_outputs["input_token_ids_logprobs_val"], + input_token_ids_logprobs_idx=next_pp_outputs["input_token_ids_logprobs_idx"], + ) + extend_input_len_per_req = next_pp_outputs["extend_input_len_per_req"] + extend_logprob_start_len_per_req = next_pp_outputs[ + "extend_logprob_start_len_per_req" + ] + + return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req + + class DPBalanceMeta: """ This class will be use in scheduler and dp controller diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 22202cc31..07d7f5234 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -849,7 +849,7 @@ class MLATokenToKVPool(KVCache): cache_k_rope = cache_k_rope.view(self.store_dtype) set_mla_kv_buffer_triton( - self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope + self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope ) def get_cpu_copy(self, indices): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c7f6ad71c..41b9ce93f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -307,8 +307,13 @@ class ModelRunner: self.start_layer = getattr(self.model, "start_layer", 0) self.end_layer = getattr(self.model, "end_layer", model_num_layers) self.num_effective_layers = self.end_layer - self.start_layer - assert (not model_has_mtp_layers) or ( - self.num_effective_layers == model_num_layers + assert ( + (not model_has_mtp_layers) + or (self.spec_algorithm.is_none()) + or ( + (not self.spec_algorithm.is_none()) + and (self.num_effective_layers == model_num_layers) + ) ), "PP is not compatible with MTP models." # Apply torchao quantization @@ -1048,8 +1053,6 @@ class ModelRunner: else: num_layers = self.num_effective_layers if self.use_mla_backend: - # FIXME: pipeline parallelism is not compatible with mla backend - assert self.pp_size == 1 cell_size = ( (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) * num_layers diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 5b1ae6e69..0d2283078 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -20,7 +20,7 @@ import torch from torch import nn from transformers import PretrainedConfig -from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm @@ -135,6 +135,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): self.config = config self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config + # if not set, model load will be broken in DeepseekV3ForCausalLM load_weights() + self.pp_group = get_pp_group() self.determine_num_fused_shared_experts("DeepseekV3ForCausalLMNextN") self.model = DeepseekModelNextN( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 384698167..37274e45b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -20,7 +20,7 @@ import concurrent.futures import logging import os from enum import IntEnum, auto -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -30,6 +30,7 @@ from transformers import PretrainedConfig from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, + get_pp_group, get_tensor_model_parallel_world_size, parallel_state, tensor_model_parallel_all_reduce, @@ -87,13 +88,13 @@ from sglang.srt.layers.quantization.int8_utils import ( ) from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper -from sglang.srt.layers.utils import is_sm100_supported +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.two_batch_overlap import ( MaybeTboDeepEPDispatcher, @@ -114,6 +115,7 @@ from sglang.srt.utils import ( is_hip, is_non_idle_and_non_empty, log_info_on_rank0, + make_layers, use_intel_amx_backend, ) @@ -2029,26 +2031,35 @@ class DeepseekV2Model(nn.Module): self.padding_id = config.pad_token_id self.vocab_size = config.vocab_size self.first_k_dense_replace = config.first_k_dense_replace + self.pp_group = get_pp_group() + + if self.pp_group.is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not is_dp_attention_enabled(), + ) + else: + self.embed_tokens = PPMissingLayer() - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - enable_tp=not is_dp_attention_enabled(), - ) self.alt_stream = torch.cuda.Stream() if _is_cuda else None - self.layers = nn.ModuleList( - [ - DeepseekV2DecoderLayer( - config, - layer_id, - quant_config=quant_config, - prefix=add_prefix(f"layers.{layer_id}", prefix), - alt_stream=self.alt_stream, - ) - for layer_id in range(config.num_hidden_layers) - ] + self.layers, self.start_layer, self.end_layer = make_layers( + config.num_hidden_layers, + lambda idx, prefix: DeepseekV2DecoderLayer( + config=config, + layer_id=idx, + quant_config=quant_config, + prefix=prefix, + alt_stream=self.alt_stream, + ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, + prefix=add_prefix("layers", prefix), ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.pp_group.is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer(return_tuple=True) def get_input_embeddings(self) -> torch.Tensor: return self.embed_tokens @@ -2059,8 +2070,9 @@ class DeepseekV2Model(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> torch.Tensor: - total_num_layers = len(self.layers) + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + total_num_layers = self.end_layer - self.start_layer device = input_embeds.device if input_embeds is not None else input_ids.device zero_allocator = BumpAllocator( buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1), @@ -2068,44 +2080,62 @@ class DeepseekV2Model(nn.Module): device=device, ) - if input_embeds is None: - hidden_states = self.embed_tokens(input_ids) + if self.pp_group.is_first_rank: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None else: - hidden_states = input_embeds + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] - residual = None + normal_start_layer = self.start_layer + normal_end_layer = self.end_layer + if forward_batch.can_run_tbo: + if ( + self.first_k_dense_replace > normal_start_layer + and self.first_k_dense_replace < normal_end_layer + ): + normal_end_layer = self.first_k_dense_replace + elif self.first_k_dense_replace < normal_start_layer: + normal_end_layer = normal_start_layer = 0 - normal_num_layers = ( - self.first_k_dense_replace - if forward_batch.can_run_tbo - else total_num_layers - ) - for i in range(normal_num_layers): + for i in range(normal_start_layer, normal_end_layer): with get_global_expert_distribution_recorder().with_current_layer(i): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual, zero_allocator ) - if normal_num_layers != total_num_layers: + if normal_end_layer != self.end_layer: hidden_states, residual = model_forward_maybe_tbo( - layers=self.layers[normal_num_layers:], + layers=self.layers[normal_end_layer : self.end_layer], enable_tbo=True, positions=positions, forward_batch=forward_batch, hidden_states=hidden_states, residual=residual, input_data_scatter_mode=self.layers[ - normal_num_layers - 1 + normal_end_layer - 1 ].layer_scatter_modes.layer_output_mode, zero_allocator=zero_allocator, ) - if not forward_batch.forward_mode.is_idle(): - if residual is None: - hidden_states = self.norm(hidden_states) - else: - hidden_states, _ = self.norm(hidden_states, residual) + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + if not forward_batch.forward_mode.is_idle(): + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -2132,6 +2162,7 @@ class DeepseekV2ForCausalLM(nn.Module): "kv_a_proj_with_mqa", ] + self.pp_group = get_pp_group() self.config = config self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config @@ -2201,13 +2232,27 @@ class DeepseekV2ForCausalLM(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + hidden_states = self.model( + input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors ) + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return hidden_states + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer + def post_load_weights(self, is_nextn=False, weight_names=None): # Perform post-processing after loading weights @@ -2215,7 +2260,7 @@ class DeepseekV2ForCausalLM(nn.Module): layer_ids = [self.config.num_hidden_layers] else: if weight_names is None: - layer_ids = range(self.config.num_hidden_layers) + layer_ids = range(self.model.start_layer, self.model.end_layer) else: layer_ids = set() for name in weight_names: @@ -2497,6 +2542,16 @@ class DeepseekV2ForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) weight_names = [] for name, loaded_weight in weights: + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: name = name.replace( "mlp.shared_experts", @@ -2581,6 +2636,12 @@ class DeepseekV2ForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip loading embed_tokens if not first rank in pipeline parallelism + if ".embed_tokens." in name and not self.pp_group.is_first_rank: + continue + # Skip loading norm if not last rank in pipeline parallelism + if ".norm." in name and not self.pp_group.is_last_rank: + continue if fuse_qkv_a_proj and ( "q_a_proj" in name or "kv_a_proj_with_mqa" in name ): diff --git a/test/srt/test_disaggregation_pp.py b/test/srt/test_disaggregation_pp.py new file mode 100644 index 000000000..6c04d0cce --- /dev/null +++ b/test/srt/test_disaggregation_pp.py @@ -0,0 +1,133 @@ +import json +import os +import random +import time +import unittest +from concurrent.futures import ThreadPoolExecutor +from types import SimpleNamespace +from typing import List, Optional + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.runners import DEFAULT_PROMPTS +from sglang.test.test_utils import ( + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestPDPPAccuracy(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + parsed_url = urlparse(DEFAULT_URL_FOR_TEST) + cls.base_host = parsed_url.hostname + base_port = str(parsed_url.port) + cls.lb_port = base_port + cls.prefill_port = f"{int(base_port) + 100}" + cls.decode_port = f"{int(base_port) + 200}" + cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" + cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" + cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" + print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") + + # Non blocking start servers + cls.start_prefill() + cls.start_decode() + + # Block until both + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + lb_command = [ + "python3", + "-m", + "sglang.srt.disaggregation.mini_lb", + "--prefill", + cls.prefill_url, + "--decode", + cls.decode_url, + "--host", + cls.base_host, + "--port", + cls.lb_port, + ] + + print("Starting load balancer:", " ".join(lb_command)) + cls.process_lb = subprocess.Popen( + lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + cls.wait_server_ready(cls.lb_url + "/health") + + @classmethod + def start_prefill(cls): + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--tp-size", + "2", + "--pp-size", + "2", + "--disaggregation-ib-device", + "mlx5_roce0", + "--disable-overlap-schedule", + ] + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp", + "1", + "--base-gpu-id", + "1", + "--disaggregation-ib-device", + "mlx5_roce1", + ] + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + + self.assertGreater(metrics["accuracy"], 0.24) + # Wait a little bit so that the memory check happens. + time.sleep(5) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_pp_single_node.py b/test/srt/test_pp_single_node.py index 01aecdd38..f1fb3e212 100644 --- a/test/srt/test_pp_single_node.py +++ b/test/srt/test_pp_single_node.py @@ -9,6 +9,8 @@ import time import unittest from types import SimpleNamespace +import requests + from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_process_tree @@ -62,6 +64,29 @@ class TestPPAccuracy(unittest.TestCase): # Wait a little bit so that the memory check happens. time.sleep(4) + def test_logprob(self): + response = requests.post( + f"{self.base_url}/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + }, + "return_logprob": True, + "top_logprobs_num": 5, + "logprob_start_len": 0, + }, + ) + response_json = response.json() + input_token_logprobs = response_json["meta_info"]["input_token_logprobs"] + output_token_logprobs = response_json["meta_info"]["output_token_logprobs"] + output_top_logprobs = response_json["meta_info"]["output_top_logprobs"] + + assert len(input_token_logprobs) == 6 + assert len(output_token_logprobs) == 16 + assert len(output_top_logprobs) == 16 + class TestQwenPPAccuracy(unittest.TestCase): @classmethod