From 4c6675c4fc24b832043ff5f886ee89c3ee3e510a Mon Sep 17 00:00:00 2001 From: valarLip <103567126+valarLip@users.noreply.github.com> Date: Wed, 25 Jun 2025 17:02:31 +0800 Subject: [PATCH] enable aiter fp8 blockscale quant (#7520) --- python/sglang/srt/layers/quantization/fp8_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 9b401a4ee..a2abf975c 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -42,7 +42,10 @@ _is_fp8_fnuz = is_fp8_fnuz() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _use_aiter: - from aiter import gemm_a8w8_blockscale_CK + import aiter + from aiter import gemm_a8w8_blockscale_CK, get_hip_quant + + aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128) if _is_cuda: from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm @@ -271,9 +274,7 @@ def aiter_w8a8_block_fp8_linear( input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=False - ) + q_input, x_scale = aiter_per1x128_quant(input_2d, quant_dtype=aiter.dtypes.fp8) output = gemm_a8w8_blockscale_CK( q_input, weight, x_scale, weight_scale, dtype=input.dtype )