[NVIDIA] [3/N] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked (#9199)

This commit is contained in:
Shu Wang
2025-09-11 22:18:43 -05:00
committed by GitHub
parent 7b141f816c
commit 3df05f4d6a
11 changed files with 694 additions and 5 deletions

View File

@@ -878,6 +878,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
"""Access the global enable_flashinfer_cutlass_moe setting."""
return get_moe_runner_backend().is_flashinfer_cutlass()
@property
def enable_flashinfer_cutedsl_moe(self) -> bool:
from sglang.srt.layers.moe import get_moe_runner_backend
"""Access the global enable_flashinfer_cutedsl_moe setting."""
return get_moe_runner_backend().is_flashinfer_cutedsl()
def create_weights(
self,
layer: torch.nn.Module,
@@ -1398,5 +1405,38 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype)
# Scale by routed_scaling_factor is fused into select_experts.
return StandardCombineInput(hidden_states=output)
def apply_without_routing_weights(
self,
layer: FusedMoE,
x: torch.Tensor,
masked_m: torch.Tensor,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe"
assert (
not moe_runner_config.apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"
from sglang.srt.layers.moe.flashinfer_cutedsl_moe import (
flashinfer_cutedsl_moe_masked,
)
out = flashinfer_cutedsl_moe_masked(
hidden_states=x,
input_global_scale=layer.w13_input_scale_quant,
w1=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alpha=layer.g1_alphas,
w2=layer.w2_weight,
a2_global_scale=layer.w2_input_scale_quant,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alpha=layer.g2_alphas,
masked_m=masked_m,
)
return out