From 587b4c6e92681cd2e6d14123bd1c0bc69db86863 Mon Sep 17 00:00:00 2001 From: yilian49 <43861414+yilian49@users.noreply.github.com> Date: Wed, 25 Jun 2025 01:16:56 -0700 Subject: [PATCH] EPLB support for MTP (#7510) --- .../srt/managers/expert_distribution.py | 21 +++++++++++++++++++ python/sglang/srt/models/deepseek_nextn.py | 11 ++++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index e78b5c542..0b2158150 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -61,6 +61,10 @@ class ExpertDistributionRecorder(ABC): def with_debug_name(self, debug_name): yield + @contextmanager + def disable_this_region(self): + yield + @contextmanager def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch): yield @@ -116,6 +120,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder): self._expert_location_metadata = expert_location_metadata self._recording = False + self._disable_all = False self._current_forward_pass_id = Withable() self._current_layer_idx = Withable() self._current_debug_name = Withable() @@ -148,6 +153,16 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder): finally: self._on_forward_pass_end(forward_pass_id) + @contextmanager + def disable_this_region(self): + """Context manager to temporarily disable recording.""" + previous_disable_all = self._disable_all + self._disable_all = True + try: + yield + finally: + self._disable_all = previous_disable_all + def _on_forward_pass_start(self, forward_batch: ForwardBatch): if not self._recording: return @@ -189,6 +204,8 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder): ) def _on_hook(self, hook_name: str, **kwargs): + if self._disable_all: + return if not (self._recording or torch.cuda.is_current_stream_capturing()): return gatherer = self._single_pass_gatherers[ @@ -462,6 +479,10 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer): def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids = topk_ids.flatten() mask = topk_ids != -1 + assert self._data[layer_idx, :].shape == topk_ids.shape, ( + "Shape mismatch between data and topk_ids." + "Selecting expert is not supported for multiple token prediction at the moment." + ) self._data[layer_idx, :].scatter_add_( dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int() ) diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 270b11436..d83586358 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -28,6 +28,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.expert_distribution import ( + get_global_expert_distribution_recorder, +) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM @@ -82,7 +85,6 @@ class DeepseekModelNextN(nn.Module): forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - zero_allocator = BumpAllocator( buffer_size=2, dtype=torch.float32, @@ -108,9 +110,10 @@ class DeepseekModelNextN(nn.Module): ) residual = None - hidden_states, residual = self.decoder( - positions, hidden_states, forward_batch, residual, zero_allocator - ) + with get_global_expert_distribution_recorder().disable_this_region(): + hidden_states, residual = self.decoder( + positions, hidden_states, forward_batch, residual, zero_allocator + ) if not forward_batch.forward_mode.is_idle(): if residual is not None: