[NVIDIA] Fix num_experts in modelopt_quant (#8811)

This commit is contained in:
Shu Wang
2025-08-06 16:35:07 -05:00
committed by GitHub
parent 01c99a9959
commit 288ae41f7a
2 changed files with 7 additions and 4 deletions

View File

@@ -1063,10 +1063,15 @@ class FlashInferFP4MoE(FusedMoE):
gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=self.g1_scale_c.data,
output1_scale_gate_scalar=self.g1_alphas.data,
output2_scale_scalar=self.g2_alphas.data,

View File

@@ -764,8 +764,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
# TODO(ch-wan): check if this is needed
layer.num_experts = num_experts
layer.num_local_experts = num_experts
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.params_dtype = params_dtype
layer.quant_config = self.quant_config
@@ -1106,7 +1104,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w13_weight_scale,
)
print("Applied flashinfer weight processing for both w13 and w2")
logger.info_once("Applied flashinfer weight processing for both w13 and w2")
else:
# CUTLASS processing - handle w13 and w2 separately
@@ -1126,7 +1124,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
# Both flashinfer cutlass and regular cutlass use same processing for w2
print("Applied weight processing for both w13 and w2")
logger.info_once("Applied weight processing for both w13 and w2")
# Set up CUTLASS MoE parameters
device = layer.w13_weight.device