[FIX] fix incorrect output when enable both deepgemm and torch compile (#4359)
Co-authored-by: xuyongfei.xyf <xuyongfei.xyf@antgroup.com>
This commit is contained in:
@@ -22,7 +22,14 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.utils import get_device_core_count, get_device_name, is_cuda, is_hip
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_device_core_count,
|
||||
get_device_name,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
supports_custom_op,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
@@ -36,6 +43,33 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))
|
||||
|
||||
if supports_custom_op():
|
||||
|
||||
def deep_gemm_fp8_fp8_bf16_nt(
|
||||
A: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
) -> None:
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
|
||||
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
||||
A: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="deep_gemm_fp8_fp8_bf16_nt",
|
||||
op_func=deep_gemm_fp8_fp8_bf16_nt,
|
||||
mutates_args=["C"],
|
||||
fake_impl=deep_gemm_fp8_fp8_bf16_nt_fake,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_token_group_quant_fp8(
|
||||
@@ -728,7 +762,10 @@ def w8a8_block_fp8_matmul(
|
||||
|
||||
# deepgemm only support bf16
|
||||
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
if supports_custom_op():
|
||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||
else:
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
else:
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
|
||||
Reference in New Issue
Block a user