fix: flashinfer_cutlass_moe: Use max of global expert scales instead of local for input scale (#10296)
This commit is contained in:
@@ -996,13 +996,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
w13_input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
|
||||
data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(layer.num_local_experts, dtype=torch.float32),
|
||||
data=torch.empty(layer.num_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
Reference in New Issue
Block a user