feat: add DeepGEMM build warning (#5176)

Co-authored-by: grimoire <streetyao@live.com>
This commit is contained in:
Yineng Zhang
2025-04-08 21:16:23 -07:00
committed by GitHub
parent f2b70afde0
commit 6669d12707

View File

@@ -16,6 +16,7 @@ import functools
import json import json
import logging import logging
import os import os
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
@@ -59,6 +60,9 @@ if supports_custom_op():
Bs: torch.Tensor, Bs: torch.Tensor,
C: torch.Tensor, C: torch.Tensor,
) -> None: ) -> None:
M, K = A.shape
N, _ = B.shape
with _log_jit_build(M, N, K):
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
def deep_gemm_fp8_fp8_bf16_nt_fake( def deep_gemm_fp8_fp8_bf16_nt_fake(
@@ -708,6 +712,25 @@ def get_w8a8_block_fp8_configs(
return None return None
@contextmanager
def _log_jit_build(M: int, N: int, K: int):
from deep_gemm.jit.runtime import RuntimeCache
origin_func = RuntimeCache.__getitem__
def __patched_func(self, *args, **kwargs):
ret = origin_func(self, *args, **kwargs)
if ret is None:
logger.warning(
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
)
return ret
RuntimeCache.__getitem__ = __patched_func
yield
RuntimeCache.__getitem__ = origin_func
def w8a8_block_fp8_matmul( def w8a8_block_fp8_matmul(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
@@ -782,6 +805,7 @@ def w8a8_block_fp8_matmul(
if supports_custom_op(): if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else: else:
with _log_jit_build(M, N, K):
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
else: else:
kernel = ( kernel = (