diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 78816b1d9..c180c0a77 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -166,11 +166,13 @@ def flashinfer_gemm_w8a8_block_fp8_linear( input_2d, block_size[1], column_major_scales=False ) - x_scale_input = x_scale.T.contiguous() - weight_scale_input = weight_scale.T.contiguous() - output = gemm_fp8_nt_groupwise( - q_input, weight, x_scale_input, weight_scale_input, out_dtype=input_2d.dtype + q_input, + weight, + x_scale, + weight_scale, + scale_major_mode="K", + out_dtype=input_2d.dtype, ) if bias is not None: