From c2942907d518e297ff0977a19c88964a3e303b67 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Tue, 22 Apr 2025 07:52:53 +0800 Subject: [PATCH] [feature] enable pre compile jit deep_gemm (#5580) --- python/sglang/compile_deep_gemm.py | 136 +++++++ .../srt/layers/quantization/deep_gemm.py | 378 ++++++++++++++++++ .../srt/layers/quantization/fp8_kernel.py | 45 +-- .../srt/layers/quantization/fp8_utils.py | 4 +- .../sglang/srt/model_executor/model_runner.py | 8 + python/sglang/srt/models/deepseek_v2.py | 13 +- python/sglang/srt/utils.py | 10 + 7 files changed, 549 insertions(+), 45 deletions(-) create mode 100644 python/sglang/compile_deep_gemm.py create mode 100644 python/sglang/srt/layers/quantization/deep_gemm.py diff --git a/python/sglang/compile_deep_gemm.py b/python/sglang/compile_deep_gemm.py new file mode 100644 index 000000000..dd3622349 --- /dev/null +++ b/python/sglang/compile_deep_gemm.py @@ -0,0 +1,136 @@ +""" +Compile DeepGEMM Kernels for a model with specify server arguments + +This script launches a server for capturing DeepGEMM calls and then compiles the kernels. +It accepts server arguments (the same as launch_server.py). + +Usage: +python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code + +""" + +import argparse +import dataclasses +import multiprocessing +import os +import time + +import requests + +from sglang.srt.entrypoints.http_server import launch_server +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_process_tree +from sglang.srt.warmup import warmup + +multiprocessing.set_start_method("spawn", force=True) + +# Reduce warning +os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1" + + +@dataclasses.dataclass +class CompileArgs: + timeout: int = 3600 + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--timeout", type=int, default=CompileArgs.timeout) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # use the default value's type to cast the args into correct types. + attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] + return cls( + **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} + ) + + +@warmup("compile-deep-gemm") +async def warm_up_compile(tokenizer_manager: TokenizerManager): + print("\nGenerate warm up request for compiling DeepGEMM...\n") + generate_req_input = GenerateReqInput( + input_ids=[0, 1, 2, 3], + sampling_params={ + "temperature": 0.0, + "max_new_tokens": 8, + "ignore_eos": True, + }, + ) + await tokenizer_manager.generate_request(generate_req_input, None).__anext__() + + +def launch_server_internal(server_args): + try: + launch_server(server_args) + except Exception as e: + raise e + finally: + kill_process_tree(os.getpid(), include_parent=False) + + +def launch_server_process_and_send_one_request( + server_args: ServerArgs, compile_args: CompileArgs +): + proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,)) + proc.start() + base_url = f"http://{server_args.host}:{server_args.port}" + timeout = compile_args.timeout + + start_time = time.time() + while time.time() - start_time < timeout: + try: + headers = { + "Content-Type": "application/json; charset=utf-8", + } + response = requests.get(f"{base_url}/v1/models", headers=headers) + if response.status_code == 200: + return proc + except requests.RequestException: + pass + time.sleep(10) + raise TimeoutError( + "DeepGEMM Kernels compilation timeout." + "\n\nFeel free and please restart the command." + ) + + +def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs): + # Disbale cuda graph and torch compile to save time + server_args.disable_cuda_graph = True + server_args.enable_torch_compile = False + print(f"Disable CUDA Graph and Torch Compile to save time...") + + # Set watchdog timeout to compile_args.timeout because compilation will take a long time + server_args.watchdog_timeout = compile_args.timeout + server_args.warmups = "compile-deep-gemm" + + +def run_compile(server_args: ServerArgs, compile_args: CompileArgs): + print( + "Begin DeepGEMM Kernels compilation...\n" + "It may take a long time and timeout maybe raised " + "while the compilation is still in progress.\n" + "Just feel free to restart the command " + "until the compilation is fully finished.\n" + ) + + proc = launch_server_process_and_send_one_request(server_args, compile_args) + + kill_process_tree(proc.pid) + + print("\nDeepGEMM Kernels compilation finished successfully.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + CompileArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + compile_args = CompileArgs.from_cli_args(args) + + refine_server_args(server_args, compile_args) + + run_compile(server_args, compile_args) diff --git a/python/sglang/srt/layers/quantization/deep_gemm.py b/python/sglang/srt/layers/quantization/deep_gemm.py new file mode 100644 index 000000000..05e16e0db --- /dev/null +++ b/python/sglang/srt/layers/quantization/deep_gemm.py @@ -0,0 +1,378 @@ +import logging +import os +from contextlib import contextmanager +from dataclasses import dataclass +from enum import IntEnum, auto +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from tqdm.contrib.concurrent import thread_map + +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda + +_ENABLE_JIT_DEEPGEMM = False +if is_cuda(): + import deep_gemm + from deep_gemm import get_num_sms + from deep_gemm.jit_kernels.gemm import get_best_configs + from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes + from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template + from deep_gemm.jit_kernels.m_grouped_gemm import ( + template as deep_gemm_grouped_gemm_template, + ) + from deep_gemm.jit_kernels.tuner import jit_tuner + + sm_version = get_device_sm() + if sm_version == 90: + if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"): + _ENABLE_JIT_DEEPGEMM = True + +logger = logging.getLogger(__name__) + +_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) +_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var( + "SGL_JIT_DEEPGEMM_PRECOMPILE", "true" +) +_DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true") +_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4) +_IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false") + +# Force redirect deep_gemm cache_dir +os.environ["DG_CACHE_DIR"] = os.getenv( + "SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm" +) + + +def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): + global _BUILTIN_M_LIST + global _DO_COMPILE + + # Generate m_max + m_max = 1024 * 16 + if server_args.chunked_prefill_size < 1: + m_max = 1024 * 64 + elif server_args.chunked_prefill_size > 8192: + m_max = server_args.chunked_prefill_size * 2 + m_max = min(1024 * 128, m_max) + _BUILTIN_M_LIST = list(range(1, m_max + 1)) + + # Check if is the first rank on node + _DO_COMPILE = ServerArgs.base_gpu_id == gpu_id + + +class DeepGemmKernelType(IntEnum): + GROUPED_GEMM_NT_F8F8BF16_MASKED = auto() + GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto() + GEMM_NT_F8F8BF16 = auto() + + +@dataclass +class DeepGemmKernelHelper: + name: str + compile_func: Callable[ + [ + int, + int, + int, + Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], + ], + None, + ] + configure_func: Callable[ + [int, int, int, int, int], + Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], + ] + + +_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict() + + +def _compile_warning_1(): + if not _IN_PRE_COMPILE_STAGE: + logger.warning( + "Entering DeepGEMM JIT Pre-Complie session. " + "And it may takes a long time(Typically 10-20 mins) " + "if you have not run `sglang.compile_deep_gemm`. " + "Recommand to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`" + " for pre-compilation to reduce the overhead if you have not run it before. " + "For example: " + "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`" + ) + + +def _compile_warning_2(): + logger.warning( + "Entering DeepGEMM JIT Single Kernel Complie session. " + "And it will makes inference throughput becomes flaky. " + "Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`" + " for pre-compilation to solve this issue. " + "For example: " + "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`" + ) + + +def _compile_grouped_gemm_nt_f8f8bf16_masked_one( + n: int, + k: int, + num_groups: int, + config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], +) -> None: + # Auto-tuning with compilation + global deep_gemm_includes, deep_gemm_grouped_gemm_template + _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config + _ = jit_tuner.compile_and_tune( + name="m_grouped_gemm_fp8_fp8_bf16_nt", + keys={ + "N": n, + "K": k, + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "SWIZZLE_D_MODE": smem_config[1], + "BLOCK_N_PADDING": smem_config[2], + "NUM_GROUPS": num_groups, + "NUM_STAGES": num_stages, + "NUM_TMA_MULTICAST": tma_multicast_config[0], + "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], + "GEMM_TYPE": "GroupedMasked", + }, + space=(), + includes=deep_gemm_includes, + arg_defs=( + ("lhs", torch.float8_e4m3fn), + ("lhs_scales", torch.float), + ("rhs", torch.float8_e4m3fn), + ("rhs_scales", torch.float), + ("out", torch.bfloat16), + ("grouped_layout", torch.int32), + ("m", int), + ("stream", torch.cuda.Stream), + ("num_sms", int), + ("smem_size", int), + ), + template=deep_gemm_grouped_gemm_template, + args=[], + ) + + +def _compile_grouped_gemm_nt_f8f8bf16_contig_one( + n: int, + k: int, + num_groups: int, + config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], +) -> None: + global deep_gemm_includes, deep_gemm_grouped_gemm_template + _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config + _ = jit_tuner.compile_and_tune( + name="m_grouped_gemm_fp8_fp8_bf16_nt", + keys={ + "N": n, + "K": k, + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "SWIZZLE_D_MODE": smem_config[1], + "BLOCK_N_PADDING": smem_config[2], + "NUM_GROUPS": num_groups, + "NUM_STAGES": num_stages, + "NUM_TMA_MULTICAST": tma_multicast_config[0], + "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], + "GEMM_TYPE": "GroupedContiguous", + }, + space=(), + includes=deep_gemm_includes, + arg_defs=( + ("lhs", torch.float8_e4m3fn), + ("lhs_scales", torch.float), + ("rhs", torch.float8_e4m3fn), + ("rhs_scales", torch.float), + ("out", torch.bfloat16), + ("grouped_layout", torch.int32), + ("m", int), + ("num_groups", int), + ("stream", torch.cuda.Stream), + ("num_sms", int), + ("smem_size", int), + ), + template=deep_gemm_grouped_gemm_template, + args=[], + ) + + +def _compile_gemm_nt_f8f8bf16_one( + n: int, + k: int, + _: int, # _ is a dummy parameter to align with other interfaces + config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], +) -> None: + global deep_gemm_includes, deep_gemm_gemm_template + _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config + _ = jit_tuner.compile_and_tune( + name="gemm_fp8_fp8_bf16_nt", + keys={ + "N": n, + "K": k, + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "SWIZZLE_D_MODE": smem_config[1], + "BLOCK_N_PADDING": smem_config[2], + "NUM_STAGES": num_stages, + "NUM_TMA_MULTICAST": tma_multicast_config[0], + "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], + }, + space=(), + includes=deep_gemm_includes, + arg_defs=( + ("lhs", torch.float8_e4m3fn), + ("lhs_scales", torch.float), + ("rhs", torch.float8_e4m3fn), + ("rhs_scales", torch.float), + ("out", torch.bfloat16), + ("m", int), + ("stream", torch.cuda.Stream), + ("num_sms", int), + ("smem_size", int), + ), + template=deep_gemm_gemm_template, + args=[], + ) + + +_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = { + DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper( + name="m_grouped_gemm_fp8_fp8_bf16_nt_masked", + compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one, + configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs( + m, n, k, num_groups, num_sms, is_grouped_masked=True + ), + ), + DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper( + name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous", + compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one, + configure_func=lambda m, n, k, _, num_sms: get_best_configs( + m, n, k, 1, num_sms, is_grouped_contiguous=True + ), + ), + DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper( + name="gemm_fp8_fp8_bf16_nt", + compile_func=_compile_gemm_nt_f8f8bf16_one, + configure_func=lambda m, n, k, _, num_sms: get_best_configs( + m, n, k, 1, num_sms + ), + ), +} + + +def _maybe_compile_deep_gemm_one_type_all( + kernel_type: DeepGemmKernelType, + n: int, + k: int, + num_groups: int, + m_list: Optional[List[int]] = None, +) -> None: + + global _INITIALIZATION_DICT + global _BUILTIN_M_LIST + + query_key = (kernel_type, n, k, num_groups) + if ( + _ENABLE_JIT_DEEPGEMM_PRECOMPILE + and _DO_COMPILE + and _INITIALIZATION_DICT.get(query_key) is None + ): + _INITIALIZATION_DICT[query_key] = True + + kernel_helper = _KERNEL_HELPER_DICT[kernel_type] + _compile_warning_1() + logger.info( + f"Try DeepGEMM JIT Compiling for " + f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms." + f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}" + ) + + # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced + num_sms = get_num_sms() + collected_configs = set() + for m in m_list if m_list is not None else _BUILTIN_M_LIST: + # Put config into set to get unique configs and reduce cases to be compiled + collected_configs.add( + kernel_helper.configure_func(m, n, k, num_groups, num_sms) + ) + compile_func = lambda config: kernel_helper.compile_func( + n, k, num_groups, config + ) + thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS) + + +def grouped_gemm_nt_f8f8bf16_masked( + lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, +): + num_groups, _, k = lhs[0].shape + _, n, _ = rhs[0].shape + + kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED + _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups) + + with _log_jit_build(expected_m, n, k, kernel_type): + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + lhs, rhs, out, masked_m, expected_m + ) + + +def grouped_gemm_nt_f8f8bf16_contig( + lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, + m_indices: torch.Tensor, +): + m, k = lhs[0].shape + num_groups, n, _ = rhs[0].shape + + kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG + _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups) + + with _log_jit_build(m, n, k, kernel_type): + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices) + + +def gemm_nt_f8f8bf16( + lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, +): + m, k = lhs[0].shape + n, _ = rhs[0].shape + + kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16 + _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1) + + with _log_jit_build(m, n, k, kernel_type): + deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out) + + +@contextmanager +def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): + if _IN_PRE_COMPILE_STAGE: + yield + return + + from deep_gemm.jit.runtime import RuntimeCache + + origin_func = RuntimeCache.__getitem__ + + def __patched_func(self, *args, **kwargs): + ret = origin_func(self, *args, **kwargs) + if ret is None: + kernel_helper = _KERNEL_HELPER_DICT[kernel_type] + _compile_warning_2() + logger.warning( + f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait." + ) + return ret + + RuntimeCache.__getitem__ = __patched_func + yield + RuntimeCache.__getitem__ = origin_func diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 89e8d23bf..45157527e 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -16,19 +16,17 @@ import functools import json import logging import os -from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple import torch import triton import triton.language as tl +from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.utils import ( direct_register_custom_op, - get_bool_env_var, get_device_core_count, get_device_name, - get_device_sm, is_cuda, is_hip, supports_custom_op, @@ -43,22 +41,16 @@ else: fp8_max = torch.finfo(_fp8_type).max fp8_min = -fp8_max -_enable_jit_deepgemm = False -_enable_jit_deepgemm_bmm = False if _is_cuda: - import deep_gemm from sgl_kernel import ( sgl_per_tensor_quant_fp8, sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8, ) - sm_version = get_device_sm() - if sm_version == 90: - if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"): - _enable_jit_deepgemm = True - if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"): - _enable_jit_deepgemm_bmm = True + from sglang.srt.layers.quantization.deep_gemm import ( + gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16, + ) logger = logging.getLogger(__name__) @@ -71,10 +63,7 @@ if supports_custom_op(): Bs: torch.Tensor, C: torch.Tensor, ) -> None: - M, K = A.shape - N, _ = B.shape - with _log_jit_build(M, N, K): - deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C) def deep_gemm_fp8_fp8_bf16_nt_fake( A: torch.Tensor, @@ -715,25 +704,6 @@ def get_w8a8_block_fp8_configs( return None -@contextmanager -def _log_jit_build(M: int, N: int, K: int): - from deep_gemm.jit.runtime import RuntimeCache - - origin_func = RuntimeCache.__getitem__ - - def __patched_func(self, *args, **kwargs): - ret = origin_func(self, *args, **kwargs) - if ret is None: - logger.warning( - f"DeepGEMM JIT code generation : M={M}, N={N}, K={K}. Please wait." - ) - return ret - - RuntimeCache.__getitem__ = __patched_func - yield - RuntimeCache.__getitem__ = origin_func - - def w8a8_block_fp8_matmul( A: torch.Tensor, B: torch.Tensor, @@ -804,12 +774,11 @@ def w8a8_block_fp8_matmul( ) # deepgemm only support bf16 - if C.dtype == torch.bfloat16 and _enable_jit_deepgemm: + if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM: if supports_custom_op(): torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) else: - with _log_jit_build(M, N, K): - deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C) else: kernel = ( _w8a8_block_fp8_matmul_unrolledx4 diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 33519a49c..7948cff16 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -12,8 +12,8 @@ try: except ImportError: VLLM_AVAILABLE = False +from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.fp8_kernel import ( - _enable_jit_deepgemm, per_token_group_quant_fp8, scaled_fp8_quant, sglang_per_token_quant_fp8, @@ -143,7 +143,7 @@ def apply_w8a8_block_fp8_linear( ) gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output) else: - if _enable_jit_deepgemm: + if _ENABLE_JIT_DEEPGEMM: q_input, x_scale = sglang_per_token_group_quant_fp8( input_2d, block_size[1], diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8833ccc42..5b3a8645c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -42,6 +42,10 @@ from sglang.srt.layers.dp_attention import ( ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer +from sglang.srt.layers.quantization.deep_gemm import ( + _ENABLE_JIT_DEEPGEMM, + update_deep_gemm_config, +) from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager @@ -169,6 +173,10 @@ class ModelRunner: # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() + # Update deep gemm configure + if _ENABLE_JIT_DEEPGEMM: + update_deep_gemm_config(gpu_id, server_args) + # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2cd64d532..3daeab95c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -57,8 +57,8 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.fp8_kernel import ( - _enable_jit_deepgemm_bmm, per_tensor_quant_mla_deep_gemm_masked_fp8, per_tensor_quant_mla_fp8, ) @@ -86,8 +86,11 @@ _is_hip = is_hip() _is_cuda = is_cuda() if _is_cuda: - from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 + + from sglang.srt.layers.quantization.deep_gemm import ( + grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked, + ) else: from vllm._custom_ops import awq_dequantize @@ -702,7 +705,7 @@ class DeepseekV2AttentionMLA(nn.Module): q_nope_out = q_nope.new_empty( (self.num_local_heads, aligned_m, self.kv_lora_rank) ) - m_grouped_gemm_fp8_fp8_bf16_nt_masked( + deep_gemm_grouped_gemm_nt_f8f8bf16_masked( (q_nope_val, q_nope_scale), (self.w_kc, self.w_scale_k), q_nope_out, @@ -751,7 +754,7 @@ class DeepseekV2AttentionMLA(nn.Module): attn_bmm_output = attn_output.new_empty( (self.num_local_heads, aligned_m, self.v_head_dim) ) - m_grouped_gemm_fp8_fp8_bf16_nt_masked( + deep_gemm_grouped_gemm_nt_f8f8bf16_masked( (attn_output_val, attn_output_scale), (self.w_vc, self.w_scale_v), attn_bmm_output, @@ -1520,7 +1523,7 @@ class DeepseekV2ForCausalLM(nn.Module): if ( _is_cuda - and _enable_jit_deepgemm_bmm + and _ENABLE_JIT_DEEPGEMM and weight_block_size[0] == 128 and weight_block_size[1] == 128 and model_dtype == torch.bfloat16 diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 5cf0e1607..69f0c4ff0 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -98,6 +98,16 @@ def get_bool_env_var(name: str, default: str = "false") -> bool: return value in truthy_values +def get_int_env_var(name: str, default: int = 0) -> int: + value = os.getenv(name) + if value is None or not value.strip(): + return default + try: + return int(value) + except ValueError: + return default + + # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip def is_hip() -> bool: return torch.version.hip is not None