Multiple tiny code cleanups (#4608)
This commit is contained in:
@@ -185,7 +185,6 @@ class DeepEPDispatcher:
|
||||
previous_event=None,
|
||||
num_max_dispatch_tokens_per_rank: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self.hidden_shape = hidden_states.shape
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
# Todo: enable low latency dispatch
|
||||
if True: # not forward_mode.is_decode():
|
||||
@@ -375,7 +374,7 @@ class DeepEPDispatcher:
|
||||
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
||||
)
|
||||
self.handle = None
|
||||
return hidden_states.view(self.hidden_shape)
|
||||
return hidden_states
|
||||
|
||||
def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
|
||||
combined_x, _, event = self.buffer_normal.combine(
|
||||
|
||||
@@ -250,8 +250,6 @@ class DeepseekV2MoE(nn.Module):
|
||||
return self.forward_deepep(hidden_states, forward_mode)
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
@@ -264,13 +262,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
return final_hidden_states
|
||||
|
||||
def forward_deepep(
|
||||
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
||||
) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
topk_idx = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
@@ -319,7 +315,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
|
||||
Reference in New Issue
Block a user