Use FlashInfer FP4 gemm. (#8241)

This commit is contained in:
Elfie Guo
2025-07-27 01:05:22 -07:00
committed by GitHub
parent bf0f448fe5
commit 5c9c275bc8
2 changed files with 230 additions and 5 deletions

25
python/sglang/srt/layers/quantization/modelopt_quant.py Normal file → Executable file
View File

@@ -35,10 +35,20 @@ if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
from sgl_kernel import scaled_fp4_quant
try:
from flashinfer import mm_fp4 as fp4_gemm
enable_flashinfer_fp4_gemm = True
except ImportError:
if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm as fp4_gemm
else:
fp4_gemm = None
enable_flashinfer_fp4_gemm = False
try:
from flashinfer import fp4_quantize as fp4_quantize
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
except ImportError:
flashinfer_cutlass_fused_moe = None
@@ -683,11 +693,16 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32
out = cutlass_scaled_fp4_mm(
w = layer.weight
w_scale_interleaved = layer.weight_scale_interleaved
if enable_flashinfer_fp4_gemm:
w = layer.weight.T
w_scale_interleaved = layer.weight_scale_interleaved.T
out = fp4_gemm(
x_fp4,
layer.weight,
w,
x_scale_interleaved,
layer.weight_scale_interleaved,
w_scale_interleaved,
layer.alpha,
output_dtype,
)