Multiple tiny code cleanups (#4608)
This commit is contained in:
@@ -185,7 +185,6 @@ class DeepEPDispatcher:
|
||||
previous_event=None,
|
||||
num_max_dispatch_tokens_per_rank: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self.hidden_shape = hidden_states.shape
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
# Todo: enable low latency dispatch
|
||||
if True: # not forward_mode.is_decode():
|
||||
@@ -375,7 +374,7 @@ class DeepEPDispatcher:
|
||||
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
||||
)
|
||||
self.handle = None
|
||||
return hidden_states.view(self.hidden_shape)
|
||||
return hidden_states
|
||||
|
||||
def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
|
||||
combined_x, _, event = self.buffer_normal.combine(
|
||||
|
||||
Reference in New Issue
Block a user