[NVIDIA] [3/N] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked (#9199)

This commit is contained in:
Shu Wang
2025-09-11 22:18:43 -05:00
committed by GitHub
parent 7b141f816c
commit 3df05f4d6a
11 changed files with 694 additions and 5 deletions

View File

@@ -673,10 +673,14 @@ class DeepseekV2MoE(nn.Module):
if shared_output is not None:
x = shared_output
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
if self.experts.should_fuse_routed_scaling_factor_in_topk():
x.add_(final_hidden_states)
else:
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
final_hidden_states = x
else:
final_hidden_states *= self.routed_scaling_factor
if not self.experts.should_fuse_routed_scaling_factor_in_topk():
final_hidden_states *= self.routed_scaling_factor
return final_hidden_states