From 92823069c471928beb312b750c8c4f586d32f607 Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Wed, 9 Apr 2025 17:44:05 +0800 Subject: [PATCH] Fix ci test "test_eval_fp8_accuracy" failed (#5185) Co-authored-by: wunhuang --- .../sglang/srt/layers/quantization/fp8_utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 2038938ea..63c318ba3 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -243,9 +243,19 @@ def apply_fp8_linear( if _is_cuda: qinput, x_scale = sglang_per_token_quant_fp8(input_2d) else: - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, input_scale, use_per_token_if_dynamic=use_per_token_if_dynamic - ) + # TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling + # final solution should be: 1. add support to per-tensor activation scaling. + # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308) + if _is_hip and weight_scale.numel() == 1: + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + use_per_token_if_dynamic=use_per_token_if_dynamic, + ) + else: + qinput, x_scale = per_token_group_quant_fp8( + input_2d, group_size=input_2d.shape[1] + ) if cutlass_fp8_supported: try: