diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index 624d9ed32..8ce6e9f94 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -38,6 +38,8 @@ jobs: include: - python-version: "3.10" cuda-version: "12.4" + - python-version: "3.10" + cuda-version: "12.8" - python-version: "3.10" cuda-version: "12.9" name: Build Wheel (CUDA ${{ matrix.cuda-version }}) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 18ac91464..e35a4e017 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -248,7 +248,6 @@ class EPMoE(FusedMoE): gateup_output, masked_m, expected_m, - recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None, ) del gateup_input del gateup_input_fp8 @@ -304,7 +303,6 @@ class EPMoE(FusedMoE): down_output, masked_m, expected_m, - recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None, ) del down_input del down_input_fp8 @@ -667,7 +665,6 @@ class DeepEPMoE(EPMoE): gateup_output, masked_m, expected_m, - recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None, ) dispose_tensor(hidden_states_fp8[0]) @@ -708,9 +705,7 @@ class DeepEPMoE(EPMoE): ( down_input_scale if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - else deep_gemm_wrapper.get_col_major_tma_aligned_tensor( - down_input_scale - ) + else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale) ), ) down_output = torch.empty( @@ -722,7 +717,6 @@ class DeepEPMoE(EPMoE): down_output, masked_m, expected_m, - recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None, ) return down_output diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py index c3043f389..ca3dbf9d2 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py @@ -1,26 +1,22 @@ import logging import os from contextlib import contextmanager -from dataclasses import dataclass from enum import IntEnum, auto -from typing import Callable, Dict, List, Optional, Tuple +from typing import Dict, List, Tuple -from tqdm.contrib.concurrent import thread_map +import torch +from tqdm import tqdm from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( - DEEPGEMM_BLACKWELL, ENABLE_JIT_DEEPGEMM, ) from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_bool_env_var, get_int_env_var +from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var logger = logging.getLogger(__name__) -if ENABLE_JIT_DEEPGEMM and not DEEPGEMM_BLACKWELL: - from deep_gemm import get_num_sms - from deep_gemm.jit import build - from deep_gemm.jit_kernels.gemm import get_best_configs - from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType +if ENABLE_JIT_DEEPGEMM: + import deep_gemm _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) @@ -40,19 +36,7 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv( # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f # NVRTC may have performance loss with some cases. # And NVCC JIT speed is also 9x faster in the ref commit -_USE_NVRTC_DEFAULT = "0" -if ENABLE_JIT_DEEPGEMM: - try: - from deep_gemm.jit.compiler import get_nvcc_compiler - - get_nvcc_compiler() - except: - logger.warning( - "NVCC Compiler not found, use NVRTC for DeepGEMM JIT " - "and may have performance loss with some cases." - ) - _USE_NVRTC_DEFAULT = "1" -os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT) +os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0") def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): @@ -75,7 +59,7 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): # Default each rank will try compile all Ms to # load all symbols at the launch stages. # Avoid loading symbols at the serving stages. - _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE + _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE class DeepGemmKernelType(IntEnum): @@ -84,185 +68,15 @@ class DeepGemmKernelType(IntEnum): GEMM_NT_F8F8BF16 = auto() -@dataclass -class DeepGemmKernelHelper: - name: str - compile_func: Callable[ - [ - int, - int, - int, - Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], - ], - None, - ] - configure_func: Callable[ - [int, int, int, int, int], - Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], - ] - - _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict() -# TODO improve naming -def _compile_warning_1(): - if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE: - logger.warning( - "Entering DeepGEMM JIT Pre-Compile session. " - "It may takes a long time (typically 10-20 mins) " - "if you have not run `sglang.compile_deep_gemm`. " - "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`" - " for pre-compilation to reduce the overhead if you have not run it before. " - "For example: " - "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`" - ) - - -# TODO improve naming -def _compile_warning_2(): - logger.warning( - "Entering DeepGEMM JIT Single Kernel Compile session. " - "And it will makes inference throughput becomes flaky. " - "Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`" - " for pre-compilation to solve this issue. " - "For example: " - "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`" - ) - - -def _compile_grouped_gemm_nt_f8f8bf16_masked_one( - n: int, - k: int, - num_groups: int, - config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], -) -> None: - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config - block_k = 128 - num_tma_threads = 128 - num_math_threads_per_group = 128 - - kwargs = { - "GEMM_TYPE": GemmType.GroupedMasked, - "NUM_TMA_THREADS": num_tma_threads, - "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group, - "N": n, - "K": k, - "NUM_GROUPS": num_groups, - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SWIZZLE_D_MODE": smem_config[1], - "BLOCK_N_PADDING": smem_config[2], - "NUM_STAGES": num_stages, - "NUM_TMA_MULTICAST": tma_multicast_config[0], - "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], - "NUM_SMS": num_sms, - "SMEM_SIZE": smem_config[0], - } - - code = FP8GemmRuntime.generate(kwargs) - _ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs) - - -def _compile_grouped_gemm_nt_f8f8bf16_contig_one( - n: int, - k: int, - num_groups: int, - config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], -) -> None: - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config - block_k = 128 - num_tma_threads = 128 - num_math_threads_per_group = 128 - kwargs = { - "GEMM_TYPE": GemmType.GroupedContiguous, - "NUM_TMA_THREADS": num_tma_threads, - "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group, - "N": n, - "K": k, - "NUM_GROUPS": 1, - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SWIZZLE_D_MODE": smem_config[1], - "BLOCK_N_PADDING": smem_config[2], - "NUM_STAGES": num_stages, - "NUM_TMA_MULTICAST": tma_multicast_config[0], - "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], - "NUM_SMS": num_sms, - "SMEM_SIZE": smem_config[0], - } - - code = FP8GemmRuntime.generate(kwargs) - _ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs) - - -def _compile_gemm_nt_f8f8bf16_one( - n: int, - k: int, - _: int, # _ is a dummy parameter to align with other interfaces - config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], -) -> None: - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config - block_k = 128 - num_tma_threads = 128 - num_math_threads_per_group = 128 - kwargs = { - "GEMM_TYPE": GemmType.Normal, - "NUM_TMA_THREADS": num_tma_threads, - "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group, - "N": n, - "K": k, - "NUM_GROUPS": 1, - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SWIZZLE_D_MODE": smem_config[1], - "BLOCK_N_PADDING": smem_config[2], - "NUM_STAGES": num_stages, - "NUM_TMA_MULTICAST": tma_multicast_config[0], - "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], - "NUM_SMS": num_sms, - "SMEM_SIZE": smem_config[0], - } - - code = FP8GemmRuntime.generate(kwargs) - _ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs) - - -# TODO further refactor warmup-related -_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = { - DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper( - name="m_grouped_gemm_fp8_fp8_bf16_nt_masked", - compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one, - configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs( - m, n, k, num_groups, num_sms, is_grouped_masked=True - ), - ), - DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper( - name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous", - compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one, - configure_func=lambda m, n, k, _, num_sms: get_best_configs( - m, n, k, 1, num_sms, is_grouped_contiguous=True - ), - ), - DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper( - name="gemm_fp8_fp8_bf16_nt", - compile_func=_compile_gemm_nt_f8f8bf16_one, - configure_func=lambda m, n, k, _, num_sms: get_best_configs( - m, n, k, 1, num_sms - ), - ), -} - - +# TODO improve code def _maybe_compile_deep_gemm_one_type_all( kernel_type: DeepGemmKernelType, n: int, k: int, num_groups: int, - m_list: Optional[List[int]] = None, ) -> None: global _INITIALIZATION_DICT global _BUILTIN_M_LIST @@ -275,61 +89,145 @@ def _maybe_compile_deep_gemm_one_type_all( ): _INITIALIZATION_DICT[query_key] = True - kernel_helper = _KERNEL_HELPER_DICT[kernel_type] - _compile_warning_1() + # TODO maybe improve logs + if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE: + logger.warning( + "Entering DeepGEMM JIT Pre-Compile session. " + "It may takes a long time (typically 10-20 mins) " + "if you have not run `sglang.compile_deep_gemm`. " + "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`" + " for pre-compilation to reduce the overhead if you have not run it before. " + "For example: " + "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`" + ) + logger.info( f"Try DeepGEMM JIT Compiling for " - f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms." + f"<{kernel_type.name}> N={n}, K={k}, num_groups={num_groups} with all Ms." f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}" ) - # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced - num_sms = get_num_sms() - collected_configs = set() - for m in m_list if m_list is not None else _BUILTIN_M_LIST: - # Put config into set to get unique configs and reduce cases to be compiled - collected_configs.add( - kernel_helper.configure_func(m, n, k, num_groups, num_sms) - ) - compile_func = lambda config: kernel_helper.compile_func( - n, k, num_groups, config + _compile_deep_gemm_one_type_all( + kernel_type=kernel_type, + n=n, + k=k, + num_groups=num_groups, + m_list=_BUILTIN_M_LIST, ) - thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS) -@contextmanager -def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): - if _IN_PRECOMPILE_STAGE: - yield - return +# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced +def _compile_deep_gemm_one_type_all( + kernel_type: DeepGemmKernelType, + n: int, + k: int, + num_groups: int, + m_list: List[int], +) -> None: + if kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: + m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout() + m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0))) - from deep_gemm.jit.runtime import RuntimeCache + executor = _BaseWarmupExecutor.create( + kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups + ) - origin_func = RuntimeCache.get + # TODO can use multi thread + for m in tqdm(m_list, desc=f"DeepGEMM warmup"): + executor.execute(m=m) - def __patched_func(self, *args, **kwargs): - ret = origin_func(self, *args, **kwargs) - if ret is None: - kernel_helper = _KERNEL_HELPER_DICT[kernel_type] - if not DEEPGEMM_BLACKWELL: - _compile_warning_2() - logger.warning( - f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait." - ) - return ret - RuntimeCache.get = __patched_func - yield - RuntimeCache.get = origin_func +class _BaseWarmupExecutor: + @staticmethod + def create(kernel_type: DeepGemmKernelType, **kwargs): + return { + DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor, + DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor, + DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor, + }[kernel_type](**kwargs) + + def execute(self, m): + raise NotImplementedError + + +def _empty_token_fp8(size): + *dims, k = size + return ( + torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn), + torch.empty( + (*dims, ceil_div(k, _BLOCK_SIZE)), device="cuda", dtype=torch.float32 + ), + ) + + +def _empty_block_fp8(size): + *dims, n, k = size + return ( + torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn), + torch.empty( + (*dims, ceil_div(n, _BLOCK_SIZE), ceil_div(k, _BLOCK_SIZE)), + device="cuda", + dtype=torch.float32, + ), + ) + + +_BLOCK_SIZE = 128 + + +class _NormalWarmupExecutor(_BaseWarmupExecutor): + def __init__(self, max_m: int, n: int, k: int, num_groups: int): + self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k)) + self.rhs_q, self.rhs_s = _empty_block_fp8((n, k)) + self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16) + + def execute(self, m): + deep_gemm.fp8_gemm_nt( + (self.lhs_q[:m], self.lhs_s[:m]), + (self.rhs_q, self.rhs_s), + self.out[:m], + ) + + +class _GroupedContWarmupExecutor(_BaseWarmupExecutor): + def __init__(self, max_m: int, n: int, k: int, num_groups: int): + self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k)) + self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k)) + self.m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32) + self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16) + + def execute(self, m): + deep_gemm.m_grouped_fp8_gemm_nt_contiguous( + (self.lhs_q[:m], self.lhs_s[:m]), + (self.rhs_q, self.rhs_s), + self.out[:m], + m_indices=self.m_indices[:m], + ) + + +class _GroupedMaskedWarmupExecutor(_BaseWarmupExecutor): + def __init__(self, max_m: int, n: int, k: int, num_groups: int): + self.lhs_q, self.lhs_s = _empty_token_fp8((num_groups, max_m, k)) + self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k)) + self.masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32) + self.out = torch.empty( + (num_groups, max_m, n), device="cuda", dtype=torch.bfloat16 + ) + + def execute(self, m): + deep_gemm.fp8_m_grouped_gemm_nt_masked( + (self.lhs_q, self.lhs_s), + (self.rhs_q, self.rhs_s), + self.out, + masked_m=self.masked_m, + # DeepGEMM uses `expect_m` instead of input shape for `get_best_config` + expected_m=m, + ) @contextmanager def deep_gemm_execution_hook( m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType ): - # not supported yet - if not DEEPGEMM_BLACKWELL: - _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups) - - with _log_jit_build(m, n, k, kernel_type): - yield + _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups) + yield diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py index cb4c2edb1..936ca75b8 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py @@ -24,14 +24,12 @@ def _compute_enable_deep_gemm(): return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true") +def _is_blackwell_arch() -> bool: + major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) + return major == 10 + + ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm() -try: - from deep_gemm import fp8_gemm_nt - - # They have not given a name to this breaking change - DEEPGEMM_BLACKWELL = True -except ImportError: - DEEPGEMM_BLACKWELL = False - +DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch() DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py index 9dad33f9e..eedaa3c9b 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py @@ -16,33 +16,16 @@ logger = logging.getLogger(__name__) if ENABLE_JIT_DEEPGEMM: import deep_gemm - - if DEEPGEMM_BLACKWELL: - from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw - from deep_gemm import ( - fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw, - ) - from deep_gemm import ( - m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw, - ) - else: - from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw - from deep_gemm import get_col_major_tma_aligned_tensor - from deep_gemm import ( - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw, - ) - from deep_gemm import ( - m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw, - ) + from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor +# TODO maybe rename these functions def grouped_gemm_nt_f8f8bf16_masked( lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor, masked_m: torch.Tensor, expected_m: int, - recipe=None, ): num_groups, _, k = lhs[0].shape _, n, _ = rhs[0].shape @@ -51,13 +34,12 @@ def grouped_gemm_nt_f8f8bf16_masked( with compile_utils.deep_gemm_execution_hook( expected_m, n, k, num_groups, kernel_type ): - _grouped_gemm_nt_f8f8bf16_masked_raw( + deep_gemm.fp8_m_grouped_gemm_nt_masked( lhs, rhs, out, masked_m, expected_m, - **({"recipe": recipe} if DEEPGEMM_BLACKWELL else {}) ) @@ -72,7 +54,7 @@ def grouped_gemm_nt_f8f8bf16_contig( kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): - _grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices) + deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices) def gemm_nt_f8f8bf16( @@ -86,7 +68,7 @@ def gemm_nt_f8f8bf16( kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16 with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): - _gemm_nt_f8f8bf16_raw( + deep_gemm.fp8_gemm_nt( lhs, rhs, out, diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 2176ad228..f0512365b 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -298,7 +298,7 @@ def _per_token_group_quant_8bit_raw( ) if scale_ue8m0: - from deep_gemm.utils.layout import transform_sf_into_required_layout + from deep_gemm import transform_sf_into_required_layout assert group_size == 128 x_s = transform_sf_into_required_layout( @@ -338,7 +338,7 @@ def _per_token_group_quant_8bit_fuse_silu_and_mul( # scale_ue8m0=scale_ue8m0, # ) - from deep_gemm.utils.layout import transform_sf_into_required_layout + from deep_gemm import transform_sf_into_required_layout from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index c08cabe5e..42c894590 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -459,7 +459,7 @@ def _requant_weight_ue8m0( import deep_gemm.utils.layout sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128) - sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf) + sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) return sf out_s = _transform_scale(out_s, mn=out_w.shape[-2]) diff --git a/python/sglang/srt/layers/quantization/mxfp4_tensor.py b/python/sglang/srt/layers/quantization/mxfp4_tensor.py index e7b9a8346..76cb92c54 100644 --- a/python/sglang/srt/layers/quantization/mxfp4_tensor.py +++ b/python/sglang/srt/layers/quantization/mxfp4_tensor.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch @@ -24,7 +26,7 @@ class MXFP4QuantizeUtil: E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5]) @classmethod - def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple: + def quantize(cls, input: torch.Tensor, block_size: Optional[int]) -> tuple: """Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported. Args: input (torch.Tensor): The input tensor to be quantized. diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 975291435..71feb6ae2 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -50,25 +50,17 @@ FetchContent_Declare( ) FetchContent_Populate(repo-cutlass) -# DeepGEMM -if("${CUDA_VERSION}" VERSION_EQUAL "12.8") - set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM") - set(DeepGEMM_TAG "blackwell") -elseif("${CUDA_VERSION}" VERSION_EQUAL "12.9") - set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM") - set(DeepGEMM_TAG "blackwell") -elseif("${CUDA_VERSION}" VERSION_EQUAL "13.0") - set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM") - set(DeepGEMM_TAG "blackwell") -else() - set(DeepGEMM_REPO "https://github.com/deepseek-ai/DeepGEMM") - set(DeepGEMM_TAG "391755ada0ffefa9a6a52b6f14dcaf22d1a463e0") -endif() +FetchContent_Declare( + repo-fmt + GIT_REPOSITORY https://github.com/fmtlib/fmt + GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28 + GIT_SHALLOW OFF +) FetchContent_Declare( repo-deepgemm - GIT_REPOSITORY ${DeepGEMM_REPO} - GIT_TAG ${DeepGEMM_TAG} + GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM + GIT_TAG sgl GIT_SHALLOW OFF ) FetchContent_Populate(repo-deepgemm) @@ -86,7 +78,7 @@ FetchContent_Populate(repo-triton) FetchContent_Declare( repo-flashinfer GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git - GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3 + GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7 GIT_SHALLOW OFF ) FetchContent_Populate(repo-flashinfer) @@ -182,28 +174,11 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) list(APPEND SGL_KERNEL_CUDA_FLAGS "-gencode=arch=compute_100,code=sm_100" "-gencode=arch=compute_100a,code=sm_100a" - "-gencode=arch=compute_103,code=sm_103" - "-gencode=arch=compute_103a,code=sm_103a" + "-gencode=arch=compute_101,code=sm_101" + "-gencode=arch=compute_101a,code=sm_101a" "-gencode=arch=compute_120,code=sm_120" "-gencode=arch=compute_120a,code=sm_120a" ) - - # refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176 - if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0") - list(APPEND SGL_KERNEL_CUDA_FLAGS - "-gencode=arch=compute_110,code=sm_110" - "-gencode=arch=compute_110a,code=sm_110a" - "-gencode=arch=compute_121,code=sm_121" - "-gencode=arch=compute_121a,code=sm_121a" - "--compress-mode=size" - ) - else() - list(APPEND SGL_KERNEL_CUDA_FLAGS - "-gencode=arch=compute_101,code=sm_101" - "-gencode=arch=compute_101a,code=sm_101a" - ) - endif() - else() list(APPEND SGL_KERNEL_CUDA_FLAGS "-use_fast_math" @@ -286,6 +261,12 @@ set(SOURCES "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu" "csrc/moe/marlin_moe_wna16/ops.cu" + "csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu" + "csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu" + "csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu" + "csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu" + "csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu" + "csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu" "csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_topk_softmax_kernels.cu" @@ -321,8 +302,6 @@ target_include_directories(common_ops PRIVATE ${repo-cutlass_SOURCE_DIR}/examples/common ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src ) -set_source_files_properties("csrc/gemm/per_token_group_quant_8bit" PROPERTIES COMPILE_OPTIONS "--use_fast_math") - find_package(Python3 COMPONENTS Interpreter REQUIRED) execute_process( @@ -464,13 +443,38 @@ install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel) set(DEEPGEMM_SOURCES "${repo-deepgemm_SOURCE_DIR}/csrc/python_api.cpp" ) -# JIT Logic -# DeepGEMM -install(DIRECTORY "${repo-deepgemm_SOURCE_DIR}/deep_gemm/" - DESTINATION "deep_gemm" - PATTERN ".git*" EXCLUDE - PATTERN "__pycache__" EXCLUDE) +Python_add_library(deep_gemm_cpp MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${DEEPGEMM_SOURCES}) + +# Link against necessary libraries, including nvrtc for JIT compilation. +target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} c10 cuda nvrtc mscclpp_static) + +# Add include directories needed by DeepGEMM. +target_include_directories(deep_gemm_cpp PRIVATE + ${repo-deepgemm_SOURCE_DIR}/deep_gemm/include + ${repo-cutlass_SOURCE_DIR}/include + ${repo-fmt_SOURCE_DIR}/include +) + +# Apply the same compile options as common_ops. +target_compile_options(deep_gemm_cpp PRIVATE $<$:${SGL_KERNEL_CUDA_FLAGS}>) + +# Create an empty __init__.py to make `deepgemm` a Python package. +file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py "") +install( + FILES ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py + DESTINATION deep_gemm + RENAME __init__.py +) + +# Install the compiled DeepGEMM API library. +install(TARGETS deep_gemm_cpp LIBRARY DESTINATION deep_gemm) + +# Install the source files required by DeepGEMM for runtime JIT compilation. +install( + DIRECTORY ${repo-deepgemm_SOURCE_DIR}/deep_gemm/ + DESTINATION deep_gemm +) install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/" DESTINATION "deep_gemm/include/cute") diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py b/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py index b3ed863a3..833d074ea 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -9,7 +9,6 @@ import jinja2 FILE_HEAD = """ // auto generated by generate.py // clang-format off -#pragma once #include "kernel.h" #include "marlin_template.h" @@ -34,17 +33,6 @@ TEMPLATE = ( "( MARLIN_KERNEL_PARAMS );" ) -KERNEL_FILE_TEMPLATE = ( - "// auto generated by generate.py\n" - "// clang-format off\n" - "#pragma once\n\n" - "{% for kernel_file in kernel_files %}" - '#include "{{ kernel_file }}"\n' - "{% endfor %}" -) - -KERNEL_FILE_NAME = "kernel_marlin.cuh" - # int8 with zero point case (sglang::kU8) is also supported, # we don't add it to reduce wheel size. SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"] @@ -60,12 +48,11 @@ DTYPES = ["fp16", "bf16"] def remove_old_kernels(): - for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"): + for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): subprocess.call(["rm", "-f", filename]) def generate_new_kernels(): - kernel_files = set() for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): has_zp = "B" not in scalar_type all_template_str_list = [] @@ -108,20 +95,10 @@ def generate_new_kernels(): file_content = FILE_HEAD + "\n\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" - filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh" + filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu" with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: f.write(file_content) - kernel_files.add(filename) - - kernel_files = list(kernel_files) - kernel_files.sort() - - file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render( - kernel_files=kernel_files - ) - with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f: - f.write(file_content) if __name__ == "__main__": diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h index afa7c377b..88d157507 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h @@ -1,4 +1,3 @@ -#pragma once #ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu index 7e83bed8f..1e3d923ae 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu @@ -1,6 +1,5 @@ // auto generated by generate.py // clang-format off -#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu index 60e2dea31..513ddc2ed 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu @@ -1,6 +1,5 @@ // auto generated by generate.py // clang-format off -#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu index 7eb6b18de..eebe9d3da 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu @@ -1,6 +1,5 @@ // auto generated by generate.py // clang-format off -#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu index ec41e018b..9adc6623a 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu @@ -1,6 +1,5 @@ // auto generated by generate.py // clang-format off -#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu index 7df28701b..66ca7e36a 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu @@ -1,6 +1,5 @@ // auto generated by generate.py // clang-format off -#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu index 1150844e2..21fdf0c1a 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu @@ -1,6 +1,5 @@ // auto generated by generate.py // clang-format off -#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh deleted file mode 100644 index bb828dc5b..000000000 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh +++ /dev/null @@ -1,10 +0,0 @@ -// auto generated by generate.py -// clang-format off -#pragma once - -#include "kernel_bf16_ku4.cuh" -#include "kernel_bf16_ku4b8.cuh" -#include "kernel_bf16_ku8b128.cuh" -#include "kernel_fp16_ku4.cuh" -#include "kernel_fp16_ku4b8.cuh" -#include "kernel_fp16_ku8b128.cuh" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h index ade562af6..71c91839d 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -18,8 +18,6 @@ /* * Adapted from https://github.com/IST-DASLab/marlin */ -#pragma once - #ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #endif diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu index b249f6415..f430390d1 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu @@ -24,7 +24,6 @@ #endif #include "kernel.h" -#include "kernel_marlin.cuh" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert( \ diff --git a/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu index c9bc8a628..050e8d52b 100644 --- a/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu +++ b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu @@ -23,7 +23,6 @@ limitations under the License. #ifndef USE_ROCM #include #include -#include #else #include #include @@ -34,16 +33,6 @@ limitations under the License. #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) -// Define reduction operators based on CUDA version -// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum -#if CUDA_VERSION >= 12090 -using MaxReduceOp = cuda::maximum<>; -using MinReduceOp = cuda::minimum<>; -#else -using MaxReduceOp = cub::Max; -using MinReduceOp = cub::Min; -#endif - /// Aligned array type template < typename T, @@ -83,6 +72,7 @@ __launch_bounds__(TPB) __global__ const int thread_row_offset = blockIdx.x * num_cols; + cub::Sum sum; float threadData(-FLT_MAX); // Don't touch finished rows. @@ -95,7 +85,7 @@ __launch_bounds__(TPB) __global__ threadData = max(convert_to_float(input[idx]), threadData); } - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp()); + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); if (threadIdx.x == 0) { float_max = maxElem; @@ -109,7 +99,7 @@ __launch_bounds__(TPB) __global__ threadData += exp((convert_to_float(input[idx]) - float_max)); } - const auto Z = BlockReduce(tmpStorage).Sum(threadData); + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); if (threadIdx.x == 0) { normalizing_factor = 1.f / Z; diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 52ee620e4..c47b389ec 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "sgl-kernel" -version = "0.3.6.post2" +version = "0.3.7" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.10" diff --git a/sgl-kernel/pyproject_cpu.toml b/sgl-kernel/pyproject_cpu.toml index d1098e958..d5fe91c42 100644 --- a/sgl-kernel/pyproject_cpu.toml +++ b/sgl-kernel/pyproject_cpu.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "sgl-kernel" -version = "0.3.6.post2" +version = "0.3.7" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.10" diff --git a/sgl-kernel/pyproject_rocm.toml b/sgl-kernel/pyproject_rocm.toml index 9b520402f..826a77398 100644 --- a/sgl-kernel/pyproject_rocm.toml +++ b/sgl-kernel/pyproject_rocm.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.3.6.post2" +version = "0.3.7" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.10" diff --git a/sgl-kernel/python/sgl_kernel/version.py b/sgl-kernel/python/sgl_kernel/version.py index 215f77650..8879c6c77 100644 --- a/sgl-kernel/python/sgl_kernel/version.py +++ b/sgl-kernel/python/sgl_kernel/version.py @@ -1 +1 @@ -__version__ = "0.3.6.post2" +__version__ = "0.3.7"