diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index bf8c58e49..23d30f395 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -272,6 +272,19 @@ class Fp8LinearMethod(LinearMethodBase): def process_weights_after_loading(self, layer: Module) -> None: # Block quant doesn't need to process weights after loading if self.block_quant: + # If ROCm, normalize the weights and scales to e4m3fnuz + if is_hip(): + # activation_scheme: dynamic + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.weight, + weight_scale=layer.weight_scale_inv, + input_scale=None, + ) + layer.weight = torch.nn.Parameter(weight, require_grad=False) + layer.weight_scale_inv = torch.nn.Parameter( + weight_scale, require_grad=False + ) + layer.input_scale = None return layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) # If checkpoint not serialized fp8, quantize the weights. @@ -369,7 +382,7 @@ class Fp8LinearMethod(LinearMethodBase): weight=layer.weight, block_size=self.quant_config.weight_block_size, weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, + input_scale=None, bias=bias, ) @@ -553,6 +566,30 @@ class Fp8MoEMethod: # Block quant doesn't need to process weights after loading if self.block_quant: + # If ROCm, normalize the weights and scales to e4m3fnuz + if is_hip(): + # activation_scheme: dynamic + w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.w13_weight, + weight_scale=layer.w13_weight_scale_inv, + input_scale=None, + ) + w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.w2_weight, + weight_scale=layer.w2_weight_scale_inv, + input_scale=None, + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale_inv = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + layer.w13_input_scale = None + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale_inv = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + layer.w2_input_scale = None return # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 43b7876f0..87fad31e6 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -22,7 +22,10 @@ import torch import triton import triton.language as tl -from sglang.srt.utils import get_device_name +from sglang.srt.utils import get_device_name, is_hip + +is_hip_ = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn logger = logging.getLogger(__name__) @@ -73,7 +76,7 @@ def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, eps: float = 1e-10, - dtype: torch.dtype = torch.float8_e4m3fn, + dtype: torch.dtype = fp8_type_, ) -> Tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. @@ -95,9 +98,13 @@ def per_token_group_quant_fp8( assert x.is_contiguous(), "`x` is not contiguous" finfo = torch.finfo(dtype) - fp8_min = finfo.min fp8_max = finfo.max + if is_hip_: + fp8_max = 224.0 + + fp8_min = -fp8_max + x_q = torch.empty_like(x, device=x.device, dtype=dtype) M = x.numel() // group_size N = group_size diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index deb3c91e8..140e70dd9 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -7,6 +7,9 @@ from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul, ) +from sglang.srt.utils import is_hip + +is_hip_ = is_hip() def normalize_e4m3fn_to_e4m3fnuz( @@ -63,8 +66,11 @@ def input_to_float8( finfo = torch.finfo(dtype) min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) - scale = finfo.max / amax - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + fp8_max = finfo.max + if is_hip_: + fp8_max = 224.0 + scale = fp8_max / amax + x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max) return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()