From 71d1785f2d0a4424f53caa9f5fa4adcb9a195e30 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Wed, 23 Apr 2025 12:24:29 +0800 Subject: [PATCH] Remove unnecessary `torch.full` in DeepSeek (#5601) --- python/sglang/srt/models/deepseek_v2.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 3d230c326..bc59e729b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -323,12 +323,6 @@ class DeepseekV2MoE(nn.Module): self, hidden_states: torch.Tensor, forward_mode: ForwardMode ) -> torch.Tensor: shared_output = None - topk_idx = torch.full( - (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device - ) - topk_weights = torch.empty( - (0, self.top_k), dtype=torch.float32, device=hidden_states.device - ) if ( forward_mode is not None and not forward_mode.is_idle() @@ -348,6 +342,13 @@ class DeepseekV2MoE(nn.Module): correction_bias=self.correction_bias, routed_scaling_factor=self.routed_scaling_factor, ) + else: + topk_idx = torch.full( + (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device + ) + topk_weights = torch.empty( + (0, self.top_k), dtype=torch.float32, device=hidden_states.device + ) if self.ep_size > 1: # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value (