Remove unnecessary torch.full in DeepSeek (#5601)

This commit is contained in:
fzyzcjy
2025-04-23 12:24:29 +08:00
committed by GitHub
parent 3f87f83116
commit 71d1785f2d

View File

@@ -323,12 +323,6 @@ class DeepseekV2MoE(nn.Module):
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
) -> torch.Tensor:
shared_output = None
topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if (
forward_mode is not None
and not forward_mode.is_idle()
@@ -348,6 +342,13 @@ class DeepseekV2MoE(nn.Module):
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
else:
topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(