[NVIDIA] [3/N] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked (#9199)
This commit is contained in:
@@ -673,10 +673,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
if shared_output is not None:
|
||||
x = shared_output
|
||||
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
||||
if self.experts.should_fuse_routed_scaling_factor_in_topk():
|
||||
x.add_(final_hidden_states)
|
||||
else:
|
||||
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
||||
final_hidden_states = x
|
||||
else:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
if not self.experts.should_fuse_routed_scaling_factor_in_topk():
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user