diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index d4db39a33..2b413b446 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -87,6 +87,7 @@ class _DpGatheredBufferWrapper: _global_dp_buffer_len: int _local_dp_buffer_len: int _global_num_tokens: Optional[List[int]] + _is_extend_in_batch: bool @classmethod def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device): @@ -145,6 +146,14 @@ class _DpGatheredBufferWrapper: def get_dp_device(cls) -> torch.device: return cls._device + @classmethod + def set_is_extend_in_batch(cls, is_extend_in_batch: bool): + cls._is_extend_in_batch = is_extend_in_batch + + @classmethod + def get_is_extend_in_batch(cls) -> bool: + return cls._is_extend_in_batch + def set_dp_buffer_len( global_dp_buffer_len: int, @@ -188,6 +197,14 @@ def get_dp_device() -> torch.device: return _DpGatheredBufferWrapper.get_dp_device() +def set_is_extend_in_batch(is_extend_in_batch: bool): + _DpGatheredBufferWrapper.set_is_extend_in_batch(is_extend_in_batch) + + +def get_is_extend_in_batch() -> bool: + return _DpGatheredBufferWrapper.get_is_extend_in_batch() + + def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): if not enable_dp_attention: return tp_rank, tp_size, 0 diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 89bab802c..62aa94390 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -566,7 +566,9 @@ def ep_scatter( scale_hidden_size = ceil_div(scale_hidden_size, 4) assert m_indices.shape[0] % BLOCK_E == 0 - assert recv_x_scale.dtype == output_tensor_scale.dtype + assert ( + recv_x_scale.dtype == output_tensor_scale.dtype + ), f"recv_x_scale.dtype: {recv_x_scale.dtype}, output_tensor_scale.dtype: {output_tensor_scale.dtype}" assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size _fwd_kernel_ep_scatter_1[(grid,)]( diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index c924494b0..d05b24098 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -20,18 +20,14 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( tma_align_input_scale, ) from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, sglang_per_token_group_quant_fp8, ) -from sglang.srt.layers.quantization.modelopt_quant import ( - CUTEDSL_MOE_NVFP4_DISPATCH, - ModelOptNvFp4FusedMoEMethod, -) from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod -from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.single_batch_overlap import DownGemmOverlapArgs from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu from sglang.srt.utils.offloader import get_offloader @@ -109,23 +105,6 @@ class DeepEPMoE(FusedMoE): self.deepep_mode = get_deepep_mode() - # TODO: move to the beginning of the file - from sglang.srt.distributed.parallel_state import get_tp_group - from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher - - self.deepep_dispatcher = MaybeTboDeepEPDispatcher( - group=get_tp_group().device_group, - router_topk=self.top_k, - permute_fusion=True, - num_experts=self.num_experts, - num_local_experts=self.num_local_experts, - hidden_size=hidden_size, - params_dtype=params_dtype, - deepep_mode=self.deepep_mode, - async_finish=True, # TODO - return_recv_hook=True, - ) - if self.deepep_mode.enable_low_latency() and not _is_npu: # NPU supports low_latency deepep without deepgemm assert ( @@ -165,19 +144,16 @@ class DeepEPMoE(FusedMoE): def forward( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - forward_batch: ForwardBatch, + topk_output: TopKOutput, forward_shared_experts=None, alt_stream=None, disable_sbo=False, ): + # We have to call SBO inside MoE to be compatible with hooks used in offloading return single_batch_overlap.execute_sbo( hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - forward_batch=forward_batch, + topk_output=topk_output, # SBO args experts=self, forward_shared_experts=forward_shared_experts, @@ -188,25 +164,14 @@ class DeepEPMoE(FusedMoE): def dispatch( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - forward_batch: ForwardBatch, + topk_output: TopKOutput, ): - return self.deepep_dispatcher.dispatch( + return self.dispatcher.dispatch( hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - forward_batch=forward_batch, - input_global_scale=( - self.w13_input_scale_quant - if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) - and self.quant_method.enable_flashinfer_cutedsl_moe - and CUTEDSL_MOE_NVFP4_DISPATCH - else None - ), + topk_output=topk_output, ) - def moe_impl( + def run_moe_core( self, dispatch_output: DispatchOutput, down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None, @@ -240,16 +205,14 @@ class DeepEPMoE(FusedMoE): def combine( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, + topk_ids: torch.Tensor, topk_weights: torch.Tensor, - forward_batch: ForwardBatch, overlap_args: Optional[Dict[str, Any]] = None, ): - return self.deepep_dispatcher.combine( + return self.dispatcher.combine( hidden_states=hidden_states, - topk_idx=topk_idx, + topk_ids=topk_ids, topk_weights=topk_weights, - forward_batch=forward_batch, overlap_args=overlap_args, ) @@ -257,9 +220,9 @@ class DeepEPMoE(FusedMoE): self, dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput], ): - hidden_states, topk_idx, topk_weights = ( + hidden_states, topk_ids, topk_weights = ( dispatch_output.hidden_states, - dispatch_output.topk_idx, + dispatch_output.topk_ids, dispatch_output.topk_weights, ) if hidden_states.shape[0] == 0: @@ -267,15 +230,15 @@ class DeepEPMoE(FusedMoE): # in original deepep, idx == -1 meaning invalid and will not be processed. # aiter does not accept -1, we use a expert mask to make these idx invalid # (idx == num_local_experts) meaning not used in aiter fused_moe - topk_idx_copy = topk_idx.to(torch.int32) - topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts + topk_ids_copy = topk_ids.to(torch.int32) + topk_ids_copy[topk_ids_copy == -1] = self.num_local_experts return fused_moe( hidden_states, self.w13_weight, self.w2_weight, topk_weights, - topk_idx_copy, + topk_ids_copy, w1_scale=self.w13_weight_scale_inv, w2_scale=self.w2_weight_scale_inv, quant_type=QuantType.per_128x128, @@ -291,18 +254,21 @@ class DeepEPMoE(FusedMoE): self, dispatch_output: DeepEPNormalOutput, ): - hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = ( - dispatch_output - ) - hidden_states_fp8, hidden_states_scale = hidden_states_fp8 + ( + hidden_states, + hidden_states_scale, + topk_ids, + topk_weights, + num_recv_tokens_per_expert, + ) = dispatch_output assert self.quant_method is not None assert self.moe_runner_config.activation == "silu" if num_recv_tokens_per_expert is None: - return hidden_states_fp8.bfloat16() + return hidden_states.bfloat16() all_tokens = sum(num_recv_tokens_per_expert) if all_tokens <= 0: - return hidden_states_fp8.bfloat16() - M, K = hidden_states_fp8.size() + return hidden_states.bfloat16() + M, K = hidden_states.size() N = self.w13_weight.size(1) scale_block_size = 128 @@ -323,35 +289,35 @@ class DeepEPMoE(FusedMoE): ), ) - hidden_states_fp8_shape = hidden_states_fp8.shape - hidden_states_fp8_device = hidden_states_fp8.device - hidden_states_fp8_dtype = hidden_states_fp8.dtype + hidden_states_shape = hidden_states.shape + hidden_states_device = hidden_states.device + hidden_states_dtype = hidden_states.dtype input_tensor = [ torch.empty( (all_tokens, K), - device=hidden_states_fp8.device, - dtype=hidden_states_fp8.dtype, + device=hidden_states.device, + dtype=hidden_states.dtype, ), ( # TODO check whether need `zeros` torch.zeros( (ceil_div(K // 128, 4), all_tokens), - device=hidden_states_fp8.device, + device=hidden_states.device, dtype=torch.int, ).transpose(0, 1) if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 else torch.empty( (all_tokens, K // 128), - device=hidden_states_fp8.device, + device=hidden_states.device, dtype=torch.float32, ) ), ] m_indices = torch.empty( - all_tokens, device=hidden_states_fp8.device, dtype=torch.int32 + all_tokens, device=hidden_states.device, dtype=torch.int32 ) - output_index = torch.empty_like(topk_idx) + output_index = torch.empty_like(topk_ids) if get_offloader().forbid_copy_engine_usage: num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce( @@ -367,9 +333,9 @@ class DeepEPMoE(FusedMoE): expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu) ep_scatter( - hidden_states_fp8, + hidden_states, hidden_states_scale, - topk_idx, + topk_ids, num_recv_tokens_per_expert_gpu, expert_start_loc, input_tensor[0], @@ -378,11 +344,11 @@ class DeepEPMoE(FusedMoE): output_index, scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, ) - dispose_tensor(hidden_states_fp8) + dispose_tensor(hidden_states) gateup_output = torch.empty( (all_tokens, N), - device=hidden_states_fp8_device, + device=hidden_states_device, dtype=torch.bfloat16, ) if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: @@ -403,7 +369,7 @@ class DeepEPMoE(FusedMoE): del gateup_output down_output = torch.empty( (all_tokens, K), - device=hidden_states_fp8_device, + device=hidden_states_device, dtype=torch.bfloat16, ) down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( @@ -425,11 +391,11 @@ class DeepEPMoE(FusedMoE): del down_input_fp8, down_input_scale gather_out = torch.empty( - hidden_states_fp8_shape, - device=hidden_states_fp8_device, + hidden_states_shape, + device=hidden_states_device, dtype=torch.bfloat16, ) - ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out) + ep_gather(down_output, topk_ids, topk_weights, output_index, gather_out) return gather_out @@ -438,13 +404,13 @@ class DeepEPMoE(FusedMoE): dispatch_output: DeepEPLLOutput, down_gemm_overlap_args: Optional[DownGemmOverlapArgs], ): - hidden_states, _, _, masked_m, _ = dispatch_output + hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output assert self.quant_method is not None assert self.moe_runner_config.activation == "silu" output = self.quant_method.apply_without_routing_weights( layer=self, - x=hidden_states, + x=(hidden_states, hidden_states_scale), masked_m=masked_m, moe_runner_config=self.moe_runner_config, down_gemm_overlap_args=down_gemm_overlap_args, @@ -466,25 +432,28 @@ class DeepEPMoE(FusedMoE): self, dispatch_output: DeepEPLLOutput, ): - hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output + hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output assert self.quant_method is not None assert self.moe_runner_config.activation == "silu" + assert ( + hidden_states_scale.dtype == torch.float32 + ), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}" # GroupGemm-0 - num_groups, m, k = hidden_states_fp8[0].size() + num_groups, m, k = hidden_states.size() n = self.w13_weight.size(1) expected_m = min(expected_m, m) gateup_output = torch.empty( - (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16 + (num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16 ) deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( - hidden_states_fp8, + (hidden_states, hidden_states_scale), self.w13_weight_fp8, gateup_output, masked_m, expected_m, ) - dispose_tensor(hidden_states_fp8[0]) + dispose_tensor(hidden_states) # Act down_input = torch.empty( @@ -557,11 +526,9 @@ class DeepEPMoE(FusedMoE): def _forward_normal(dispatch_output: DeepEPNormalOutput): if TYPE_CHECKING: assert isinstance(dispatch_output, DeepEPNormalOutput) - hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output - - if isinstance(hidden_states, tuple): - per_token_scale = hidden_states[1] - hidden_states = hidden_states[0] + hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = ( + dispatch_output + ) group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to( hidden_states.device @@ -571,7 +538,7 @@ class DeepEPMoE(FusedMoE): hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[self.w13_weight.permute(0, 2, 1)], - # per_token_scale=[per_token_scale], + # per_token_scale=[hidden_states_scale], split_item=2, group_list_type=group_list_type, group_type=0, @@ -591,7 +558,7 @@ class DeepEPMoE(FusedMoE): )[0] else: if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"): - hidden_states, per_token_scale = torch_npu.npu_dynamic_quant( + hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant( hidden_states ) # gmm1: gate_up_proj @@ -599,7 +566,7 @@ class DeepEPMoE(FusedMoE): x=[hidden_states], weight=[self.w13_weight], scale=[self.w13_weight_scale.to(output_dtype)], - per_token_scale=[per_token_scale], + per_token_scale=[hidden_states_scale], split_item=2, group_list_type=group_list_type, group_type=0, @@ -631,11 +598,14 @@ class DeepEPMoE(FusedMoE): def _forward_ll(dispatch_output: DeepEPLLOutput): if TYPE_CHECKING: assert isinstance(dispatch_output, DeepEPLLOutput) - hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output - - if isinstance(hidden_states, tuple): - per_token_scale = hidden_states[1] - hidden_states = hidden_states[0] + ( + hidden_states, + hidden_states_scale, + topk_ids, + topk_weights, + group_list, + _, + ) = dispatch_output group_list = group_list.to(torch.int64) @@ -644,7 +614,7 @@ class DeepEPMoE(FusedMoE): hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[self.w13_weight.permute(0, 2, 1)], - # per_token_scale=[per_token_scale], + # per_token_scale=[hidden_states_scale], split_item=2, group_list_type=group_list_type, group_type=0, @@ -678,7 +648,7 @@ class DeepEPMoE(FusedMoE): hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( x=hidden_states, weight_scale=self.w13_weight_scale.to(torch.float32), - activation_scale=per_token_scale, + activation_scale=hidden_states_scale, bias=None, quant_scale=None, quant_offset=None, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 36e7964a8..84a35b96a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -11,14 +11,19 @@ from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size, + get_tp_group, tensor_model_parallel_all_reduce, ) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.layers.moe import ( MoeRunnerConfig, + get_deepep_mode, + get_moe_a2a_backend, get_moe_runner_backend, should_use_flashinfer_trtllm_moe, ) +from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput +from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher from sglang.srt.layers.moe.token_dispatcher.standard import ( StandardDispatcher, StandardDispatchOutput, @@ -32,6 +37,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight +from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, @@ -71,6 +77,27 @@ def _get_tile_tokens_dim(num_tokens, top_k, num_experts): return tile_tokens_dim +def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher: + a2a_backend = get_moe_a2a_backend() + if a2a_backend.is_none(): + return StandardDispatcher(moe_runner_config) + elif a2a_backend.is_deepep(): + return MaybeTboDeepEPDispatcher( + group=get_tp_group().device_group, + router_topk=moe_runner_config.top_k, + permute_fusion=True, + num_experts=moe_runner_config.num_experts, + num_local_experts=moe_runner_config.num_local_experts, + hidden_size=moe_runner_config.hidden_size, + params_dtype=moe_runner_config.params_dtype, + deepep_mode=get_deepep_mode(), + async_finish=True, + return_recv_hook=True, + ) + else: + raise NotImplementedError(f"Unsupported a2a backend: {a2a_backend}") + + class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" @@ -132,8 +159,6 @@ class FusedMoE(torch.nn.Module): self.hidden_size = hidden_size self.num_experts = num_experts self.num_fused_shared_experts = num_fused_shared_experts - self.expert_map_cpu = None - self.expert_map_gpu = None enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass() @@ -149,19 +174,6 @@ class FusedMoE(torch.nn.Module): assert num_experts % self.moe_ep_size == 0 self.num_local_experts = num_experts // self.moe_ep_size - if self.moe_ep_size > 1: - # TODO(ch-wan): support shared experts fusion - # Create a tensor of size num_experts filled with -1 - self.expert_map_cpu = torch.full( - (self.num_experts,), -1, dtype=torch.int32, device="cpu" - ) - # Create a expert map for the local experts - self.expert_map_cpu[ - self.moe_ep_rank - * self.num_local_experts : (self.moe_ep_rank + 1) - * self.num_local_experts - ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu") - assert intermediate_size % self.moe_tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size self.reduce_results = reduce_results @@ -219,7 +231,7 @@ class FusedMoE(torch.nn.Module): ) self.quant_method.create_moe_runner(self, self.moe_runner_config) - self.dispatcher = StandardDispatcher() + self.dispatcher = create_moe_dispatcher(self.moe_runner_config) self.should_fuse_routed_scaling_factor_in_topk = isinstance( self.quant_method, ModelOptNvFp4FusedMoEMethod @@ -453,9 +465,12 @@ class FusedMoE(torch.nn.Module): expert_data.copy_(loaded_weight) def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: - if self.expert_map_cpu is None: - return expert_id - return self.expert_map_cpu[expert_id].item() + start_idx = self.moe_ep_rank * self.num_local_experts + end_idx = (self.moe_ep_rank + 1) * self.num_local_experts + if start_idx <= expert_id < end_idx: + return expert_id - start_idx + else: + return -1 def weight_loader( self, @@ -804,32 +819,18 @@ class FusedMoE(torch.nn.Module): origin_hidden_states_dim = hidden_states.shape[-1] assert self.quant_method is not None - if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe: - if self.expert_map_cpu is not None and self.expert_map_gpu is None: - # If we are in EP mode, we need to move the expert map to GPU. - self.expert_map_gpu = self.expert_map_cpu.to(device="cuda") - - if self.expert_map_gpu is not None: - if TopKOutputChecker.format_is_standard(topk_output): - topk_output = topk_output._replace( - topk_ids=self.expert_map_gpu[topk_output.topk_ids] - ) - elif TopKOutputChecker.format_is_triton_kernel(topk_output): - raise NotImplementedError() - dispatch_output = self.dispatcher.dispatch( hidden_states=hidden_states, topk_output=topk_output ) - # TODO: consider using symmetric memory - combine_input = self.quant_method.apply( - layer=self, + combine_input = self.run_moe_core( dispatch_output=dispatch_output, **kwargs, ) final_hidden_states = self.dispatcher.combine(combine_input) + # TODO: should we add some conditions here? final_hidden_states = final_hidden_states[ ..., :origin_hidden_states_dim ].contiguous() @@ -839,6 +840,14 @@ class FusedMoE(torch.nn.Module): return final_hidden_states + def run_moe_core(self, dispatch_output: DispatchOutput, **kwargs) -> CombineInput: + # TODO: consider using symmetric memory + return self.quant_method.apply( + layer=self, + dispatch_output=dispatch_output, + **kwargs, + ) + @classmethod def make_expert_params_mapping( cls, diff --git a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py index 7526f73de..d662b3afd 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py @@ -23,6 +23,7 @@ from sglang.srt.layers.moe.token_dispatcher.mooncake import ( ) from sglang.srt.layers.moe.token_dispatcher.standard import ( StandardCombineInput, + StandardDispatcher, StandardDispatchOutput, ) @@ -38,6 +39,7 @@ __all__ = [ "MooncakeCombineInput", "MooncakeDispatchOutput", "MooncakeEPDispatcher", + "StandardDispatcher", "StandardDispatchOutput", "StandardCombineInput", "DeepEPConfig", diff --git a/python/sglang/srt/layers/moe/token_dispatcher/base.py b/python/sglang/srt/layers/moe/token_dispatcher/base.py index 155860886..1af84caeb 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/base.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/base.py @@ -73,7 +73,7 @@ class DispatchOutputFormat(Enum): class DispatchOutput(Protocol): """Protocol for dispatch outputs in different formats.""" - # TODO: add hidden_states to the protocol + hidden_states: torch.Tensor @property def format(self) -> DispatchOutputFormat: ... diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index dd94d4464..8c6796bf1 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers import deep_gemm_wrapper +from sglang.srt.layers.dp_attention import get_is_extend_in_batch from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, BaseDispatcherConfig, @@ -15,6 +16,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import ( DispatchOutput, DispatchOutputFormat, ) +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.utils import ( DeepEPMode, get_deepep_config, @@ -51,8 +53,6 @@ from enum import Enum, IntEnum, auto import torch import torch.distributed as dist -from sglang.srt.model_executor.forward_batch_info import ForwardBatch - _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() logger = logging.getLogger(__name__) @@ -61,9 +61,9 @@ logger = logging.getLogger(__name__) class DeepEPNormalOutput(NamedTuple): """DeepEP normal dispatch output.""" - hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] - # hidden_states_scale - topk_idx: torch.Tensor + hidden_states: torch.Tensor + hidden_states_scale: Optional[torch.Tensor] + topk_ids: torch.Tensor topk_weights: torch.Tensor num_recv_tokens_per_expert: List[int] @@ -75,8 +75,9 @@ class DeepEPNormalOutput(NamedTuple): class DeepEPLLOutput(NamedTuple): """DeepEP low latency dispatch output.""" - hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor] - topk_idx: torch.Tensor + hidden_states: torch.Tensor + hidden_states_scale: Optional[torch.Tensor] + topk_ids: torch.Tensor topk_weights: torch.Tensor masked_m: torch.Tensor expected_m: int @@ -314,9 +315,7 @@ class _DeepEPDispatcherImplBase: def dispatch_a( self, hidden_states: torch.Tensor, - input_global_scale: Optional[torch.Tensor], - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + topk_output: TopKOutput, ): raise NotImplementedError @@ -326,7 +325,7 @@ class _DeepEPDispatcherImplBase: def combine_a( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, + topk_ids: torch.Tensor, topk_weights: torch.Tensor, overlap_args: Optional["CombineOverlapArgs"], ): @@ -345,15 +344,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): self.async_finish = async_finish self.src2dst = None + self.quant_config = {} def dispatch_a( self, hidden_states: torch.Tensor, - input_global_scale: Optional[torch.Tensor], - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + topk_output: TopKOutput, ): - topk_idx = topk_idx.to(torch.int64) + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + topk_ids = topk_ids.to(torch.int64) if ( deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and not get_moe_runner_backend().is_cutlass() @@ -367,25 +366,35 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, ) previous_event = Buffer.capture() if self.async_finish else None - return hidden_states, topk_idx, topk_weights, previous_event + return hidden_states, topk_ids, topk_weights, previous_event - def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): + def dispatch_b(self, hidden_states, topk_ids, topk_weights, previous_event): ( hidden_states, - topk_idx, + topk_ids, topk_weights, num_recv_tokens_per_expert, event, - ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event) + ) = self._dispatch_core(hidden_states, topk_ids, topk_weights, previous_event) event.current_stream_wait() if self.async_finish else () + + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_scale = hidden_states + else: + hidden_states_scale = None + return DeepEPNormalOutput( - hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert + hidden_states, + hidden_states_scale, + topk_ids, + topk_weights, + num_recv_tokens_per_expert, ) def _dispatch_core( self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - topk_idx: torch.Tensor, + topk_ids: torch.Tensor, topk_weights: torch.Tensor, previous_event, ): @@ -397,7 +406,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): is_token_in_rank, previous_event, ) = buffer.get_dispatch_layout( - topk_idx, + topk_ids, self.num_experts, previous_event=previous_event, async_finish=self.async_finish, @@ -409,14 +418,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ( recv_x, - recv_topk_idx, + recv_topk_ids, recv_topk_weights, num_recv_tokens_per_expert, self.handle, event, ) = buffer.dispatch( x, - topk_idx=topk_idx, + topk_idx=topk_ids, topk_weights=topk_weights, num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, @@ -437,7 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): return ( recv_x, - recv_topk_idx, + recv_topk_ids, recv_topk_weights, num_recv_tokens_per_expert, event, @@ -446,40 +455,16 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): def combine_a( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, + topk_ids: torch.Tensor, topk_weights: torch.Tensor, overlap_args: Optional["CombineOverlapArgs"], ): - from sglang.srt.layers.moe.ep_moe.kernels import ( - deepep_post_reorder_triton_kernel, - ) if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu: output = hidden_states else: - if hidden_states.shape[0] > 0: - num_tokens = self.src2dst.shape[0] // self.router_topk - output = torch.empty( - (num_tokens, hidden_states.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - deepep_post_reorder_triton_kernel[(num_tokens,)]( - hidden_states, - output, - self.src2dst, - topk_idx, - topk_weights, - self.router_topk, - hidden_states.shape[1], - BLOCK_SIZE=512, - ) - else: - output = torch.zeros( - (0, hidden_states.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) + raise NotImplementedError() # triton runner was supported but it's temporarily disabled + previous_event = Buffer.capture() if self.async_finish else None return output, previous_event @@ -514,6 +499,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): self.num_experts, ) + def set_quant_config(self, quant_config: dict): + self.quant_config = quant_config + class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): def __init__(self, return_recv_hook: bool, **kwargs): @@ -525,28 +513,27 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): """ self.return_recv_hook = return_recv_hook self.device_module = torch.get_device_module() + self.quant_config = {} def dispatch_a( self, hidden_states: torch.Tensor, - input_global_scale: Optional[torch.Tensor], - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + topk_output: TopKOutput, ): buffer = self._get_buffer() - topk_idx = topk_idx.to(torch.int64) + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + topk_ids = topk_ids.to(torch.int64) expected_m = ( - hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1] + hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1] + self.num_experts ) // self.num_experts hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, - input_global_scale, - topk_idx, + topk_ids, ) return ( hidden_states, - topk_idx, + topk_ids, topk_weights, masked_m, expected_m, @@ -557,7 +544,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): def dispatch_b( self, hidden_states, - topk_idx, + topk_ids, topk_weights, masked_m, expected_m, @@ -570,9 +557,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): masked_m ) + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_scale = hidden_states + else: + hidden_states_scale = None + deepep_output = DeepEPLLOutput( hidden_states, - topk_idx, + hidden_states_scale, + topk_ids, topk_weights, masked_m, expected_m, @@ -582,10 +575,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): def _dispatch_core( self, hidden_states: torch.Tensor, - input_global_scale: Optional[torch.Tensor], - topk_idx: torch.Tensor, + topk_ids: torch.Tensor, ): use_nvfp4 = use_fp8 = False + input_global_scale = self.quant_config.get("input_global_scale", None) if input_global_scale is not None: use_nvfp4 = True elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"): @@ -595,7 +588,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = ( buffer.low_latency_dispatch( hidden_states, - topk_idx, + topk_ids, self.num_max_dispatch_tokens_per_rank, self.num_experts, use_fp8=use_fp8, @@ -618,13 +611,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): def combine_a( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, + topk_ids: torch.Tensor, topk_weights: torch.Tensor, overlap_args: Optional["CombineOverlapArgs"], ): hidden_states, event, hook = self._combine_core( hidden_states, - topk_idx, + topk_ids, topk_weights, overlap_args=overlap_args, ) @@ -644,7 +637,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): def _combine_core( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, + topk_ids: torch.Tensor, topk_weights: torch.Tensor, overlap_args: Optional["CombineOverlapArgs"], ): @@ -658,7 +651,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): with ctx: combined_hidden_states, event, hook = buffer.low_latency_combine( x=hidden_states, - topk_idx=topk_idx, + topk_idx=topk_ids, topk_weights=topk_weights, handle=self.handle, async_finish=not self.return_recv_hook, @@ -688,6 +681,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): self.num_experts, ) + def set_quant_config(self, quant_config: dict): + self.quant_config = quant_config + @dataclass class _Stage(Enum): @@ -745,25 +741,20 @@ class DeepEPDispatcher(BaseDispatcher): def dispatch_a( self, hidden_states: torch.Tensor, - input_global_scale: Optional[torch.Tensor], - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - forward_batch: ForwardBatch, + topk_output: TopKOutput, ): self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A) - inner_state = self._get_impl(forward_batch).dispatch_a( + inner_state = self._get_impl().dispatch_a( hidden_states=hidden_states, - input_global_scale=input_global_scale, - topk_idx=topk_idx, - topk_weights=topk_weights, + topk_output=topk_output, ) - self._dispatch_intermediate_state = forward_batch, inner_state + self._dispatch_intermediate_state = inner_state def dispatch_b(self): self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B) - forward_batch, inner_state = self._dispatch_intermediate_state + inner_state = self._dispatch_intermediate_state del self._dispatch_intermediate_state - return self._get_impl(forward_batch).dispatch_b(*inner_state) + return self._get_impl().dispatch_b(*inner_state) def combine(self, *args, **kwargs) -> Tuple: self.combine_a(*args, **kwargs) @@ -773,30 +764,28 @@ class DeepEPDispatcher(BaseDispatcher): def combine_a( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, + topk_ids: torch.Tensor, topk_weights: torch.Tensor, - forward_batch: ForwardBatch, overlap_args: Optional["CombineOverlapArgs"] = None, ): self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) - inner_state = self._get_impl(forward_batch).combine_a( + inner_state = self._get_impl().combine_a( hidden_states=hidden_states, - topk_idx=topk_idx, + topk_ids=topk_ids, topk_weights=topk_weights, overlap_args=overlap_args, ) - self._combine_intermediate_state = forward_batch, inner_state + self._combine_intermediate_state = inner_state def combine_b(self): self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL) - forward_batch, inner_state = self._combine_intermediate_state + inner_state = self._combine_intermediate_state del self._combine_intermediate_state - return self._get_impl(forward_batch).combine_b(*inner_state) + return self._get_impl().combine_b(*inner_state) - def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase: - resolved_deepep_mode = self.deepep_mode.resolve( - forward_batch.is_extend_in_batch - ) + def _get_impl(self) -> _DeepEPDispatcherImplBase: + is_extend_in_batch = get_is_extend_in_batch() + resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch) if resolved_deepep_mode == DeepEPMode.NORMAL: return self._normal_dispatcher elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY: @@ -807,3 +796,9 @@ class DeepEPDispatcher(BaseDispatcher): def _update_stage(self, old_stage, new_stage): assert self._stage == old_stage self._stage = new_stage + + def set_quant_config(self, quant_config: dict): + if self.deepep_mode.enable_low_latency(): + self._low_latency_dispatcher.set_quant_config(quant_config) + if self.deepep_mode.enable_normal(): + self._normal_dispatcher.set_quant_config(quant_config) diff --git a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py index 54ba8f1b5..201c1b5f2 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import NamedTuple, Optional, Tuple from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.layers.dp_attention import get_is_extend_in_batch from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, CombineInput, @@ -12,6 +13,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import ( DispatchOutput, DispatchOutputFormat, ) +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.utils import get_int_env_var @@ -27,16 +29,15 @@ from enum import Enum, auto import torch import torch.distributed as dist -from sglang.srt.model_executor.forward_batch_info import ForwardBatch - logger = logging.getLogger(__name__) class MooncakeDispatchOutput(NamedTuple): """Mooncake EP dispatch output.""" - hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor] - topk_idx: torch.Tensor + hidden_states: torch.Tensor + hidden_states_scale: torch.Tensor + topk_ids: torch.Tensor topk_weights: torch.Tensor masked_m: torch.Tensor expected_m: int @@ -164,23 +165,23 @@ class _MooncakeEPDispatcherImpl: def dispatch_a( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + topk_output: TopKOutput, ): + topk_ids, topk_weights = topk_output.topk_ids, topk_output.topk_weights buffer = self._get_buffer() - topk_idx = topk_idx.to(torch.int64) + topk_ids = topk_ids.to(torch.int64) expected_m = ( - hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1] + hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1] + self.num_experts ) // self.num_experts hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, - topk_idx, + topk_ids, use_fp8=True, ) return ( hidden_states, - topk_idx, + topk_ids, topk_weights, masked_m, expected_m, @@ -191,7 +192,7 @@ class _MooncakeEPDispatcherImpl: def dispatch_b( self, hidden_states, - topk_idx, + topk_ids, topk_weights, masked_m, expected_m, @@ -206,7 +207,7 @@ class _MooncakeEPDispatcherImpl: return MooncakeDispatchOutput( hidden_states, - topk_idx, + topk_ids, topk_weights, masked_m, expected_m, @@ -215,14 +216,14 @@ class _MooncakeEPDispatcherImpl: def _dispatch_core( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, + topk_ids: torch.Tensor, use_fp8: bool = False, ): buffer = self._get_buffer() packed_recv_hidden, packed_recv_count, self.handle, event, hook = ( buffer.dispatch( hidden_states, - topk_idx, + topk_ids, self.active_ranks, self.num_max_dispatch_tokens_per_rank, self.num_experts, @@ -237,12 +238,12 @@ class _MooncakeEPDispatcherImpl: def combine_a( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, + topk_ids: torch.Tensor, topk_weights: torch.Tensor, ): hidden_states, event, hook = self._combine_core( hidden_states, - topk_idx, + topk_ids, topk_weights, ) return hidden_states, event, hook @@ -254,13 +255,13 @@ class _MooncakeEPDispatcherImpl: def _combine_core( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, + topk_ids: torch.Tensor, topk_weights: torch.Tensor, ): buffer = self._get_buffer() combined_hidden_states, event, hook = buffer.combine( hidden_states, - topk_idx, + topk_ids, topk_weights, self.active_ranks, -1 if self.first_execution else self.timeout_us, @@ -332,24 +333,20 @@ class MooncakeEPDispatcher(BaseDispatcher): def dispatch_a( self, hidden_states: torch.Tensor, - input_global_scale: Optional[torch.Tensor], - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - forward_batch: ForwardBatch, + topk_output: TopKOutput, ): self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A) - inner_state = self._get_impl(forward_batch).dispatch_a( + inner_state = self._get_impl().dispatch_a( hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, + topk_output=topk_output, ) - self._dispatch_intermediate_state = forward_batch, inner_state + self._dispatch_intermediate_state = inner_state def dispatch_b(self): self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B) - forward_batch, inner_state = self._dispatch_intermediate_state + inner_state = self._dispatch_intermediate_state del self._dispatch_intermediate_state - return self._get_impl(forward_batch).dispatch_b(*inner_state) + return self._get_impl().dispatch_b(*inner_state) def combine(self, *args, **kwargs) -> Tuple: self.combine_a(*args, **kwargs) @@ -359,29 +356,27 @@ class MooncakeEPDispatcher(BaseDispatcher): def combine_a( self, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, + topk_ids: torch.Tensor, topk_weights: torch.Tensor, - forward_batch: ForwardBatch, overlap_args: Optional = None, ): self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) - inner_state = self._get_impl(forward_batch).combine_a( + inner_state = self._get_impl().combine_a( hidden_states=hidden_states, - topk_idx=topk_idx, + topk_ids=topk_ids, topk_weights=topk_weights, ) - self._combine_intermediate_state = forward_batch, inner_state + self._combine_intermediate_state = inner_state def combine_b(self): self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL) - forward_batch, inner_state = self._combine_intermediate_state + inner_state = self._combine_intermediate_state del self._combine_intermediate_state - return self._get_impl(forward_batch).combine_b(*inner_state) + return self._get_impl().combine_b(*inner_state) - def _get_impl(self, forward_batch: ForwardBatch) -> _MooncakeEPDispatcherImpl: - resolved_deepep_mode = self.deepep_mode.resolve( - forward_batch.is_extend_in_batch - ) + def _get_impl(self) -> _MooncakeEPDispatcherImpl: + is_extend_in_batch = get_is_extend_in_batch() + resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch) if resolved_deepep_mode == DeepEPMode.NORMAL: raise NotImplementedError elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY: @@ -392,3 +387,6 @@ class MooncakeEPDispatcher(BaseDispatcher): def _update_stage(self, old_stage, new_stage): assert self._stage == old_stage self._stage = new_stage + + def set_quant_config(self, quant_config: dict): + pass diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py index f984104f6..5d4abde9a 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/standard.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py @@ -4,6 +4,11 @@ from typing import TYPE_CHECKING, NamedTuple import torch +from sglang.srt.distributed import ( + get_moe_expert_parallel_rank, + get_moe_expert_parallel_world_size, +) +from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, CombineInput, @@ -11,6 +16,8 @@ from sglang.srt.layers.moe.token_dispatcher.base import ( DispatchOutput, DispatchOutputFormat, ) +from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker +from sglang.srt.layers.moe.utils import get_moe_runner_backend if TYPE_CHECKING: from sglang.srt.layers.moe.topk import TopKOutput @@ -45,9 +52,45 @@ assert isinstance(StandardCombineInput, CombineInput) class StandardDispatcher(BaseDispatcher): + def __init__(self, moe_runner_config: MoeRunnerConfig): + self.moe_ep_size = get_moe_expert_parallel_world_size() + self.enable_flashinfer_cutlass_moe = ( + get_moe_runner_backend().is_flashinfer_cutlass() + ) + self.num_experts = moe_runner_config.num_experts + self.num_local_experts = moe_runner_config.num_local_experts + self.moe_ep_rank = get_moe_expert_parallel_rank() + self.local_expert_mapping = None + def dispatch( self, hidden_states: torch.Tensor, topk_output: TopKOutput ) -> DispatchOutput: + + if ( + self.moe_ep_size > 1 + and not self.enable_flashinfer_cutlass_moe + and TopKOutputChecker.format_is_standard(topk_output) + ): + if self.local_expert_mapping is None: + self.local_expert_mapping = torch.full( + (self.num_experts,), -1, dtype=torch.int32, device="cuda" + ) + self.local_expert_mapping[ + self.moe_ep_rank + * self.num_local_experts : (self.moe_ep_rank + 1) + * self.num_local_experts + ] = torch.arange( + 0, self.num_local_experts, dtype=torch.int32, device="cuda" + ) + + if self.local_expert_mapping is not None: + if TopKOutputChecker.format_is_standard(topk_output): + topk_output = topk_output._replace( + topk_ids=self.local_expert_mapping[topk_output.topk_ids] + ) + elif TopKOutputChecker.format_is_triton_kernel(topk_output): + raise NotImplementedError() + return StandardDispatchOutput( hidden_states=hidden_states, topk_output=topk_output ) @@ -59,3 +102,6 @@ class StandardDispatcher(BaseDispatcher): # TODO: this branch should be removed in the future assert isinstance(combine_input, torch.Tensor) return combine_input + + def set_quant_config(self, quant_config: dict): + pass diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 9af3b8a2b..74cbfbc21 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -365,9 +365,10 @@ class TopK(CustomOp): def empty_topk_output(self, device: torch.device) -> TopKOutput: topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device) - topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device) + topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device) + # FIXME: router_logits should be of size (0, num_experts) router_logits = torch.empty((0, topk), dtype=torch.float32, device=device) - return StandardTopKOutput(topk_weights, topk_idx, router_logits) + return StandardTopKOutput(topk_weights, topk_ids, router_logits) # ------------------------------- TopK implementation ------------------------------------- diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 9ae270caf..627e991c4 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1244,6 +1244,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): (1 / w2_input_scale).to(torch.float32), requires_grad=False ) + layer.dispatcher.set_quant_config( + {"input_global_scale": layer.w13_input_scale_quant} + ) + # Validate weight scales for name, weight_scale in [ ("w13", layer.w13_weight_scale), diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index 7c5d4554a..44deaa8af 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -339,7 +339,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): hidden_states, topk_idx, topk_weights = ( dispatch_output.hidden_states, - dispatch_output.topk_idx, + dispatch_output.topk_ids, dispatch_output.topk_weights, ) if isinstance(hidden_states, tuple): diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index aedbd037c..9eabae6d5 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, set_dp_buffer_len, + set_is_extend_in_batch, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer @@ -639,6 +640,7 @@ class CudaGraphRunner: # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None set_dp_buffer_len(global_dp_buffer_len, num_tokens) + set_is_extend_in_batch(False) kwargs = {} if ( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index f34f36d70..398bb0da0 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -44,6 +44,7 @@ from sglang.srt.layers.dp_attention import ( get_attention_dp_rank, get_attention_tp_size, set_dp_buffer_len, + set_is_extend_in_batch, ) from sglang.srt.utils import get_compiler_backend, is_npu, support_triton @@ -688,6 +689,7 @@ class ForwardBatch: self.global_dp_buffer_len = buffer_len set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens) + set_is_extend_in_batch(self.is_extend_in_batch) bs = self.batch_size diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index 922642e56..e4af95d17 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, set_dp_buffer_len, + set_is_extend_in_batch, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.torchao_utils import save_gemlite_cache @@ -377,6 +378,9 @@ class PiecewiseCudaGraphRunner: # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None set_dp_buffer_len(global_dp_buffer_len, num_tokens) + # FIXME: the implementation is hacky. `is_extend_in_batch`` is for determining the deepep mode. + # It is True in this context but we need to set it to use low latency deepep mode. + set_is_extend_in_batch(False) kwargs = {} with set_forward_context(forward_batch, self.attention_layers): diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py index e768c0a53..a4af1e43e 100644 --- a/python/sglang/srt/models/bailing_moe.py +++ b/python/sglang/srt/models/bailing_moe.py @@ -380,7 +380,7 @@ class BailingMoESparseMoeBlock(nn.Module): if self.num_shared_experts > 0: shared_output = self.shared_experts(hidden_states) - topk_weights, topk_idx, _ = self.topk( + topk_output = self.topk( hidden_states, router_logits, num_token_non_padded=forward_batch.num_token_non_padded, @@ -389,53 +389,15 @@ class BailingMoESparseMoeBlock(nn.Module): ), ) else: - topk_idx = torch.full( - (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device - ) - topk_weights = torch.empty( - (0, self.top_k), dtype=torch.float32, device=hidden_states.device - ) - - if self.ep_size > 1: - ( - hidden_states, - topk_idx, - topk_weights, - reorder_topk_ids, - num_recv_tokens_per_expert, - seg_indptr, - masked_m, - expected_m, - ) = self.deepep_dispatcher.dispatch( - hidden_states, - topk_idx, - topk_weights, - forward_batch=forward_batch, - ) + topk_output = self.topk.empty_topk_output(hidden_states.device) final_hidden_states = self.experts( hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - reorder_topk_ids=reorder_topk_ids, - seg_indptr=seg_indptr, - masked_m=masked_m, - expected_m=expected_m, - num_recv_tokens_per_expert=num_recv_tokens_per_expert, - forward_batch=forward_batch, + topk_output=topk_output, ) - if self.ep_size > 1: - final_hidden_states = self.deepep_dispatcher.combine( - final_hidden_states, - topk_idx, - topk_weights, - forward_batch=forward_batch, - ) - - final_hidden_states *= self.routed_scaling_factor if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + final_hidden_states += shared_output return final_hidden_states diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 6ca168670..76a946757 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -74,7 +74,6 @@ from sglang.srt.layers.linear import ( ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe import ( - get_deepep_mode, get_moe_a2a_backend, should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_trtllm_moe, @@ -112,10 +111,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.server_args import get_global_server_args from sglang.srt.single_batch_overlap import SboFlags from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.two_batch_overlap import ( - MaybeTboDeepEPDispatcher, - model_forward_maybe_tbo, -) +from sglang.srt.two_batch_overlap import model_forward_maybe_tbo from sglang.srt.utils import ( BumpAllocator, LazyValue, @@ -649,19 +645,6 @@ class DeepseekV2MoE(nn.Module): else None ) - self.deepep_dispatcher = MaybeTboDeepEPDispatcher( - group=parallel_state.get_tp_group().device_group, - router_topk=self.top_k, - permute_fusion=True, - num_experts=self.num_experts, - num_local_experts=config.n_routed_experts // self.tp_size, - hidden_size=config.hidden_size, - params_dtype=config.torch_dtype, - deepep_mode=get_deepep_mode(), - async_finish=True, - return_recv_hook=True, - ) - self._enable_a2a_moe = ( get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() ) @@ -874,7 +857,7 @@ class DeepseekV2MoE(nn.Module): router_logits = self.gate(hidden_states) if not self._fuse_shared_experts_inside_sbo: shared_output = self._forward_shared_experts(hidden_states) - topk_weights, topk_idx, _ = self.topk( + topk_output = self.topk( hidden_states, router_logits, num_token_non_padded=forward_batch.num_token_non_padded, @@ -883,9 +866,7 @@ class DeepseekV2MoE(nn.Module): ), ) else: - topk_weights, topk_idx, _ = self.topk.empty_topk_output( - hidden_states.device - ) + topk_output = self.topk.empty_topk_output(hidden_states.device) if self._fuse_shared_experts_inside_sbo: shared_output = None @@ -896,9 +877,7 @@ class DeepseekV2MoE(nn.Module): final_hidden_states = self.experts( hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - forward_batch=forward_batch, + topk_output=topk_output, **( dict( forward_shared_experts=_forward_shared_experts_and_put_results, @@ -960,7 +939,7 @@ class DeepseekV2MoE(nn.Module): with get_global_expert_distribution_recorder().with_current_layer( self.layer_id ): - state.topk_weights_local, state.topk_idx_local, _ = self.topk( + state.topk_output = self.topk( hidden_states=hidden_states, router_logits=router_logits, num_token_non_padded=state.forward_batch.num_token_non_padded, @@ -969,21 +948,13 @@ class DeepseekV2MoE(nn.Module): ), ) else: - state.topk_idx_local = torch.full( - (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device - ) - state.topk_weights_local = torch.empty( - (0, self.top_k), dtype=torch.float32, device=hidden_states.device - ) + state.topk_output = self.topk.empty_topk_output(hidden_states.device) def op_dispatch_a(self, state): if self.ep_size > 1: - self.experts.deepep_dispatcher.dispatch_a( + self.experts.dispatcher.dispatch_a( hidden_states=state.hidden_states_mlp_input, - input_global_scale=None, - topk_idx=state.pop("topk_idx_local"), - topk_weights=state.pop("topk_weights_local"), - forward_batch=state.forward_batch, + topk_output=state.pop("topk_output"), tbo_subbatch_index=state.get("tbo_subbatch_index"), ) @@ -992,32 +963,29 @@ class DeepseekV2MoE(nn.Module): with get_global_expert_distribution_recorder().with_current_layer( self.layer_id ): - state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b( + state.dispatch_output = self.experts.dispatcher.dispatch_b( tbo_subbatch_index=state.get("tbo_subbatch_index"), ) def op_experts(self, state): - state.hidden_states_experts_output = self.experts.moe_impl( + state.hidden_states_experts_output = self.experts.run_moe_core( dispatch_output=state.dispatch_output, ) def op_combine_a(self, state): if self.ep_size > 1: - self.experts.deepep_dispatcher.combine_a( + self.experts.dispatcher.combine_a( hidden_states=state.pop("hidden_states_experts_output"), - topk_idx=state.dispatch_output.topk_idx, + topk_ids=state.dispatch_output.topk_ids, topk_weights=state.dispatch_output.topk_weights, - forward_batch=state.forward_batch, tbo_subbatch_index=state.get("tbo_subbatch_index"), ) state.pop("dispatch_output") def op_combine_b(self, state): if self.ep_size > 1: - state.hidden_states_after_combine = ( - self.experts.deepep_dispatcher.combine_b( - tbo_subbatch_index=state.get("tbo_subbatch_index"), - ) + state.hidden_states_after_combine = self.experts.dispatcher.combine_b( + tbo_subbatch_index=state.get("tbo_subbatch_index"), ) def op_output(self, state): diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index df1bbc362..6a5f24679 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -27,7 +27,6 @@ from sglang.srt.distributed import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - parallel_state, tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul @@ -49,7 +48,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend +from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK @@ -71,7 +70,6 @@ from sglang.srt.models.deepseek_v2 import ( DeepseekV2MoE, ) from sglang.srt.server_args import get_global_server_args -from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher from sglang.srt.utils import ( BumpAllocator, LazyValue, @@ -477,19 +475,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): else None ) - self.deepep_dispatcher = MaybeTboDeepEPDispatcher( - group=parallel_state.get_tp_group().device_group, - router_topk=self.top_k, - permute_fusion=True, - num_experts=self.num_experts, - num_local_experts=config.n_routed_experts // self.tp_size, - hidden_size=config.hidden_size, - params_dtype=config.torch_dtype, - deepep_mode=get_deepep_mode(), - async_finish=True, - return_recv_hook=True, - ) - self._enable_a2a_moe = ( get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() ) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 05612fca0..1b5738adb 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -219,7 +219,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) shared_output = self._forward_shared_experts(hidden_states) - topk_weights, topk_idx, _ = self.topk( + topk_output = self.topk( hidden_states, router_logits, num_token_non_padded=forward_batch.num_token_non_padded, @@ -228,14 +228,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ), ) else: - topk_weights, topk_idx, _ = self.topk.empty_topk_output( - hidden_states.device - ) + topk_output = self.topk.empty_topk_output(hidden_states.device) final_hidden_states = self.experts( hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - forward_batch=forward_batch, + topk_output=topk_output, ) if shared_output is not None: diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index a3044bef9..721fa221f 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -180,7 +180,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): if hidden_states.shape[0] > 0: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - topk_weights, topk_idx, _ = self.topk( + topk_output = self.topk( hidden_states, router_logits, num_token_non_padded=forward_batch.num_token_non_padded, @@ -189,17 +189,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ), ) else: - topk_idx = torch.full( - (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device - ) - topk_weights = torch.empty( - (0, self.top_k), dtype=torch.float32, device=hidden_states.device - ) + topk_output = self.topk.empty_topk_output(hidden_states.device) final_hidden_states = self.experts( hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - forward_batch=forward_batch, + topk_output=topk_output, ) return final_hidden_states @@ -219,7 +212,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): with get_global_expert_distribution_recorder().with_current_layer( self.layer_id ): - state.topk_weights_local, state.topk_idx_local, _ = self.topk( + state.topk_output = self.topk( hidden_states=hidden_states, router_logits=router_logits, num_token_non_padded=state.forward_batch.num_token_non_padded, @@ -228,20 +221,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ), ) else: - state.topk_idx_local = torch.full( - (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device - ) - state.topk_weights_local = torch.empty( - (0, self.top_k), dtype=torch.float32, device=hidden_states.device - ) + state.topk_output = self.topk.empty_topk_output(hidden_states.device) def op_dispatch_a(self, state): if self.ep_size > 1: - self.experts.deepep_dispatcher.dispatch_a( + self.experts.dispatcher.dispatch_a( hidden_states=state.pop("hidden_states_mlp_input"), - topk_idx=state.pop("topk_idx_local"), - topk_weights=state.pop("topk_weights_local"), - forward_batch=state.forward_batch, + topk_output=state.pop("topk_output"), tbo_subbatch_index=state.get("tbo_subbatch_index"), ) @@ -250,32 +236,29 @@ class Qwen3MoeSparseMoeBlock(nn.Module): with get_global_expert_distribution_recorder().with_current_layer( self.layer_id ): - state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b( + state.dispatch_output = self.experts.dispatcher.dispatch_b( tbo_subbatch_index=state.get("tbo_subbatch_index"), ) def op_experts(self, state): - state.hidden_states_experts_output = self.experts.moe_impl( + state.hidden_states_experts_output = self.experts.run_moe_core( dispatch_output=state.dispatch_output, ) def op_combine_a(self, state): if self.ep_size > 1: - self.experts.deepep_dispatcher.combine_a( + self.experts.dispatcher.combine_a( hidden_states=state.pop("hidden_states_experts_output"), - topk_idx=state.dispatch_output.topk_idx, + topk_ids=state.dispatch_output.topk_ids, topk_weights=state.dispatch_output.topk_weights, - forward_batch=state.forward_batch, tbo_subbatch_index=state.get("tbo_subbatch_index"), ) state.pop("dispatch_output") def op_combine_b(self, state): if self.ep_size > 1: - state.hidden_states_after_combine = ( - self.experts.deepep_dispatcher.combine_b( - tbo_subbatch_index=state.get("tbo_subbatch_index"), - ) + state.hidden_states_after_combine = self.experts.dispatcher.combine_b( + tbo_subbatch_index=state.get("tbo_subbatch_index"), ) def op_output(self, state): diff --git a/python/sglang/srt/single_batch_overlap.py b/python/sglang/srt/single_batch_overlap.py index 77cd41d9d..425152c6f 100644 --- a/python/sglang/srt/single_batch_overlap.py +++ b/python/sglang/srt/single_batch_overlap.py @@ -1,3 +1,19 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Optional @@ -5,12 +21,12 @@ import torch from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe import get_moe_runner_backend +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.utils import is_sbo_enabled -from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import get_int_env_var if TYPE_CHECKING: - from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE class SboFlags: @@ -54,23 +70,22 @@ class DownGemmOverlapArgs: def execute_sbo( forward_shared_experts: Callable[[], Any], - experts: "DeepEPMoE", + experts: FusedMoE, hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - forward_batch: ForwardBatch, - alt_stream: Optional = None, + topk_output: TopKOutput, + alt_stream: Optional[torch.cuda.Stream] = None, disable_sbo: bool = False, ): - dispatch_output = experts.dispatch( - hidden_states, topk_idx, topk_weights, forward_batch + + dispatch_output = experts.dispatcher.dispatch( + hidden_states=hidden_states, topk_output=topk_output ) combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = ( _compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo) ) - hidden_states = experts.moe_impl( + hidden_states = experts.run_moe_core( dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args ) if (e := meta_overlap_args.get("record_event_after_down")) is not None: @@ -83,11 +98,10 @@ def execute_sbo( ): forward_shared_experts() - hidden_states = experts.combine( - hidden_states, - dispatch_output.topk_idx, - dispatch_output.topk_weights, - forward_batch, + hidden_states = experts.dispatcher.combine( + hidden_states=hidden_states, + topk_ids=dispatch_output.topk_ids, + topk_weights=dispatch_output.topk_weights, overlap_args=combine_overlap_args, ) @@ -101,9 +115,7 @@ def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo): ): return None, None, {} - hidden_states = dispatch_output.hidden_states_fp8 - if isinstance(hidden_states, tuple): - hidden_states = hidden_states[0] + hidden_states = dispatch_output.hidden_states num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index c8f2b4e66..1b785443d 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -14,6 +14,7 @@ from sglang.srt.model_executor.cuda_graph_runner import ( get_global_graph_memory_pool, model_capture_mode, set_global_graph_memory_pool, + set_is_extend_in_batch, set_torch_compile_config, ) from sglang.srt.model_executor.forward_batch_info import ( @@ -263,6 +264,7 @@ class EAGLEDraftCudaGraphRunner: # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None set_dp_buffer_len(global_dp_buffer_len, num_tokens) + set_is_extend_in_batch(False) # Backup two fields, which will be modified in-place in `draft_forward`. output_cache_loc_backup = forward_batch.out_cache_loc diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index d54d86a8c..d4b5aeb27 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -15,6 +15,7 @@ from sglang.srt.model_executor.cuda_graph_runner import ( get_global_graph_memory_pool, model_capture_mode, set_global_graph_memory_pool, + set_is_extend_in_batch, set_torch_compile_config, ) from sglang.srt.model_executor.forward_batch_info import ( @@ -294,6 +295,7 @@ class EAGLEDraftExtendCudaGraphRunner: # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None set_dp_buffer_len(global_dp_buffer_len, num_tokens) + set_is_extend_in_batch(False) # Backup two fields, which will be modified in-place in `draft_forward`. output_cache_loc_backup = forward_batch.out_cache_loc diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 69d3f03c1..a07c112fe 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -1000,3 +1000,7 @@ class MaybeTboDeepEPDispatcher: def combine_b(self, **kwargs): return self._execute("combine_b", **kwargs) + + def set_quant_config(self, quant_config: dict): + for inner in self._inners: + inner.set_quant_config(quant_config)