Remove unnecessary torch.full in DeepSeek (#5601)
This commit is contained in:
@@ -323,12 +323,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
shared_output = None
|
shared_output = None
|
||||||
topk_idx = torch.full(
|
|
||||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
|
||||||
)
|
|
||||||
topk_weights = torch.empty(
|
|
||||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
forward_mode is not None
|
forward_mode is not None
|
||||||
and not forward_mode.is_idle()
|
and not forward_mode.is_idle()
|
||||||
@@ -348,6 +342,13 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
correction_bias=self.correction_bias,
|
correction_bias=self.correction_bias,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
topk_idx = torch.full(
|
||||||
|
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||||
|
)
|
||||||
|
topk_weights = torch.empty(
|
||||||
|
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||||
|
)
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||||
(
|
(
|
||||||
|
|||||||
Reference in New Issue
Block a user