diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 4e1e97713..ac1a831ac 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -33,10 +33,12 @@ from sglang.srt.layers.quantization.base_config import ( ) from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8_kernel import ( + is_fp8_fnuz, scaled_fp8_quant, sglang_per_token_group_quant_fp8, sglang_per_token_quant_fp8, ) +from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.managers.expert_location import get_global_expert_location_metadata from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -50,6 +52,7 @@ from sglang.srt.utils import ( ) _is_hip = is_hip() +_is_fp8_fnuz = is_fp8_fnuz() if _is_hip: from vllm._custom_ops import scaled_fp8_quant @@ -843,6 +846,33 @@ class Fp8EPMoEMethod(Fp8MoEMethod): torch.max(layer.w13_weight_scale, dim=1).values, requires_grad=False, ) + if self.block_quant: + # If ROCm, normalize the weights and scales to e4m3fnuz + if _is_fp8_fnuz: + # activation_scheme: dynamic + w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.w13_weight, + weight_scale=layer.w13_weight_scale_inv, + input_scale=None, + ) + w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.w2_weight, + weight_scale=layer.w2_weight_scale_inv, + input_scale=None, + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter( + w13_weight, requires_grad=False + ) + layer.w13_weight_scale_inv = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + layer.w13_input_scale = None + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale_inv = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + layer.w2_input_scale = None return def apply(