From 6a47b73024ac80fa4748d14c5856497fba428bf2 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 1 Jun 2025 18:30:54 -0700 Subject: [PATCH] Remove contiguous before Flashinfer groupwise fp8 gemm (#6804) --- python/sglang/srt/layers/quantization/fp8_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 78816b1d9..c180c0a77 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -166,11 +166,13 @@ def flashinfer_gemm_w8a8_block_fp8_linear( input_2d, block_size[1], column_major_scales=False ) - x_scale_input = x_scale.T.contiguous() - weight_scale_input = weight_scale.T.contiguous() - output = gemm_fp8_nt_groupwise( - q_input, weight, x_scale_input, weight_scale_input, out_dtype=input_2d.dtype + q_input, + weight, + x_scale, + weight_scale, + scale_major_mode="K", + out_dtype=input_2d.dtype, ) if bias is not None: