feat: add DeepGEMM build warning (#5176)
Co-authored-by: grimoire <streetyao@live.com>
This commit is contained in:
@@ -16,6 +16,7 @@ import functools
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -59,7 +60,10 @@ if supports_custom_op():
|
|||||||
Bs: torch.Tensor,
|
Bs: torch.Tensor,
|
||||||
C: torch.Tensor,
|
C: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
M, K = A.shape
|
||||||
|
N, _ = B.shape
|
||||||
|
with _log_jit_build(M, N, K):
|
||||||
|
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||||
|
|
||||||
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
@@ -708,6 +712,25 @@ def get_w8a8_block_fp8_configs(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _log_jit_build(M: int, N: int, K: int):
|
||||||
|
from deep_gemm.jit.runtime import RuntimeCache
|
||||||
|
|
||||||
|
origin_func = RuntimeCache.__getitem__
|
||||||
|
|
||||||
|
def __patched_func(self, *args, **kwargs):
|
||||||
|
ret = origin_func(self, *args, **kwargs)
|
||||||
|
if ret is None:
|
||||||
|
logger.warning(
|
||||||
|
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
RuntimeCache.__getitem__ = __patched_func
|
||||||
|
yield
|
||||||
|
RuntimeCache.__getitem__ = origin_func
|
||||||
|
|
||||||
|
|
||||||
def w8a8_block_fp8_matmul(
|
def w8a8_block_fp8_matmul(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
@@ -782,7 +805,8 @@ def w8a8_block_fp8_matmul(
|
|||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||||
else:
|
else:
|
||||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
with _log_jit_build(M, N, K):
|
||||||
|
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||||
else:
|
else:
|
||||||
kernel = (
|
kernel = (
|
||||||
_w8a8_block_fp8_matmul_unrolledx4
|
_w8a8_block_fp8_matmul_unrolledx4
|
||||||
|
|||||||
Reference in New Issue
Block a user