From 8723b4f14633aa46e81bac434e1a6d4da3aa4f8c Mon Sep 17 00:00:00 2001 From: Elfie Guo <164945471+elfiegg@users.noreply.github.com> Date: Tue, 12 Aug 2025 20:08:40 -0700 Subject: [PATCH] Use FlashInfer's TRTLLM FP8 Blockscale GEMM (#8588) --- python/sglang/srt/layers/quantization/fp8_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index d7638ce18..989056b37 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -161,16 +161,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear( output_shape = [*input.shape[:-1], weight.shape[0]] q_input, x_scale = sglang_per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=False + input_2d, block_size[1], column_major_scales=True ) - + # TRTLLM requires column-major scaling factors output = gemm_fp8_nt_groupwise( q_input, weight, x_scale, weight_scale, - scale_major_mode="K", out_dtype=input_2d.dtype, + backend="trtllm", ) if bias is not None: