[FIX] fix incorrect output when enable both deepgemm and torch compile (#4359)
Co-authored-by: xuyongfei.xyf <xuyongfei.xyf@antgroup.com>
This commit is contained in:
@@ -22,7 +22,14 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.utils import get_device_core_count, get_device_name, is_cuda, is_hip
|
from sglang.srt.utils import (
|
||||||
|
direct_register_custom_op,
|
||||||
|
get_device_core_count,
|
||||||
|
get_device_name,
|
||||||
|
is_cuda,
|
||||||
|
is_hip,
|
||||||
|
supports_custom_op,
|
||||||
|
)
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
@@ -36,6 +43,33 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))
|
_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))
|
||||||
|
|
||||||
|
if supports_custom_op():
|
||||||
|
|
||||||
|
def deep_gemm_fp8_fp8_bf16_nt(
|
||||||
|
A: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
C: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||||
|
|
||||||
|
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
||||||
|
A: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
C: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="deep_gemm_fp8_fp8_bf16_nt",
|
||||||
|
op_func=deep_gemm_fp8_fp8_bf16_nt,
|
||||||
|
mutates_args=["C"],
|
||||||
|
fake_impl=deep_gemm_fp8_fp8_bf16_nt_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _per_token_group_quant_fp8(
|
def _per_token_group_quant_fp8(
|
||||||
@@ -728,7 +762,10 @@ def w8a8_block_fp8_matmul(
|
|||||||
|
|
||||||
# deepgemm only support bf16
|
# deepgemm only support bf16
|
||||||
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
||||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
if supports_custom_op():
|
||||||
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||||
|
else:
|
||||||
|
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||||
else:
|
else:
|
||||||
kernel = (
|
kernel = (
|
||||||
_w8a8_block_fp8_matmul_unrolledx4
|
_w8a8_block_fp8_matmul_unrolledx4
|
||||||
|
|||||||
Reference in New Issue
Block a user