diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 31544f563..d5c1db3a8 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -852,25 +852,15 @@ class ModelOptFp4LinearMethod(LinearMethodBase): if enable_flashinfer_fp4_gemm: w = layer.weight.T w_scale_interleaved = layer.weight_scale_interleaved.T - if USE_CUTLASS_BACKEND_FOR_FP4_GEMM: - out = fp4_gemm( - x_fp4, - w, - x_scale_interleaved, - w_scale_interleaved, - layer.alpha, - output_dtype, - backend="cutlass", - ) - else: - out = fp4_gemm( - x_fp4, - w, - x_scale_interleaved, - w_scale_interleaved, - layer.alpha, - output_dtype, - ) + out = fp4_gemm( + x_fp4, + w, + x_scale_interleaved, + w_scale_interleaved, + layer.alpha, + output_dtype, + **(dict(backend="cutlass") if USE_CUTLASS_BACKEND_FOR_FP4_GEMM else dict()), + ) if bias is not None: out = out + bias return out.view(*output_shape)