diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 91caf99db..447fffb54 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -53,13 +53,7 @@ from sglang.srt.mem_cache.memory_pool import ( NSATokenToKVPool, SWAKVPool, ) -from sglang.srt.model_executor.forward_batch_info import PPProxyTensors -from sglang.srt.utils import ( - DynamicGradMode, - broadcast_pyobj, - point_to_point_pyobj, - require_mlp_sync, -) +from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync if TYPE_CHECKING: from torch.distributed import ProcessGroup @@ -685,218 +679,6 @@ class SchedulerDisaggregationPrefillMixin: return req.disagg_kv_sender.send(page_indices, state_indices) - # PP - @DynamicGradMode() - def event_loop_pp_disagg_prefill(self: Scheduler): - """ - An event loop for the prefill server in pipeline parallelism. - - Rules: - 1. Each stage runs in the same order and is notified by the previous stage. - 2. Each send/recv operation is blocking and matched by the neighboring stage. - - Regular Schedule: - ==================================================================== - Stage i | Stage i+1 - send ith req | recv ith req - send ith proxy | recv ith proxy - send prev (i+1)th carry | recv prev (i+1)th carry - ==================================================================== - - Prefill Server Schedule: - ==================================================================== - Stage i | Stage i+1 - send ith req | recv ith req - send ith bootstrap req | recv ith bootstrap req - send ith transferred req | recv ith transferred req - send ith proxy | recv ith proxy - send prev (i+1)th carry | recv prev (i+1)th carry - send prev (i+1)th release req | recv prev (i+1)th release req - ==================================================================== - - There are two additional elements compared to the regular schedule: - - 1. Bootstrap Requests: - a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization. - b. The first stage polls the status and propagates the bootstrapped requests down to all other stages. - c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together. - - 2. Transferred Requests + Release Requests: - a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage. - b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory. - c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage. - """ - from sglang.srt.managers.scheduler import GenerationBatchResult - - mbs = [None] * self.pp_size - last_mbs = [None] * self.pp_size - self.running_mbs = [ - ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) - ] - pp_outputs: Optional[PPProxyTensors] = None - - # Either success or failed - bootstrapped_rids: List[str] = [] - transferred_rids: List[str] = [] - release_rids: Optional[List[str]] = None - - # transferred microbatch - tmbs = [None] * self.pp_size - - ENABLE_RELEASE = True # For debug - - while True: - server_is_idle = True - - for mb_id in range(self.pp_size): - self.running_batch = self.running_mbs[mb_id] - self.last_batch = last_mbs[mb_id] - - recv_reqs = self.recv_requests() - - self.process_input_requests(recv_reqs) - - if self.pp_group.is_first_rank: - # First rank, pop the bootstrap reqs from the bootstrap queue - bootstrapped_reqs, failed_reqs = ( - self.disagg_prefill_bootstrap_queue.pop_bootstrapped( - return_failed_reqs=True - ) - ) - bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [ - req.rid for req in failed_reqs - ] - self.waiting_queue.extend(bootstrapped_reqs) - else: - # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus - bootstrapped_rids = self.recv_pyobj_from_prev_stage() - bootstrapped_reqs = ( - self.disagg_prefill_bootstrap_queue.pop_bootstrapped( - rids_to_check=bootstrapped_rids - ) - ) - self.waiting_queue.extend(bootstrapped_reqs) - - if self.pp_group.is_first_rank: - transferred_rids = self.get_transferred_rids() - # if other ranks, - else: - # 1. recv previous stage's transferred reqs info - prev_transferred_rids = self.recv_pyobj_from_prev_stage() - # 2. get the current stage's transferred reqs info - curr_transferred_rids = self.get_transferred_rids() - # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids) - transferred_rids = list( - set(prev_transferred_rids) & set(curr_transferred_rids) - ) - - tmbs[mb_id] = transferred_rids - - self.process_prefill_chunk() - mbs[mb_id] = self.get_new_batch_prefill() - self.running_mbs[mb_id] = self.running_batch - - self.cur_batch = mbs[mb_id] - if self.cur_batch: - server_is_idle = False - result = self.run_batch(self.cur_batch) - - # send the outputs to the next step - if self.pp_group.is_last_rank: - if self.cur_batch: - next_token_ids = result.next_token_ids - pp_outputs = PPProxyTensors( - { - "next_token_ids": next_token_ids, - } - ) - # send the output from the last round to let the next stage worker run post processing - self.pp_group.send_tensor_dict( - pp_outputs.tensors, - all_gather_group=self.attn_tp_group, - ) - - if ENABLE_RELEASE: - if self.pp_group.is_last_rank: - # At the last stage, all stages has reached the consensus to release memory for transferred_rids - release_rids = transferred_rids - # send to the first rank - self.send_pyobj_to_next_stage(release_rids) - - # receive outputs and post-process (filter finished reqs) the coming microbatch - next_mb_id = (mb_id + 1) % self.pp_size - next_pp_outputs = None - next_release_rids = None - - if mbs[next_mb_id] is not None: - next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors( - self.pp_group.recv_tensor_dict( - all_gather_group=self.attn_tp_group - ) - ) - mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"] - output_result = GenerationBatchResult( - logits_output=None, - pp_hidden_states_proxy_tensors=None, - next_token_ids=next_pp_outputs["next_token_ids"], - extend_input_len_per_req=None, - extend_logprob_start_len_per_req=None, - can_run_cuda_graph=result.can_run_cuda_graph, - ) - self.process_batch_result_disagg_prefill( - mbs[next_mb_id], output_result - ) - - last_mbs[next_mb_id] = mbs[next_mb_id] - - if ENABLE_RELEASE: - if tmbs[next_mb_id] is not None: - # recv consensus rids from the previous rank - next_release_rids = self.recv_pyobj_from_prev_stage() - self.process_disagg_prefill_inflight_queue(next_release_rids) - - # carry the outputs to the next stage - if not self.pp_group.is_last_rank: - if pp_outputs: - # send the outputs from the last round to let the next stage worker run post processing - self.pp_group.send_tensor_dict( - pp_outputs.tensors, - all_gather_group=self.attn_tp_group, - ) - if ENABLE_RELEASE: - if release_rids is not None: - self.send_pyobj_to_next_stage(release_rids) - - if not self.pp_group.is_last_rank: - # send out reqs to the next stage - self.send_pyobj_to_next_stage(recv_reqs) - self.send_pyobj_to_next_stage(bootstrapped_rids) - self.send_pyobj_to_next_stage(transferred_rids) - - # send out proxy tensors to the next stage - if self.cur_batch: - # FIXME(lsyin): remove this assert - assert result.pp_hidden_states_proxy_tensors.tensors is not None - self.pp_group.send_tensor_dict( - result.pp_hidden_states_proxy_tensors.tensors, - all_gather_group=self.attn_tp_group, - ) - - pp_outputs = next_pp_outputs - release_rids = next_release_rids - - self.running_batch.batch_is_full = False - - if not ENABLE_RELEASE: - if len(self.disagg_prefill_inflight_queue) > 0: - self.process_disagg_prefill_inflight_queue() - - # When the server is idle, self-check and re-init some states - if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0: - self.check_memory() - self.check_tree_cache() - self.new_token_ratio = self.init_new_token_ratio - def send_pyobj_to_next_stage(self, data): if self.attn_tp_rank == 0: dp_offset = self.attn_dp_rank * self.attn_tp_size diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ba095da9b..45c104c25 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -78,6 +78,7 @@ from sglang.srt.utils import flatten_nested_list if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 @@ -1527,8 +1528,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): if self.is_v2_eagle: # TODO(spec-v2): all v2 spec should go through this path - from sglang.srt.speculative.eagle_info import EagleDraftInput - draft_input: EagleDraftInput = self.spec_info draft_input.prepare_for_decode(self) @@ -1585,8 +1584,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def maybe_wait_verify_done(self): if self.is_v2_eagle: - from sglang.srt.speculative.eagle_info import EagleDraftInput - draft_input: EagleDraftInput = self.spec_info if draft_input.verify_done is not None: draft_input.verify_done.synchronize() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9b1cb19d0..a64439b46 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -63,7 +63,6 @@ from sglang.srt.distributed import get_pp_group, get_world_group from sglang.srt.environ import envs from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.dp_attention import compute_dp_attention_world_info -from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe import initialize_moe_config from sglang.srt.managers.io_struct import ( AbortReq, @@ -114,7 +113,7 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.mm_utils import init_embedding_cache -from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap +from sglang.srt.managers.overlap_utils import FutureMap from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, ModelWorkerBatch, @@ -136,22 +135,21 @@ from sglang.srt.managers.scheduler_metrics_mixin import ( from sglang.srt.managers.scheduler_output_processor_mixin import ( SchedulerOutputProcessorMixin, ) +from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper from sglang.srt.managers.scheduler_update_weights_mixin import ( SchedulerUpdateWeightsMixin, ) from sglang.srt.managers.session_controller import Session -from sglang.srt.managers.utils import validate_input_length +from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache -from sglang.srt.model_executor.forward_batch_info import PPProxyTensors from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args -from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.tracing.trace import ( process_tracing_init, @@ -198,77 +196,6 @@ TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT") GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) -@dataclass -class GenerationBatchResult: - logits_output: Optional[LogitsProcessorOutput] = None - pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None - next_token_ids: Optional[torch.Tensor] = None - num_accepted_tokens: Optional[int] = None - can_run_cuda_graph: bool = False - - # For output processing - extend_input_len_per_req: Optional[List[int]] = None - extend_logprob_start_len_per_req: Optional[List[int]] = None - - # For overlap scheduling - copy_done: Optional[torch.cuda.Event] = None - delay_sample_func: Optional[callable] = None - future_indices: Optional[FutureIndices] = None - - # FIXME(lsyin): maybe move to a better place? - # sync path: forward stream -> output processor - accept_lens: Optional[torch.Tensor] = None - allocate_lens: Optional[torch.Tensor] = None - - # relay path: forward stream -> next step forward - next_draft_input: Optional[EagleDraftInput] = None - - def copy_to_cpu(self, return_logprob: bool = False): - """Copy tensors to CPU in overlap scheduling. - Only the tensors which are needed for processing results are copied, - e.g., next_token_ids, logits outputs - """ - if return_logprob: - if self.logits_output.next_token_logits is not None: - self.logits_output.next_token_logits = ( - self.logits_output.next_token_logits.to("cpu", non_blocking=True) - ) - if self.logits_output.input_token_logprobs is not None: - self.logits_output.input_token_logprobs = ( - self.logits_output.input_token_logprobs.to("cpu", non_blocking=True) - ) - if self.logits_output.hidden_states is not None: - self.logits_output.hidden_states = self.logits_output.hidden_states.to( - "cpu", non_blocking=True - ) - self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True) - - if self.accept_lens is not None: - self.accept_lens = self.accept_lens.to("cpu", non_blocking=True) - - if self.allocate_lens is not None: - self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True) - - self.copy_done.record() - - @classmethod - def from_pp_proxy( - cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph - ): - # TODO(lsyin): refactor PP and avoid using dict - proxy_dict = next_pp_outputs.tensors - return cls( - logits_output=logits_output, - pp_hidden_states_proxy_tensors=None, - next_token_ids=next_pp_outputs["next_token_ids"], - extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None), - extend_logprob_start_len_per_req=proxy_dict.get( - "extend_logprob_start_len_per_req", None - ), - can_run_cuda_graph=can_run_cuda_graph, - ) - - @dataclass class EmbeddingBatchResult: embeddings: torch.Tensor @@ -281,6 +208,7 @@ class Scheduler( SchedulerMetricsMixin, SchedulerDisaggregationDecodeMixin, SchedulerDisaggregationPrefillMixin, + SchedulerPPMixin, ): """A scheduler that manages a tensor parallel GPU worker.""" @@ -1058,128 +986,6 @@ class Scheduler( self.launch_batch_sample_if_needed(batch_result) self.last_batch = batch - @DynamicGradMode() - def event_loop_pp(self): - """A non-overlap scheduler loop for pipeline parallelism.""" - mbs = [None] * self.pp_size - last_mbs = [None] * self.pp_size - self.running_mbs = [ - ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) - ] - pp_outputs: Optional[PPProxyTensors] = None - while True: - server_is_idle = True - for mb_id in range(self.pp_size): - self.running_batch = self.running_mbs[mb_id] - self.last_batch = last_mbs[mb_id] - - recv_reqs = self.recv_requests() - self.process_input_requests(recv_reqs) - mbs[mb_id] = self.get_next_batch_to_run() - self.running_mbs[mb_id] = self.running_batch - - self.cur_batch = mbs[mb_id] - if self.cur_batch: - server_is_idle = False - result = self.run_batch(self.cur_batch) - - # (last rank) send the outputs to the next step - if self.pp_group.is_last_rank: - if self.cur_batch: - next_token_ids = result.next_token_ids - if self.cur_batch.return_logprob: - pp_outputs = PPProxyTensors( - { - "next_token_ids": next_token_ids, - "extend_input_len_per_req": result.extend_input_len_per_req, - "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req, - } - | ( - { - f"logits_output.{k}": v - for k, v in result.logits_output.__dict__.items() - } - if result.logits_output is not None - else {} - ) - ) - else: - pp_outputs = PPProxyTensors( - { - "next_token_ids": next_token_ids, - } - ) - # send the output from the last round to let the next stage worker run post processing - self.pp_group.send_tensor_dict( - pp_outputs.tensors, - all_gather_group=self.attn_tp_group, - ) - - # receive outputs and post-process (filter finished reqs) the coming microbatch - next_mb_id = (mb_id + 1) % self.pp_size - next_pp_outputs = None - if mbs[next_mb_id] is not None: - next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors( - self.pp_group.recv_tensor_dict( - all_gather_group=self.attn_tp_group - ) - ) - mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"] - logits_output_args = { - k[len("logits_output.") :]: v - for k, v in next_pp_outputs.tensors.items() - if k.startswith("logits_output.") - } - if len(logits_output_args) > 0: - logits_output = LogitsProcessorOutput(**logits_output_args) - else: - logits_output = None - - output_result = GenerationBatchResult.from_pp_proxy( - logits_output=logits_output, - next_pp_outputs=next_pp_outputs, - can_run_cuda_graph=result.can_run_cuda_graph, - ) - self.process_batch_result(mbs[next_mb_id], output_result) - last_mbs[next_mb_id] = mbs[next_mb_id] - - # (not last rank) - if not self.pp_group.is_last_rank: - # carry the outputs to the next stage - # send the outputs from the last round to let the next stage worker run post processing - if pp_outputs: - self.pp_group.send_tensor_dict( - pp_outputs.tensors, - all_gather_group=self.attn_tp_group, - ) - - # send out reqs to the next stage - dp_offset = self.attn_dp_rank * self.attn_tp_size - if self.attn_tp_rank == 0: - point_to_point_pyobj( - recv_reqs, - self.pp_rank * self.tp_size + dp_offset, - self.world_group.device_group, - self.pp_rank * self.tp_size + dp_offset, - (self.pp_rank + 1) * self.tp_size + dp_offset, - ) - - # send out proxy tensors to the next stage - if self.cur_batch: - # FIXME(lsyin): remove this assert - assert result.pp_hidden_states_proxy_tensors.tensors is not None - self.pp_group.send_tensor_dict( - result.pp_hidden_states_proxy_tensors.tensors, - all_gather_group=self.attn_tp_group, - ) - - pp_outputs = next_pp_outputs - - # When the server is idle, self-check and re-init some states - if server_is_idle: - # When the server is idle, do self-check and re-init some states - self.self_check_during_idle() - def recv_requests(self) -> List[Req]: """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py new file mode 100644 index 000000000..e177d3b56 --- /dev/null +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -0,0 +1,341 @@ +from typing import List, Optional + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.utils import GenerationBatchResult +from sglang.srt.model_executor.forward_batch_info import PPProxyTensors +from sglang.srt.utils import DynamicGradMode, point_to_point_pyobj + + +class SchedulerPPMixin: + + @DynamicGradMode() + def event_loop_pp(self): + """A non-overlap scheduler loop for pipeline parallelism.""" + mbs = [None] * self.pp_size + last_mbs = [None] * self.pp_size + self.running_mbs = [ + ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) + ] + pp_outputs: Optional[PPProxyTensors] = None + while True: + server_is_idle = True + for mb_id in range(self.pp_size): + self.running_batch = self.running_mbs[mb_id] + self.last_batch = last_mbs[mb_id] + + recv_reqs = self.recv_requests() + self.process_input_requests(recv_reqs) + mbs[mb_id] = self.get_next_batch_to_run() + self.running_mbs[mb_id] = self.running_batch + + self.cur_batch = mbs[mb_id] + if self.cur_batch: + server_is_idle = False + result = self.run_batch(self.cur_batch) + + # (last rank) send the outputs to the next step + if self.pp_group.is_last_rank: + if self.cur_batch: + next_token_ids = result.next_token_ids + if self.cur_batch.return_logprob: + pp_outputs = PPProxyTensors( + { + "next_token_ids": next_token_ids, + "extend_input_len_per_req": result.extend_input_len_per_req, + "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req, + } + | ( + { + f"logits_output.{k}": v + for k, v in result.logits_output.__dict__.items() + } + if result.logits_output is not None + else {} + ) + ) + else: + pp_outputs = PPProxyTensors( + { + "next_token_ids": next_token_ids, + } + ) + # send the output from the last round to let the next stage worker run post processing + self.pp_group.send_tensor_dict( + pp_outputs.tensors, + all_gather_group=self.attn_tp_group, + ) + + # receive outputs and post-process (filter finished reqs) the coming microbatch + next_mb_id = (mb_id + 1) % self.pp_size + next_pp_outputs = None + if mbs[next_mb_id] is not None: + next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors( + self.pp_group.recv_tensor_dict( + all_gather_group=self.attn_tp_group + ) + ) + mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"] + logits_output_args = { + k[len("logits_output.") :]: v + for k, v in next_pp_outputs.tensors.items() + if k.startswith("logits_output.") + } + if len(logits_output_args) > 0: + logits_output = LogitsProcessorOutput(**logits_output_args) + else: + logits_output = None + + output_result = GenerationBatchResult.from_pp_proxy( + logits_output=logits_output, + next_pp_outputs=next_pp_outputs, + can_run_cuda_graph=result.can_run_cuda_graph, + ) + self.process_batch_result(mbs[next_mb_id], output_result) + last_mbs[next_mb_id] = mbs[next_mb_id] + + # (not last rank) + if not self.pp_group.is_last_rank: + # carry the outputs to the next stage + # send the outputs from the last round to let the next stage worker run post processing + if pp_outputs: + self.pp_group.send_tensor_dict( + pp_outputs.tensors, + all_gather_group=self.attn_tp_group, + ) + + # send out reqs to the next stage + dp_offset = self.attn_dp_rank * self.attn_tp_size + if self.attn_tp_rank == 0: + point_to_point_pyobj( + recv_reqs, + self.pp_rank * self.tp_size + dp_offset, + self.world_group.device_group, + self.pp_rank * self.tp_size + dp_offset, + (self.pp_rank + 1) * self.tp_size + dp_offset, + ) + + # send out proxy tensors to the next stage + if self.cur_batch: + # FIXME(lsyin): remove this assert + assert result.pp_hidden_states_proxy_tensors.tensors is not None + self.pp_group.send_tensor_dict( + result.pp_hidden_states_proxy_tensors.tensors, + all_gather_group=self.attn_tp_group, + ) + + pp_outputs = next_pp_outputs + + # When the server is idle, self-check and re-init some states + if server_is_idle: + # When the server is idle, do self-check and re-init some states + self.self_check_during_idle() + + @DynamicGradMode() + def event_loop_pp_disagg_prefill(self): + """ + An event loop for the prefill server in pipeline parallelism. + + Rules: + 1. Each stage runs in the same order and is notified by the previous stage. + 2. Each send/recv operation is blocking and matched by the neighboring stage. + + Regular Schedule: + ==================================================================== + Stage i | Stage i+1 + send ith req | recv ith req + send ith proxy | recv ith proxy + send prev (i+1)th carry | recv prev (i+1)th carry + ==================================================================== + + Prefill Server Schedule: + ==================================================================== + Stage i | Stage i+1 + send ith req | recv ith req + send ith bootstrap req | recv ith bootstrap req + send ith transferred req | recv ith transferred req + send ith proxy | recv ith proxy + send prev (i+1)th carry | recv prev (i+1)th carry + send prev (i+1)th release req | recv prev (i+1)th release req + ==================================================================== + + There are two additional elements compared to the regular schedule: + + 1. Bootstrap Requests: + a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization. + b. The first stage polls the status and propagates the bootstrapped requests down to all other stages. + c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together. + + 2. Transferred Requests + Release Requests: + a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage. + b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory. + c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage. + """ + mbs = [None] * self.pp_size + last_mbs = [None] * self.pp_size + self.running_mbs = [ + ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) + ] + pp_outputs: Optional[PPProxyTensors] = None + + # Either success or failed + bootstrapped_rids: List[str] = [] + transferred_rids: List[str] = [] + release_rids: Optional[List[str]] = None + + # transferred microbatch + tmbs = [None] * self.pp_size + + ENABLE_RELEASE = True # For debug + + while True: + server_is_idle = True + + for mb_id in range(self.pp_size): + self.running_batch = self.running_mbs[mb_id] + self.last_batch = last_mbs[mb_id] + + recv_reqs = self.recv_requests() + + self.process_input_requests(recv_reqs) + + if self.pp_group.is_first_rank: + # First rank, pop the bootstrap reqs from the bootstrap queue + bootstrapped_reqs, failed_reqs = ( + self.disagg_prefill_bootstrap_queue.pop_bootstrapped( + return_failed_reqs=True + ) + ) + bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [ + req.rid for req in failed_reqs + ] + self.waiting_queue.extend(bootstrapped_reqs) + else: + # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus + bootstrapped_rids = self.recv_pyobj_from_prev_stage() + bootstrapped_reqs = ( + self.disagg_prefill_bootstrap_queue.pop_bootstrapped( + rids_to_check=bootstrapped_rids + ) + ) + self.waiting_queue.extend(bootstrapped_reqs) + + if self.pp_group.is_first_rank: + transferred_rids = self.get_transferred_rids() + # if other ranks, + else: + # 1. recv previous stage's transferred reqs info + prev_transferred_rids = self.recv_pyobj_from_prev_stage() + # 2. get the current stage's transferred reqs info + curr_transferred_rids = self.get_transferred_rids() + # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids) + transferred_rids = list( + set(prev_transferred_rids) & set(curr_transferred_rids) + ) + + tmbs[mb_id] = transferred_rids + + self.process_prefill_chunk() + mbs[mb_id] = self.get_new_batch_prefill() + self.running_mbs[mb_id] = self.running_batch + + self.cur_batch = mbs[mb_id] + if self.cur_batch: + server_is_idle = False + result = self.run_batch(self.cur_batch) + + # send the outputs to the next step + if self.pp_group.is_last_rank: + if self.cur_batch: + next_token_ids = result.next_token_ids + pp_outputs = PPProxyTensors( + { + "next_token_ids": next_token_ids, + } + ) + # send the output from the last round to let the next stage worker run post processing + self.pp_group.send_tensor_dict( + pp_outputs.tensors, + all_gather_group=self.attn_tp_group, + ) + + if ENABLE_RELEASE: + if self.pp_group.is_last_rank: + # At the last stage, all stages has reached the consensus to release memory for transferred_rids + release_rids = transferred_rids + # send to the first rank + self.send_pyobj_to_next_stage(release_rids) + + # receive outputs and post-process (filter finished reqs) the coming microbatch + next_mb_id = (mb_id + 1) % self.pp_size + next_pp_outputs = None + next_release_rids = None + + if mbs[next_mb_id] is not None: + next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors( + self.pp_group.recv_tensor_dict( + all_gather_group=self.attn_tp_group + ) + ) + mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"] + output_result = GenerationBatchResult( + logits_output=None, + pp_hidden_states_proxy_tensors=None, + next_token_ids=next_pp_outputs["next_token_ids"], + extend_input_len_per_req=None, + extend_logprob_start_len_per_req=None, + can_run_cuda_graph=result.can_run_cuda_graph, + ) + self.process_batch_result_disagg_prefill( + mbs[next_mb_id], output_result + ) + + last_mbs[next_mb_id] = mbs[next_mb_id] + + if ENABLE_RELEASE: + if tmbs[next_mb_id] is not None: + # recv consensus rids from the previous rank + next_release_rids = self.recv_pyobj_from_prev_stage() + self.process_disagg_prefill_inflight_queue(next_release_rids) + + # carry the outputs to the next stage + if not self.pp_group.is_last_rank: + if pp_outputs: + # send the outputs from the last round to let the next stage worker run post processing + self.pp_group.send_tensor_dict( + pp_outputs.tensors, + all_gather_group=self.attn_tp_group, + ) + if ENABLE_RELEASE: + if release_rids is not None: + self.send_pyobj_to_next_stage(release_rids) + + if not self.pp_group.is_last_rank: + # send out reqs to the next stage + self.send_pyobj_to_next_stage(recv_reqs) + self.send_pyobj_to_next_stage(bootstrapped_rids) + self.send_pyobj_to_next_stage(transferred_rids) + + # send out proxy tensors to the next stage + if self.cur_batch: + # FIXME(lsyin): remove this assert + assert result.pp_hidden_states_proxy_tensors.tensors is not None + self.pp_group.send_tensor_dict( + result.pp_hidden_states_proxy_tensors.tensors, + all_gather_group=self.attn_tp_group, + ) + + pp_outputs = next_pp_outputs + release_rids = next_release_rids + + self.running_batch.batch_is_full = False + + if not ENABLE_RELEASE: + if len(self.disagg_prefill_inflight_queue) > 0: + self.process_disagg_prefill_inflight_queue() + + # When the server is idle, self-check and re-init some states + if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0: + self.check_memory() + self.check_tree_cache() + self.new_token_ratio = self.init_new_token_ratio diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index fa3435198..6e753f165 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -1,18 +1,95 @@ from __future__ import annotations +import dataclasses import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional + +import torch from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.overlap_utils import FutureIndices from sglang.srt.managers.schedule_batch import Req from sglang.srt.model_executor.forward_batch_info import PPProxyTensors if TYPE_CHECKING: from sglang.srt.managers.scheduler import GenerationBatchResult + from sglang.srt.speculative.eagle_info import EagleDraftInput + logger = logging.getLogger(__name__) +@dataclasses.dataclass +class GenerationBatchResult: + logits_output: Optional[LogitsProcessorOutput] = None + pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None + next_token_ids: Optional[torch.Tensor] = None + num_accepted_tokens: Optional[int] = None + can_run_cuda_graph: bool = False + + # For output processing + extend_input_len_per_req: Optional[List[int]] = None + extend_logprob_start_len_per_req: Optional[List[int]] = None + + # For overlap scheduling + copy_done: Optional[torch.cuda.Event] = None + delay_sample_func: Optional[callable] = None + future_indices: Optional[FutureIndices] = None + + # FIXME(lsyin): maybe move to a better place? + # sync path: forward stream -> output processor + accept_lens: Optional[torch.Tensor] = None + allocate_lens: Optional[torch.Tensor] = None + + # relay path: forward stream -> next step forward + next_draft_input: Optional[EagleDraftInput] = None + + def copy_to_cpu(self, return_logprob: bool = False): + """Copy tensors to CPU in overlap scheduling. + Only the tensors which are needed for processing results are copied, + e.g., next_token_ids, logits outputs + """ + if return_logprob: + if self.logits_output.next_token_logits is not None: + self.logits_output.next_token_logits = ( + self.logits_output.next_token_logits.to("cpu", non_blocking=True) + ) + if self.logits_output.input_token_logprobs is not None: + self.logits_output.input_token_logprobs = ( + self.logits_output.input_token_logprobs.to("cpu", non_blocking=True) + ) + if self.logits_output.hidden_states is not None: + self.logits_output.hidden_states = self.logits_output.hidden_states.to( + "cpu", non_blocking=True + ) + self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True) + + if self.accept_lens is not None: + self.accept_lens = self.accept_lens.to("cpu", non_blocking=True) + + if self.allocate_lens is not None: + self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True) + + self.copy_done.record() + + @classmethod + def from_pp_proxy( + cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph + ): + # TODO(lsyin): refactor PP and avoid using dict + proxy_dict = next_pp_outputs.tensors + return cls( + logits_output=logits_output, + pp_hidden_states_proxy_tensors=None, + next_token_ids=next_pp_outputs["next_token_ids"], + extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None), + extend_logprob_start_len_per_req=proxy_dict.get( + "extend_logprob_start_len_per_req", None + ), + can_run_cuda_graph=can_run_cuda_graph, + ) + + def validate_input_length( req: Req, max_req_input_len: int, allow_auto_truncate: bool ) -> Optional[str]: