From 6669d12707c91ccd6f795a110be945fc3afc2d89 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 8 Apr 2025 21:16:23 -0700 Subject: [PATCH] feat: add DeepGEMM build warning (#5176) Co-authored-by: grimoire --- .../srt/layers/quantization/fp8_kernel.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index bfd74474e..535d4ecf6 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -16,6 +16,7 @@ import functools import json import logging import os +from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple import torch @@ -59,7 +60,10 @@ if supports_custom_op(): Bs: torch.Tensor, C: torch.Tensor, ) -> None: - deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + M, K = A.shape + N, _ = B.shape + with _log_jit_build(M, N, K): + deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) def deep_gemm_fp8_fp8_bf16_nt_fake( A: torch.Tensor, @@ -708,6 +712,25 @@ def get_w8a8_block_fp8_configs( return None +@contextmanager +def _log_jit_build(M: int, N: int, K: int): + from deep_gemm.jit.runtime import RuntimeCache + + origin_func = RuntimeCache.__getitem__ + + def __patched_func(self, *args, **kwargs): + ret = origin_func(self, *args, **kwargs) + if ret is None: + logger.warning( + f"DeepGEMM JIT code generation : M={M}, N={N}, K={K}. Please wait." + ) + return ret + + RuntimeCache.__getitem__ = __patched_func + yield + RuntimeCache.__getitem__ = origin_func + + def w8a8_block_fp8_matmul( A: torch.Tensor, B: torch.Tensor, @@ -782,7 +805,8 @@ def w8a8_block_fp8_matmul( if supports_custom_op(): torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) else: - deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + with _log_jit_build(M, N, K): + deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) else: kernel = ( _w8a8_block_fp8_matmul_unrolledx4