diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index d7638ce18..989056b37 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -161,16 +161,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear( output_shape = [*input.shape[:-1], weight.shape[0]] q_input, x_scale = sglang_per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=False + input_2d, block_size[1], column_major_scales=True ) - + # TRTLLM requires column-major scaling factors output = gemm_fp8_nt_groupwise( q_input, weight, x_scale, weight_scale, - scale_major_mode="K", out_dtype=input_2d.dtype, + backend="trtllm", ) if bias is not None: