diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 0ab963396..d72526a61 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -77,6 +77,9 @@ logger = logging.getLogger(__name__) CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var( "SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true" ) +USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var( + "SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM" +) # Supported activation schemes for the current configuration ACTIVATION_SCHEMES = ["static"] @@ -844,14 +847,25 @@ class ModelOptFp4LinearMethod(LinearMethodBase): if enable_flashinfer_fp4_gemm: w = layer.weight.T w_scale_interleaved = layer.weight_scale_interleaved.T - out = fp4_gemm( - x_fp4, - w, - x_scale_interleaved, - w_scale_interleaved, - layer.alpha, - output_dtype, - ) + if USE_CUTLASS_BACKEND_FOR_FP4_GEMM: + out = fp4_gemm( + x_fp4, + w, + x_scale_interleaved, + w_scale_interleaved, + layer.alpha, + output_dtype, + backend="cutlass", + ) + else: + out = fp4_gemm( + x_fp4, + w, + x_scale_interleaved, + w_scale_interleaved, + layer.alpha, + output_dtype, + ) if bias is not None: out = out + bias return out.view(*output_shape)