From 85ef7f64e4802115a07a5f76a843017830973875 Mon Sep 17 00:00:00 2001 From: AniZpZ Date: Thu, 13 Mar 2025 12:34:09 +0800 Subject: [PATCH] [FIX] fix incorrect output when enable both deepgemm and torch compile (#4359) Co-authored-by: xuyongfei.xyf --- .../srt/layers/quantization/fp8_kernel.py | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 1b61575ca..7a6219527 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -22,7 +22,14 @@ import torch import triton import triton.language as tl -from sglang.srt.utils import get_device_core_count, get_device_name, is_cuda, is_hip +from sglang.srt.utils import ( + direct_register_custom_op, + get_device_core_count, + get_device_name, + is_cuda, + is_hip, + supports_custom_op, +) _is_hip = is_hip() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn @@ -36,6 +43,33 @@ logger = logging.getLogger(__name__) _enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0")) +if supports_custom_op(): + + def deep_gemm_fp8_fp8_bf16_nt( + A: torch.Tensor, + As: torch.Tensor, + B: torch.Tensor, + Bs: torch.Tensor, + C: torch.Tensor, + ) -> None: + deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + + def deep_gemm_fp8_fp8_bf16_nt_fake( + A: torch.Tensor, + As: torch.Tensor, + B: torch.Tensor, + Bs: torch.Tensor, + C: torch.Tensor, + ) -> None: + return + + direct_register_custom_op( + op_name="deep_gemm_fp8_fp8_bf16_nt", + op_func=deep_gemm_fp8_fp8_bf16_nt, + mutates_args=["C"], + fake_impl=deep_gemm_fp8_fp8_bf16_nt_fake, + ) + @triton.jit def _per_token_group_quant_fp8( @@ -728,7 +762,10 @@ def w8a8_block_fp8_matmul( # deepgemm only support bf16 if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm: - deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + if supports_custom_op(): + torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) + else: + deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) else: kernel = ( _w8a8_block_fp8_matmul_unrolledx4