[NVIDIA] [3/N] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked (#9199)
This commit is contained in:
@@ -878,6 +878,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
||||
return get_moe_runner_backend().is_flashinfer_cutlass()
|
||||
|
||||
@property
|
||||
def enable_flashinfer_cutedsl_moe(self) -> bool:
|
||||
from sglang.srt.layers.moe import get_moe_runner_backend
|
||||
|
||||
"""Access the global enable_flashinfer_cutedsl_moe setting."""
|
||||
return get_moe_runner_backend().is_flashinfer_cutedsl()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -1398,5 +1405,38 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
||||
).to(x.dtype)
|
||||
# Scale by routed_scaling_factor is fused into select_experts.
|
||||
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
def apply_without_routing_weights(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe"
|
||||
assert (
|
||||
not moe_runner_config.apply_router_weight_on_input
|
||||
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||
|
||||
from sglang.srt.layers.moe.flashinfer_cutedsl_moe import (
|
||||
flashinfer_cutedsl_moe_masked,
|
||||
)
|
||||
|
||||
out = flashinfer_cutedsl_moe_masked(
|
||||
hidden_states=x,
|
||||
input_global_scale=layer.w13_input_scale_quant,
|
||||
w1=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alpha=layer.g1_alphas,
|
||||
w2=layer.w2_weight,
|
||||
a2_global_scale=layer.w2_input_scale_quant,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alpha=layer.g2_alphas,
|
||||
masked_m=masked_m,
|
||||
)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user