diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index 8667d8747..c944ef679 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -235,6 +235,15 @@ class DeepEPBuffer: cls.clean_buffer() cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY + @classmethod + def set_dispatch_mode(cls, mode: DeepEPMode): + if mode.is_low_latency(): + cls.set_dispatch_mode_as_low_latency() + elif mode.is_normal(): + cls.set_dispatch_mode_as_normal() + else: + raise Exception("unsupported mode") + class DeepEPConfig(BaseDispatcherConfig): _instance = None diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 90635c776..aedbd037c 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -40,6 +40,8 @@ from sglang.srt.layers.dp_attention import ( set_dp_buffer_len, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer +from sglang.srt.layers.moe.utils import get_deepep_mode, get_moe_a2a_backend from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, @@ -240,6 +242,8 @@ class CudaGraphRunner: self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() + self.deepep_adapter = DeepEPCudaGraphRunnerAdapter() + # Batch sizes to capture self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}") @@ -653,6 +657,8 @@ class CudaGraphRunner: ) return logits_output_or_pp_proxy_tensors + self.deepep_adapter.capture(is_extend_in_batch=False) + for _ in range(2): self.device_module.synchronize() self.model_runner.tp_group.barrier() @@ -796,6 +802,8 @@ class CudaGraphRunner: skip_attn_backend_init: bool = False, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + self.deepep_adapter.replay() + if not skip_attn_backend_init: self.replay_prepare(forward_batch, pp_proxy_tensors) else: @@ -872,3 +880,23 @@ CUDA_GRAPH_CAPTURE_FAILED_MSG = ( "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" ) + + +class DeepEPCudaGraphRunnerAdapter: + def __init__(self): + # Record DeepEP mode used during capture to ensure replay consistency + self._captured_deepep_mode = None + + def capture(self, is_extend_in_batch: bool): + if not get_moe_a2a_backend().is_deepep(): + return + self._captured_deepep_mode = get_deepep_mode().resolve( + is_extend_in_batch=is_extend_in_batch + ) + DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode) + + def replay(self): + if not get_moe_a2a_backend().is_deepep(): + return + assert self._captured_deepep_mode is not None + DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index a2ce4614b..c82df4d2e 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -9,6 +9,7 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len from sglang.srt.model_executor.cuda_graph_runner import ( CUDA_GRAPH_CAPTURE_FAILED_MSG, CudaGraphRunner, + DeepEPCudaGraphRunnerAdapter, get_batch_sizes_to_capture, get_global_graph_memory_pool, model_capture_mode, @@ -61,6 +62,7 @@ class EAGLEDraftCudaGraphRunner: self.enable_profile_cuda_graph = ( model_runner.server_args.enable_profile_cuda_graph ) + self.deepep_adapter = DeepEPCudaGraphRunnerAdapter() server_args = model_runner.server_args # Batch sizes to capture @@ -264,6 +266,8 @@ class EAGLEDraftCudaGraphRunner: forward_batch.spec_info.hidden_states = hidden_states_backup return ret + self.deepep_adapter.capture(is_extend_in_batch=False) + for _ in range(2): torch.cuda.synchronize() self.model_runner.tp_group.barrier() @@ -285,6 +289,8 @@ class EAGLEDraftCudaGraphRunner: def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None + self.deepep_adapter.replay() + raw_bs = forward_batch.batch_size raw_num_token = raw_bs * self.num_tokens_per_bs diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 9612a8da2..39d8e0f6a 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -9,6 +9,7 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len from sglang.srt.model_executor.cuda_graph_runner import ( CUDA_GRAPH_CAPTURE_FAILED_MSG, CudaGraphRunner, + DeepEPCudaGraphRunnerAdapter, LogitsProcessorOutput, get_batch_sizes_to_capture, get_global_graph_memory_pool, @@ -61,6 +62,7 @@ class EAGLEDraftExtendCudaGraphRunner: ) self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.padded_static_len = -1 + self.deepep_adapter = DeepEPCudaGraphRunnerAdapter() # Attention backend self.num_tokens_per_bs = self.speculative_num_steps + 1 @@ -243,6 +245,8 @@ class EAGLEDraftExtendCudaGraphRunner: ) spec_info.positions = None + self.deepep_adapter.capture(is_extend_in_batch=True) + # Forward batch forward_batch = ForwardBatch( forward_mode=ForwardMode.DRAFT_EXTEND, @@ -318,6 +322,8 @@ class EAGLEDraftExtendCudaGraphRunner: def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None + self.deepep_adapter.replay() + # batch_size and num_seqs can be different in case there are finished examples # in the batch, which will not be counted as num_seqs raw_bs = forward_batch.batch_size