diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index a7af87144..4c065e4e5 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -31,7 +31,6 @@ if _is_cuda: if _is_cuda or _is_hip: from sgl_kernel import topk_softmax - expert_distribution_recorder = ExpertDistributionRecorder() @@ -99,6 +98,7 @@ def grouped_topk( topk_group: int = 0, n_share_experts_fusion: int = 0, routed_scaling_factor: Optional[float] = None, + num_token_non_padded: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -138,7 +138,9 @@ def grouped_topk( ) topk_weights = topk_weights / topk_weights_sum - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + return topk_weights, topk_ids def biased_grouped_topk_impl( @@ -151,6 +153,7 @@ def biased_grouped_topk_impl( topk_group: int = 0, n_share_experts_fusion: int = 0, routed_scaling_factor: Optional[float] = None, + num_token_non_padded: Optional[torch.Tensor] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -197,13 +200,25 @@ def biased_grouped_topk_impl( ) topk_weights = topk_weights / topk_weights_sum - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + return topk_weights, topk_ids def is_power_of_two(n): return n > 0 and math.log2(n).is_integer() +def _mask_topk_ids_padded_region( + topk_ids: torch.Tensor, + num_token_non_padded: Optional[torch.Tensor] = None, +): + if num_token_non_padded is None: + return + indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device) + topk_ids[indices >= num_token_non_padded, :] = -1 + + def biased_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -215,6 +230,7 @@ def biased_grouped_topk( compiled: bool = True, n_share_experts_fusion: int = 0, routed_scaling_factor: Optional[float] = None, + num_token_non_padded: Optional[torch.Tensor] = None, ): assert ( routed_scaling_factor is not None @@ -226,7 +242,7 @@ def biased_grouped_topk( <= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion. and is_power_of_two(correction_bias.shape[0]) ): - return moe_fused_gate( + topk_weights, topk_ids = moe_fused_gate( gating_output, correction_bias, num_expert_group, @@ -235,6 +251,11 @@ def biased_grouped_topk( n_share_experts_fusion, routed_scaling_factor, ) + # TODO will fuse this into kernel, thus use slow manual operation now + torch.compile( + _mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend() + )(topk_ids, num_token_non_padded) + return topk_weights, topk_ids else: biased_grouped_topk_fn = ( torch.compile( @@ -253,6 +274,7 @@ def biased_grouped_topk( topk_group, n_share_experts_fusion=n_share_experts_fusion, routed_scaling_factor=routed_scaling_factor, + num_token_non_padded=num_token_non_padded, ) @@ -268,6 +290,7 @@ def select_experts( correction_bias: Optional[torch.Tensor] = None, torch_native: bool = False, routed_scaling_factor: Optional[float] = None, + num_token_non_padded: Optional[torch.Tensor] = None, ): n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] # DeepSeek V2/V3/R1 series models use grouped_top_k @@ -284,6 +307,7 @@ def select_experts( topk_group=topk_group, n_share_experts_fusion=n_share_experts_fusion, routed_scaling_factor=routed_scaling_factor, + num_token_non_padded=num_token_non_padded, ) else: topk_weights, topk_ids = biased_grouped_topk( @@ -296,8 +320,12 @@ def select_experts( topk_group=topk_group, n_share_experts_fusion=n_share_experts_fusion, routed_scaling_factor=routed_scaling_factor, + num_token_non_padded=num_token_non_padded, ) elif torch_native and custom_routing_function is None: + assert ( + num_token_non_padded is None + ), "num_token_non_padded is not yet supported in fused_topk_native" topk_weights, topk_ids = fused_topk_native( hidden_states=hidden_states, gating_output=router_logits, @@ -305,6 +333,9 @@ def select_experts( renormalize=renormalize, ) elif custom_routing_function is None: + assert ( + num_token_non_padded is None + ), "num_token_non_padded is not yet supported in fused_topk" topk_weights, topk_ids = fused_topk( hidden_states=hidden_states, gating_output=router_logits, @@ -312,6 +343,9 @@ def select_experts( renormalize=renormalize, ) else: + assert ( + num_token_non_padded is None + ), "num_token_non_padded is not yet supported in custom_routing_function" topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index e88022beb..40f136deb 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -240,6 +240,7 @@ class CudaGraphRunner: self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) + self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) # pipeline parallelism if self.pp_size > 1: @@ -403,6 +404,7 @@ class CudaGraphRunner: else: encoder_lens = None mrope_positions = self.mrope_positions[:, :bs] + self.num_token_non_padded[...] = num_tokens # pipeline parallelism if self.pp_size > 1: @@ -461,6 +463,7 @@ class CudaGraphRunner: spec_info=spec_info, capture_hidden_mode=self.capture_hidden_mode, lora_paths=lora_paths, + num_token_non_padded=self.num_token_non_padded, ) if lora_paths is not None: @@ -556,6 +559,7 @@ class CudaGraphRunner: self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.positions[:raw_num_token].copy_(forward_batch.positions) + self.num_token_non_padded[...] = len(forward_batch.input_ids) if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: self.seq_lens_cpu.fill_(1) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 5018f92d5..ea64199a5 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -247,6 +247,7 @@ class ForwardBatch: # For padding padded_static_len: int = -1 # -1 if not padded + num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor # For Qwen2-VL mrope_positions: torch.Tensor = None @@ -290,6 +291,9 @@ class ForwardBatch: capture_hidden_mode=batch.capture_hidden_mode, input_embeds=batch.input_embeds, extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu, + num_token_non_padded=torch.tensor( + len(batch.input_ids), dtype=torch.int32 + ).to(device, non_blocking=True), ) # For DP attention diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 436e966db..5955332f5 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -165,7 +165,7 @@ class DeepseekV2MLP(nn.Module): ) self.act_fn = SiluAndMul() - def forward(self, x, forward_mode: Optional[ForwardMode] = None): + def forward(self, x, forward_batch: Optional[ForwardBatch] = None): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -287,12 +287,12 @@ class DeepseekV2MoE(nn.Module): ) def forward( - self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None + self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None ) -> torch.Tensor: if not global_server_args_dict["enable_deepep_moe"]: return self.forward_normal(hidden_states) else: - return self.forward_deepep(hidden_states, forward_mode) + return self.forward_deepep(hidden_states, forward_batch) def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self._forward_shared_experts(hidden_states) @@ -309,8 +309,9 @@ class DeepseekV2MoE(nn.Module): return final_hidden_states def forward_deepep( - self, hidden_states: torch.Tensor, forward_mode: ForwardMode + self, hidden_states: torch.Tensor, forward_batch: ForwardBatch ) -> torch.Tensor: + forward_mode = forward_batch.forward_mode shared_output = None if ( forward_mode is not None @@ -330,6 +331,7 @@ class DeepseekV2MoE(nn.Module): num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=forward_batch.num_token_non_padded, ) else: topk_idx = torch.full( @@ -1339,7 +1341,7 @@ class DeepseekV2DecoderLayer(nn.Module): and (not self.info.is_sparse) and hidden_states.shape[0] == 0 ): - hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) + hidden_states = self.mlp(hidden_states, forward_batch) if self.is_last_layer and self.attn_tp_size != 1: hidden_states += residual