Multiple tiny code cleanups (#4608)
This commit is contained in:
@@ -185,7 +185,6 @@ class DeepEPDispatcher:
|
|||||||
previous_event=None,
|
previous_event=None,
|
||||||
num_max_dispatch_tokens_per_rank: int = 128,
|
num_max_dispatch_tokens_per_rank: int = 128,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
self.hidden_shape = hidden_states.shape
|
|
||||||
topk_idx = topk_idx.to(torch.int64)
|
topk_idx = topk_idx.to(torch.int64)
|
||||||
# Todo: enable low latency dispatch
|
# Todo: enable low latency dispatch
|
||||||
if True: # not forward_mode.is_decode():
|
if True: # not forward_mode.is_decode():
|
||||||
@@ -375,7 +374,7 @@ class DeepEPDispatcher:
|
|||||||
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
||||||
)
|
)
|
||||||
self.handle = None
|
self.handle = None
|
||||||
return hidden_states.view(self.hidden_shape)
|
return hidden_states
|
||||||
|
|
||||||
def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
|
def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
|
||||||
combined_x, _, event = self.buffer_normal.combine(
|
combined_x, _, event = self.buffer_normal.combine(
|
||||||
|
|||||||
@@ -250,8 +250,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
return self.forward_deepep(hidden_states, forward_mode)
|
return self.forward_deepep(hidden_states, forward_mode)
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
||||||
if self.n_shared_experts is not None:
|
if self.n_shared_experts is not None:
|
||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
@@ -264,13 +262,11 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
return final_hidden_states
|
||||||
|
|
||||||
def forward_deepep(
|
def forward_deepep(
|
||||||
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
||||||
shared_output = None
|
shared_output = None
|
||||||
topk_idx = torch.full(
|
topk_idx = torch.full(
|
||||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||||
@@ -319,7 +315,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||||
|
|||||||
Reference in New Issue
Block a user