diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 9e994a734..e78b5c542 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -264,15 +264,23 @@ class _SinglePassGatherer(ABC): return _DetailSinglePassGatherer( server_args, expert_location_metadata, rank ) + + if server_args.expert_distribution_recorder_mode == "stat_approx": + if server_args.enable_deepep_moe and (server_args.deepep_mode == "normal"): + return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank) + else: + raise NotImplementedError + if server_args.enable_deepep_moe: if server_args.deepep_mode == "normal": - return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank) + return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank) elif server_args.deepep_mode == "low_latency": return _DeepepLowLatencySinglePassGatherer( expert_location_metadata, rank ) else: raise NotImplementedError + return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank) def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): @@ -347,7 +355,9 @@ class _DetailSinglePassGatherer(_SinglePassGatherer): ) def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): - self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], :] = topk_ids + self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], : topk_ids.shape[1]] = ( + topk_ids + ) def on_deepep_dispatch_normal( self, @@ -380,7 +390,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer): ) -class _LayerBasedSinglePassGatherer(_SinglePassGatherer): +class _LayerBasedCpuSinglePassGatherer(_SinglePassGatherer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._objects_of_layer = {} @@ -409,29 +419,63 @@ def _list_sum(a: List, b: List) -> List: return [x + y for x, y in zip(a, b, strict=True)] -class _SelectExpertsSinglePassGatherer(_LayerBasedSinglePassGatherer): - # pretty slow, but we will use the DeepEP Gatherer in production - def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): - topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() - torch.cuda.synchronize() +class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer): + def __init__(self, *args, enable_global_physical_experts: bool, **kwargs): + super().__init__(*args, **kwargs) + self._enable_global_physical_experts = enable_global_physical_experts + self._data = torch.zeros( + ( + self._expert_location_metadata.num_layers, + ( + self._expert_location_metadata.num_physical_experts + if enable_global_physical_experts + else self._expert_location_metadata.num_local_physical_experts + ), + ), + dtype=torch.int, + device="cuda", + ) - global_physical_count = [ - 0 - ] * self._expert_location_metadata.num_physical_experts - for token_record in topk_ids_list: - for global_physical_expert_idx in token_record: - global_physical_count[global_physical_expert_idx] += 1 - - self._on_layer_data(layer_idx, global_physical_count) + def reset(self): + self._data[...] = 0 def collect(self) -> Dict: - global_physical_count = super()._collect_objects( - pad_len=self._expert_location_metadata.num_physical_experts - ) + if self._enable_global_physical_experts: + global_physical_count = self._data + else: + # Can optimize if bottleneck + global_physical_count = _convert_local_to_global_physical_count( + self._data, + rank=self._rank, + num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts, + num_physical_experts=self._expert_location_metadata.num_physical_experts, + ) + return dict(global_physical_count=global_physical_count) -class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer): +class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, enable_global_physical_experts=True) + + # can optimize (e.g. fuse / compile) + def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): + topk_ids = topk_ids.flatten() + mask = topk_ids != -1 + self._data[layer_idx, :].scatter_add_( + dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int() + ) + + +class _DeepepNormalSinglePassGatherer(_LayerBasedCpuSinglePassGatherer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if torch.distributed.get_rank() == 0: + logger.info( + "DeepepNormalSinglePassGatherer gathers approximate statistics. " + "If used with small batch size, consider using expert_distribution_recorder_mode=stat." + ) + def on_deepep_dispatch_normal( self, layer_idx: int, @@ -456,17 +500,9 @@ class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer): return dict(global_physical_count=global_physical_count) -class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): +class _DeepepLowLatencySinglePassGatherer(_LayerBasedGpuSinglePassGatherer): def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._data = torch.zeros( - ( - self._expert_location_metadata.num_layers, - self._expert_location_metadata.num_local_physical_experts, - ), - dtype=torch.int, - device="cuda", - ) + super().__init__(*args, **kwargs, enable_global_physical_experts=False) def on_deepep_dispatch_low_latency( self, layer_idx: int, local_physical_count_of_layer: torch.Tensor @@ -474,19 +510,6 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): # Most naive implementation, can optimize later self._data[layer_idx, :] += local_physical_count_of_layer - def reset(self): - self._data[...] = 0 - - def collect(self) -> Dict: - # Can optimize if bottleneck - global_physical_count = _convert_local_to_global_physical_count( - self._data, - rank=self._rank, - num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts, - num_physical_experts=self._expert_location_metadata.num_physical_experts, - ) - return dict(global_physical_count=global_physical_count) - def _convert_local_to_global_physical_count( local_physical_count: torch.Tensor, @@ -525,6 +548,7 @@ class _Accumulator(ABC): def get_class(server_args: ServerArgs) -> Type["_Accumulator"]: return { "stat": _StatAccumulator, + "stat_approx": _StatAccumulator, "per_pass": _DetailAccumulator, "per_token": _DetailAccumulator, }[server_args.expert_distribution_recorder_mode] diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 383b3138c..2160a9e6c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -460,22 +460,25 @@ class DeepseekV2MoE(nn.Module): hidden_states = state.hidden_states_mlp_input if router_logits is not None: - state.topk_weights_local, state.topk_idx_local = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - correction_bias=self.correction_bias, - routed_scaling_factor=self.routed_scaling_factor, - num_token_non_padded=state.forward_batch.num_token_non_padded, - expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( - layer_id=self.layer_id, - ), - ) + with get_global_expert_distribution_recorder().with_current_layer( + self.layer_id + ): + state.topk_weights_local, state.topk_idx_local = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + num_fused_shared_experts=self.num_fused_shared_experts, + correction_bias=self.correction_bias, + routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=state.forward_batch.num_token_non_padded, + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + layer_id=self.layer_id, + ), + ) else: state.topk_idx_local = torch.full( (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 0724ea779..f885500a9 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -255,17 +255,20 @@ class Qwen3MoeSparseMoeBlock(nn.Module): router_logits = state.pop("router_logits") hidden_states = state.hidden_states_mlp_input if router_logits is not None: - state.topk_weights_local, state.topk_idx_local = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=self.renormalize, - num_token_non_padded=state.forward_batch.num_token_non_padded, - expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( - layer_id=self.layer_id, - ), - ) + with get_global_expert_distribution_recorder().with_current_layer( + self.layer_id + ): + state.topk_weights_local, state.topk_idx_local = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=self.renormalize, + num_token_non_padded=state.forward_batch.num_token_non_padded, + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + layer_id=self.layer_id, + ), + ) else: state.topk_idx_local = torch.full( (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 92a86a0aa..ac04cdc74 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -182,7 +182,7 @@ class ServerArgs: eplb_rebalance_num_iterations: int = 1000 eplb_rebalance_layers_per_chunk: Optional[int] = None expert_distribution_recorder_mode: Optional[ - Literal["stat", "per_pass", "per_token"] + Literal["stat", "stat_approx", "per_pass", "per_token"] ] = None expert_distribution_recorder_buffer_size: Optional[int] = None enable_expert_distribution_metrics: bool = False