[perf] slightly imporve DeepSeek-R1-FP4 TP8 (#7481)
This commit is contained in:
@@ -362,12 +362,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
return self.forward_deepep(hidden_states, forward_batch)
|
||||
|
||||
def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user