fix: flashinfer_cutlass_moe: Use max of global expert scales instead of local for input scale (#10296)
This commit is contained in:
@@ -503,8 +503,14 @@ class FusedMoE(torch.nn.Module):
|
||||
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
||||
return
|
||||
|
||||
# ModelOptNvFp4FusedMoEMethod uses max of global expert scaling factors for input scaling factor
|
||||
load_global_experts = (
|
||||
isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
||||
and "input_scale" in weight_name
|
||||
)
|
||||
|
||||
global_expert_location_metadata = get_global_expert_location_metadata()
|
||||
if global_expert_location_metadata is None:
|
||||
if global_expert_location_metadata is None or load_global_experts:
|
||||
self._weight_loader_impl(
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
|
||||
@@ -996,13 +996,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
w13_input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
|
||||
data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(layer.num_local_experts, dtype=torch.float32),
|
||||
data=torch.empty(layer.num_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
Reference in New Issue
Block a user