diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 2038938ea..63c318ba3 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -243,9 +243,19 @@ def apply_fp8_linear( if _is_cuda: qinput, x_scale = sglang_per_token_quant_fp8(input_2d) else: - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, input_scale, use_per_token_if_dynamic=use_per_token_if_dynamic - ) + # TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling + # final solution should be: 1. add support to per-tensor activation scaling. + # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308) + if _is_hip and weight_scale.numel() == 1: + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + use_per_token_if_dynamic=use_per_token_if_dynamic, + ) + else: + qinput, x_scale = per_token_group_quant_fp8( + input_2d, group_size=input_2d.shape[1] + ) if cutlass_fp8_supported: try: