diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index cde2cf14a..0f0b0180f 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -4,6 +4,7 @@ from typing import List, Optional import torch import triton +from sglang.math_utils import ceil_div from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import dispose_tensor, is_cuda @@ -15,11 +16,6 @@ if _is_cuda: sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8, ) - try: - from deep_gemm import ceil_div - except ImportError: - logger.error(f"Failed to import ceil_div from deep_gemm.") - import triton.language as tl diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index f0c0c5f6e..b0259a616 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1,30 +1,11 @@ import logging from typing import Callable, List, Optional, Tuple +import einops import torch +from sgl_kernel import silu_and_mul from torch.nn import Module -from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM -from sglang.srt.managers.expert_location import get_global_expert_location_metadata -from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo -from sglang.srt.managers.schedule_batch import global_server_args_dict - -try: - from deep_gemm import ( - get_col_major_tma_aligned_tensor, - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, - m_grouped_gemm_fp8_fp8_bf16_nt_masked, - ) - from sgl_kernel import silu_and_mul - - from sglang.srt.layers.quantization.fp8_kernel import ( - sglang_per_token_group_quant_fp8, - ) - - use_deep_gemm = True -except ImportError: - use_deep_gemm = False - from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -45,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -52,10 +34,20 @@ from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8_kernel import ( scaled_fp8_quant, + sglang_per_token_group_quant_fp8, sglang_per_token_quant_fp8, ) +from sglang.srt.managers.expert_location import get_global_expert_location_metadata +from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs +from sglang.srt.utils import ( + DeepEPMode, + dispose_tensor, + get_bool_env_var, + is_hip, + set_weight_attrs, +) _is_hip = is_hip() @@ -680,7 +672,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod): params_dtype: torch.dtype, **extra_weight_attrs, ): - if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn @@ -920,7 +911,9 @@ class DeepEPMoE(EPMoE): ) self.deepep_mode = deepep_mode if self.deepep_mode.enable_low_latency(): - assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm" + assert ( + deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + ), f"DeepEP {self.deepep_mode} mode requires deep_gemm" self.w13_weight_fp8 = ( self.w13_weight, ( @@ -948,7 +941,7 @@ class DeepEPMoE(EPMoE): ): resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: - if _ENABLE_JIT_DEEPGEMM: + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: return self.forward_deepgemm_contiguous( hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert ) @@ -1145,7 +1138,7 @@ class DeepEPMoE(EPMoE): dtype=torch.bfloat16, ) input_tensor[1] = tma_align_input_scale(input_tensor[1]) - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig( input_tensor, self.w13_weight_fp8, gateup_output, m_indices ) del input_tensor @@ -1169,7 +1162,7 @@ class DeepEPMoE(EPMoE): ) del down_input down_input_scale = tma_align_input_scale(down_input_scale) - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig( (down_input_fp8, down_input_scale), self.w2_weight_fp8, down_output, @@ -1202,8 +1195,13 @@ class DeepEPMoE(EPMoE): gateup_output = torch.empty( (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16 ) - m_grouped_gemm_fp8_fp8_bf16_nt_masked( - hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + hidden_states_fp8, + self.w13_weight_fp8, + gateup_output, + masked_m, + expected_m, + recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None, ) dispose_tensor(hidden_states_fp8[0]) @@ -1240,13 +1238,18 @@ class DeepEPMoE(EPMoE): n = self.w2_weight.size(1) down_input_fp8 = ( down_input, - get_col_major_tma_aligned_tensor(down_input_scale), + deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale), ) down_output = torch.empty( (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16 ) - m_grouped_gemm_fp8_fp8_bf16_nt_masked( - down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + down_input_fp8, + self.w2_weight_fp8, + down_output, + masked_m, + expected_m, + recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None, ) return down_output diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 8089168be..2028ecf04 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,7 +1,7 @@ import logging from dataclasses import dataclass -from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM +from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.managers.expert_distribution import ( get_global_expert_distribution_recorder, ) @@ -236,14 +236,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): topk_weights: torch.Tensor, ): topk_idx = topk_idx.to(torch.int64) - if _ENABLE_JIT_DEEPGEMM: + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: # TODO hard code 128 block quant,use fp8 communication hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128) previous_event = Buffer.capture() if self.async_finish else None return hidden_states, topk_idx, topk_weights, previous_event def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): - if _ENABLE_JIT_DEEPGEMM: + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: ( hidden_states, topk_idx, @@ -345,7 +345,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): previous_event=previous_event, async_finish=self.async_finish, allocate_on_comm_stream=(previous_event is not None) and self.async_finish, - expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1, + expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1, config=DeepEPConfig.get_instance().normal_dispatch_config, ) @@ -409,7 +409,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): topk_idx: torch.Tensor, topk_weights: torch.Tensor, ): - if _ENABLE_JIT_DEEPGEMM: + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: output = hidden_states else: if hidden_states.shape[0] > 0: diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py new file mode 100644 index 000000000..32f3ebe04 --- /dev/null +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py @@ -0,0 +1 @@ +from .entrypoint import * diff --git a/python/sglang/srt/layers/quantization/deep_gemm.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py similarity index 83% rename from python/sglang/srt/layers/quantization/deep_gemm.py rename to python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py index 985c0a274..75ebd9298 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py @@ -5,33 +5,24 @@ from dataclasses import dataclass from enum import IntEnum, auto from typing import Callable, Dict, List, Optional, Tuple -import torch from tqdm.contrib.concurrent import thread_map +from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( + DEEPGEMM_V202506, + ENABLE_JIT_DEEPGEMM, +) from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda +from sglang.srt.utils import get_bool_env_var, get_int_env_var logger = logging.getLogger(__name__) -_ENABLE_JIT_DEEPGEMM = False try: - import deep_gemm from deep_gemm import get_num_sms from deep_gemm.jit import build - from deep_gemm.jit.compiler import get_nvcc_compiler from deep_gemm.jit_kernels.gemm import get_best_configs from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType - - sm_version = get_device_sm() - if sm_version == 90: - if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"): - _ENABLE_JIT_DEEPGEMM = True except ImportError: - logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.") - - -def get_enable_jit_deepgemm(): - return _ENABLE_JIT_DEEPGEMM + pass _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) @@ -52,8 +43,10 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv( # NVRTC may have performance loss with some cases. # And NVCC JIT speed is also 9x faster in the ref commit _USE_NVRTC_DEFAULT = "0" -if _ENABLE_JIT_DEEPGEMM: +if ENABLE_JIT_DEEPGEMM: try: + from deep_gemm.jit.compiler import get_nvcc_compiler + get_nvcc_compiler() except: logger.warning( @@ -114,6 +107,7 @@ class DeepGemmKernelHelper: _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict() +# TODO improve naming def _compile_warning_1(): if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE: logger.warning( @@ -127,6 +121,7 @@ def _compile_warning_1(): ) +# TODO improve naming def _compile_warning_2(): logger.warning( "Entering DeepGEMM JIT Single Kernel Compile session. " @@ -238,6 +233,7 @@ def _compile_gemm_nt_f8f8bf16_one( _ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs) +# TODO further refactor warmup-related _KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = { DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper( name="m_grouped_gemm_fp8_fp8_bf16_nt_masked", @@ -270,7 +266,6 @@ def _maybe_compile_deep_gemm_one_type_all( num_groups: int, m_list: Optional[List[int]] = None, ) -> None: - global _INITIALIZATION_DICT global _BUILTIN_M_LIST @@ -304,56 +299,6 @@ def _maybe_compile_deep_gemm_one_type_all( thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS) -def grouped_gemm_nt_f8f8bf16_masked( - lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor, - masked_m: torch.Tensor, - expected_m: int, -): - num_groups, _, k = lhs[0].shape - _, n, _ = rhs[0].shape - - kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED - _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups) - - with _log_jit_build(expected_m, n, k, kernel_type): - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( - lhs, rhs, out, masked_m, expected_m - ) - - -def grouped_gemm_nt_f8f8bf16_contig( - lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor, - m_indices: torch.Tensor, -): - m, k = lhs[0].shape - num_groups, n, _ = rhs[0].shape - - kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG - _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups) - - with _log_jit_build(m, n, k, kernel_type): - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices) - - -def gemm_nt_f8f8bf16( - lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor, -): - m, k = lhs[0].shape - n, _ = rhs[0].shape - - kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16 - _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1) - - with _log_jit_build(m, n, k, kernel_type): - deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out) - - @contextmanager def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): if _IN_PRECOMPILE_STAGE: @@ -380,13 +325,14 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): @contextmanager -def configure_deep_gemm_num_sms(num_sms): - if num_sms is None: +def deep_gemm_execution_hook( + m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType +): + # not supported yet + if DEEPGEMM_V202506: + yield + return + + _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups) + with _log_jit_build(m, n, k, kernel_type): yield - else: - original_num_sms = deep_gemm.get_num_sms() - deep_gemm.set_num_sms(num_sms) - try: - yield - finally: - deep_gemm.set_num_sms(original_num_sms) diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py new file mode 100644 index 000000000..b6c776629 --- /dev/null +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py @@ -0,0 +1,26 @@ +import logging + +from sglang.srt.utils import get_bool_env_var, get_device_sm + +logger = logging.getLogger(__name__) + + +def _compute_enable_deep_gemm(): + try: + import deep_gemm + except ImportError: + logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.") + return False + + sm_version = get_device_sm() + if sm_version < 90: + return False + + return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true") + + +ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm() + +DEEPGEMM_V202506 = False + +DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_V202506 diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py new file mode 100644 index 000000000..514a4f884 --- /dev/null +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py @@ -0,0 +1,95 @@ +import logging +from contextlib import contextmanager +from typing import Tuple + +import torch + +from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils +from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( + DEEPGEMM_SCALE_UE8M0, + DEEPGEMM_V202506, + ENABLE_JIT_DEEPGEMM, +) +from sglang.srt.server_args import ServerArgs + +logger = logging.getLogger(__name__) + +if ENABLE_JIT_DEEPGEMM: + import deep_gemm + from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw + from deep_gemm import get_col_major_tma_aligned_tensor + from deep_gemm import ( + m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw, + ) + from deep_gemm import ( + m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw, + ) + + +def grouped_gemm_nt_f8f8bf16_masked( + lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, + recipe=None, +): + num_groups, _, k = lhs[0].shape + _, n, _ = rhs[0].shape + kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED + + with compile_utils.deep_gemm_execution_hook( + expected_m, n, k, num_groups, kernel_type + ): + _grouped_gemm_nt_f8f8bf16_masked_raw( + lhs, rhs, out, masked_m, expected_m, recipe=recipe + ) + + +def grouped_gemm_nt_f8f8bf16_contig( + lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, + m_indices: torch.Tensor, +): + m, k = lhs[0].shape + num_groups, n, _ = rhs[0].shape + kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG + + with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): + _grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices) + + +def gemm_nt_f8f8bf16( + lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, +): + m, k = lhs[0].shape + n, _ = rhs[0].shape + num_groups = 1 + kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16 + + with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): + _gemm_nt_f8f8bf16_raw( + lhs, + rhs, + out, + ) + + +def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): + compile_utils.update_deep_gemm_config(gpu_id, server_args) + + +@contextmanager +def configure_deep_gemm_num_sms(num_sms): + if num_sms is None: + yield + else: + original_num_sms = deep_gemm.get_num_sms() + deep_gemm.set_num_sms(num_sms) + try: + yield + finally: + deep_gemm.set_num_sms(original_num_sms) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 60ba7f5c5..601a4b088 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -23,7 +23,8 @@ import torch import triton import triton.language as tl -from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM +from sglang.math_utils import align +from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.utils import ( direct_register_custom_op, get_device_core_count, @@ -44,10 +45,6 @@ if _is_cuda: sgl_per_token_quant_fp8, ) - from sglang.srt.layers.quantization.deep_gemm import ( - gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16, - ) - logger = logging.getLogger(__name__) @@ -67,7 +64,6 @@ else: fp8_max = torch.finfo(fp8_dtype).max fp8_min = -fp8_max - if supports_custom_op(): def deep_gemm_fp8_fp8_bf16_nt( @@ -77,7 +73,7 @@ if supports_custom_op(): Bs: torch.Tensor, C: torch.Tensor, ) -> None: - deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C) + deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C) def deep_gemm_fp8_fp8_bf16_nt_fake( A: torch.Tensor, @@ -797,12 +793,12 @@ def w8a8_block_fp8_matmul_deepgemm( M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype) # Deepgemm only supports output tensor type as bfloat16 - assert C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM + assert C.dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM if supports_custom_op(): torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) else: - deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C) + deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C) return C @@ -896,7 +892,7 @@ def w8a8_block_fp8_matmul( block_size: List[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - if output_dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM: + if output_dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: return w8a8_block_fp8_matmul_deepgemm( A, B, As, Bs, block_size, output_dtype=output_dtype ) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 2408af197..46ecf6267 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1,10 +1,10 @@ -import os -from curses import flash from typing import Callable, List, Optional, Tuple +import einops import torch from sglang.math_utils import align +from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.utils import is_sm100_supported @@ -15,7 +15,6 @@ try: except ImportError: VLLM_AVAILABLE = False -from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.fp8_kernel import ( fp8_dtype, fp8_max, @@ -138,7 +137,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable: return cutlass_w8a8_block_fp8_linear_with_fallback elif _use_aiter: return aiter_w8a8_block_fp8_linear - elif _ENABLE_JIT_DEEPGEMM: + elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: return deepgemm_w8a8_block_fp8_linear_with_fallback else: return triton_w8a8_block_fp8_linear diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 995dedd02..1847af151 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -26,6 +26,7 @@ from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist +from sglang.srt import debug_utils from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -45,10 +46,9 @@ from sglang.srt.layers.dp_attention import ( initialize_dp_attention, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer -from sglang.srt.layers.quantization.deep_gemm import ( - _ENABLE_JIT_DEEPGEMM, - update_deep_gemm_config, +from sglang.srt.layers.quantization import ( + deep_gemm_wrapper, + monkey_patch_isinstance_for_vllm_base_layer, ) from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model @@ -205,8 +205,8 @@ class ModelRunner: min_per_gpu_memory = self.init_torch_distributed() # Update deep gemm configure - if _ENABLE_JIT_DEEPGEMM: - update_deep_gemm_config(gpu_id, server_args) + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: + deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args) # If it is a draft model, tp_group can be different self.initialize(min_per_gpu_memory) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 82a0c1d91..83837a748 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -54,8 +54,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, per_tensor_quant_mla_fp8, @@ -110,10 +110,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_cuda: from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 - - from sglang.srt.layers.quantization.deep_gemm import ( - grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked, - ) else: from vllm._custom_ops import awq_dequantize @@ -981,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module): q_nope_out = q_nope.new_empty( (self.num_local_heads, aligned_m, self.kv_lora_rank) ) - deep_gemm_grouped_gemm_nt_f8f8bf16_masked( + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( (q_nope_val, q_nope_scale), (self.w_kc, self.w_scale_k), q_nope_out, @@ -1851,7 +1847,7 @@ class DeepseekV2ForCausalLM(nn.Module): and weight_block_size[1] == 128 and model_dtype == torch.bfloat16 ): - if _ENABLE_JIT_DEEPGEMM and get_bool_env_var( + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and get_bool_env_var( "SGL_USE_DEEPGEMM_BMM", "false" ): block_scale = weight_scale diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 9e83e0ba5..601d39183 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -11,7 +11,7 @@ from sglang.srt.layers.communicator import ( ScatterMode, ) from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher -from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms +from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.operations import execute_operations, execute_overlapped_operations @@ -479,7 +479,9 @@ def _model_forward_tbo( ) del inputs - with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms): + with deep_gemm_wrapper.configure_deep_gemm_num_sms( + operations_strategy.deep_gemm_num_sms + ): outputs_arr = execute_overlapped_operations( inputs_arr=inputs_arr, operations_arr=[operations_strategy.operations] * 2,