feat: add DeepGEMM build warning (#5176)
Co-authored-by: grimoire <streetyao@live.com>
This commit is contained in:
@@ -16,6 +16,7 @@ import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -59,7 +60,10 @@ if supports_custom_op():
|
||||
Bs: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
) -> None:
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
M, K = A.shape
|
||||
N, _ = B.shape
|
||||
with _log_jit_build(M, N, K):
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
|
||||
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
||||
A: torch.Tensor,
|
||||
@@ -708,6 +712,25 @@ def get_w8a8_block_fp8_configs(
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _log_jit_build(M: int, N: int, K: int):
|
||||
from deep_gemm.jit.runtime import RuntimeCache
|
||||
|
||||
origin_func = RuntimeCache.__getitem__
|
||||
|
||||
def __patched_func(self, *args, **kwargs):
|
||||
ret = origin_func(self, *args, **kwargs)
|
||||
if ret is None:
|
||||
logger.warning(
|
||||
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
|
||||
)
|
||||
return ret
|
||||
|
||||
RuntimeCache.__getitem__ = __patched_func
|
||||
yield
|
||||
RuntimeCache.__getitem__ = origin_func
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
@@ -782,7 +805,8 @@ def w8a8_block_fp8_matmul(
|
||||
if supports_custom_op():
|
||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||
else:
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
with _log_jit_build(M, N, K):
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
else:
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
|
||||
Reference in New Issue
Block a user