[NVIDIA] Add Flashinfer MoE blockscale fp8 backend (#8036)
This commit is contained in:
@@ -711,7 +711,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
" quantization. Please use Blackwell and"
|
||||
" above."
|
||||
)
|
||||
self.enable_flashinfer_moe = False
|
||||
self.enable_flashinfer_cutlass_moe = False
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -865,7 +865,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
||||
|
||||
if self.enable_flashinfer_moe:
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
||||
else:
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||
@@ -894,7 +894,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||
|
||||
# GEMM 2
|
||||
if self.enable_flashinfer_moe:
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
||||
else:
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
@@ -934,7 +934,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
@property
|
||||
def load_up_proj_weight_first(self) -> bool:
|
||||
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
||||
return self.enable_flashinfer_moe
|
||||
return self.enable_flashinfer_cutlass_moe
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -954,7 +954,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
if self.enable_flashinfer_moe:
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
assert (
|
||||
not apply_router_weight_on_input
|
||||
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||
|
||||
Reference in New Issue
Block a user