diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 0eb2a9170..36e7964a8 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -800,7 +800,7 @@ class FusedMoE(torch.nn.Module): f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded." ) - def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs): origin_hidden_states_dim = hidden_states.shape[-1] assert self.quant_method is not None @@ -825,6 +825,7 @@ class FusedMoE(torch.nn.Module): combine_input = self.quant_method.apply( layer=self, dispatch_output=dispatch_output, + **kwargs, ) final_hidden_states = self.dispatcher.combine(combine_input) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index f1c6dafb5..572a8e8d7 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +import os from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch @@ -1347,6 +1348,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): self, layer: FusedMoE, dispatch_output: StandardDispatchOutput, + forward_shared_experts=None, + alt_stream=None, ) -> CombineInput: from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput @@ -1418,9 +1421,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): )[0] if should_use_flashinfer_cutlass_moe_fp4_allgather(): output, global_output = get_local_dp_buffer(), output + + if forward_shared_experts is not None: + alt_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(alt_stream): + forward_shared_experts() + get_tp_group().reduce_scatterv( global_output, output=output, sizes=get_dp_global_num_tokens() ) + + if forward_shared_experts is not None: + torch.cuda.current_stream().wait_stream(alt_stream) + return StandardCombineInput(hidden_states=output) from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f24923a73..cdc1beb6f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -655,6 +655,7 @@ class DeepseekV2MoE(nn.Module): self._enable_a2a_moe = ( get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() ) + self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo() def get_moe_weights(self): return [ @@ -746,9 +747,10 @@ class DeepseekV2MoE(nn.Module): return self.forward_cpu(hidden_states, should_allreduce_fusion) if hidden_states.shape[0] > 0: - shared_output = self._forward_shared_experts( - hidden_states, gemm_output_zero_allocator - ) + if not self._fuse_shared_experts_inside_sbo: + shared_output = self._forward_shared_experts( + hidden_states, gemm_output_zero_allocator + ) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states, gemm_output_zero_allocator) topk_output = self.topk(hidden_states, router_logits) @@ -756,7 +758,27 @@ class DeepseekV2MoE(nn.Module): shared_output = None topk_output = self.topk.empty_topk_output(hidden_states.device) - final_hidden_states = self.experts(hidden_states, topk_output) + if self._fuse_shared_experts_inside_sbo: + shared_output = None + + def _forward_shared_experts_and_put_results(): + nonlocal shared_output + shared_output = self._forward_shared_experts( + hidden_states, gemm_output_zero_allocator + ) + + final_hidden_states = self.experts( + hidden_states, + topk_output, + **( + dict( + forward_shared_experts=_forward_shared_experts_and_put_results, + alt_stream=self.alt_stream, + ) + if self._fuse_shared_experts_inside_sbo + else {} + ), + ) if not _is_cuda and not _use_aiter: # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor