fix: flashinfer_cutlass_moe: Use max of global expert scales instead of local for input scale (#10296)

This commit is contained in:
Trevor Morris
2025-09-11 20:19:17 -07:00
committed by GitHub
parent 3df05f4d6a
commit c7e85f5378
2 changed files with 9 additions and 3 deletions

View File

@@ -996,13 +996,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
w13_input_scale = PerTensorScaleParameter(
data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter(
data=torch.empty(layer.num_local_experts, dtype=torch.float32),
data=torch.empty(layer.num_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w2_input_scale", w2_input_scale)