Use FlashInfer FP4 gemm. (#8241)
This commit is contained in:
25
python/sglang/srt/layers/quantization/modelopt_quant.py
Normal file → Executable file
25
python/sglang/srt/layers/quantization/modelopt_quant.py
Normal file → Executable file
@@ -35,10 +35,20 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from sgl_kernel import scaled_fp4_quant
|
||||
|
||||
try:
|
||||
from flashinfer import mm_fp4 as fp4_gemm
|
||||
|
||||
enable_flashinfer_fp4_gemm = True
|
||||
except ImportError:
|
||||
if is_cuda():
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm as fp4_gemm
|
||||
else:
|
||||
fp4_gemm = None
|
||||
enable_flashinfer_fp4_gemm = False
|
||||
|
||||
try:
|
||||
from flashinfer import fp4_quantize as fp4_quantize
|
||||
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
||||
except ImportError:
|
||||
flashinfer_cutlass_fused_moe = None
|
||||
@@ -683,11 +693,16 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
||||
assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
|
||||
assert layer.alpha.dtype == torch.float32
|
||||
|
||||
out = cutlass_scaled_fp4_mm(
|
||||
w = layer.weight
|
||||
w_scale_interleaved = layer.weight_scale_interleaved
|
||||
if enable_flashinfer_fp4_gemm:
|
||||
w = layer.weight.T
|
||||
w_scale_interleaved = layer.weight_scale_interleaved.T
|
||||
out = fp4_gemm(
|
||||
x_fp4,
|
||||
layer.weight,
|
||||
w,
|
||||
x_scale_interleaved,
|
||||
layer.weight_scale_interleaved,
|
||||
w_scale_interleaved,
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user