diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index 8ddbef82e..533813c6f 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -108,10 +108,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase): layer.weight, layer.weight.shape[-1] ) weight_scale = weight_scale.t().contiguous() + if _is_hip: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale + ) else: # if cutlass not supported, we fall back to use torch._scaled_mm # which requires per tensor quantization on weight - qweight, weight_scale = input_to_float8(layer.weight) + fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn + qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype) # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False)