[NVIDIA] Add Flashinfer MoE blockscale fp8 backend (#8036)

This commit is contained in:
Kaixi Hou
2025-07-27 00:34:41 -07:00
committed by GitHub
parent e34cf6ad75
commit 85486b6f6f
8 changed files with 179 additions and 47 deletions

View File

@@ -711,7 +711,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" quantization. Please use Blackwell and"
" above."
)
self.enable_flashinfer_moe = False
self.enable_flashinfer_cutlass_moe = False
def create_weights(
self,
@@ -865,7 +865,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
if self.enable_flashinfer_moe:
if self.enable_flashinfer_cutlass_moe:
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
@@ -894,7 +894,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
# GEMM 2
if self.enable_flashinfer_moe:
if self.enable_flashinfer_cutlass_moe:
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
else:
w2_input_scale = layer.w2_input_scale
@@ -934,7 +934,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
@property
def load_up_proj_weight_first(self) -> bool:
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return self.enable_flashinfer_moe
return self.enable_flashinfer_cutlass_moe
def apply(
self,
@@ -954,7 +954,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if self.enable_flashinfer_moe:
if self.enable_flashinfer_cutlass_moe:
assert (
not apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"