diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index f46521c3a..abe604fc6 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -504,14 +504,8 @@ class FusedMoE(torch.nn.Module): param.data[:, :dim1, :dim2].copy_(loaded_weight) return - # ModelOptNvFp4FusedMoEMethod uses max of global expert scaling factors for input scaling factor - load_global_experts = ( - isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) - and "input_scale" in weight_name - ) - global_expert_location_metadata = get_global_expert_location_metadata() - if global_expert_location_metadata is None or load_global_experts: + if global_expert_location_metadata is None: self._weight_loader_impl( param=param, loaded_weight=loaded_weight, @@ -548,10 +542,12 @@ class FusedMoE(torch.nn.Module): shard_id: str, expert_id: int, ) -> None: + # WARN: This makes the `expert_id` mean "local" and "global" in different cases + if not getattr(param, "_sglang_require_global_experts", False): + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) + if expert_id == -1: + return - expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) - if expert_id == -1: - return self._weight_loader_impl( param=param, loaded_weight=loaded_weight, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index bf8c6c1ed..38894f8c9 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -999,12 +999,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): data=torch.empty(layer.num_experts, 2, dtype=torch.float32), weight_loader=weight_loader, ) + w13_input_scale._sglang_require_global_experts = True layer.register_parameter("w13_input_scale", w13_input_scale) w2_input_scale = PerTensorScaleParameter( data=torch.empty(layer.num_experts, dtype=torch.float32), weight_loader=weight_loader, ) + w2_input_scale._sglang_require_global_experts = True layer.register_parameter("w2_input_scale", w2_input_scale) def swizzle_blockscale(self, scale: torch.Tensor):