From 93cec4335fed91f07683e4d69cb7980ca050e64d Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 14 Jun 2025 14:00:17 +0800 Subject: [PATCH] Support new DeepGEMM (#7172) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 9 ++++++- .../srt/layers/moe/ep_moe/token_dispatcher.py | 2 ++ .../layers/moe/fused_moe_triton/fused_moe.py | 5 +--- .../deep_gemm_wrapper/configurer.py | 8 +++++- .../deep_gemm_wrapper/entrypoint.py | 26 +++++++++++++------ .../srt/layers/quantization/fp8_kernel.py | 23 +++++++++++++--- .../srt/layers/quantization/fp8_utils.py | 1 + python/sglang/srt/models/deepseek_v2.py | 4 +-- 8 files changed, 59 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index b0259a616..50bbf94c3 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1231,6 +1231,7 @@ class DeepEPMoE(EPMoE): down_input_scale, scale_block_size, masked_m, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, ) del gateup_output @@ -1238,7 +1239,13 @@ class DeepEPMoE(EPMoE): n = self.w2_weight.size(1) down_input_fp8 = ( down_input, - deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale), + ( + down_input_scale + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + else deep_gemm_wrapper.get_col_major_tma_aligned_tensor( + down_input_scale + ) + ), ) down_output = torch.empty( (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16 diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 2028ecf04..33b7d6929 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -584,6 +584,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): use_fp8=use_fp8, async_finish=not self.return_recv_hook, return_recv_hook=self.return_recv_hook, + round_scale=deep_gemm_wrapper.DEEPGEMM_V202506, + use_ue8m0=deep_gemm_wrapper.DEEPGEMM_V202506, ) ) return packed_recv_hidden, packed_recv_count, event, hook diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 79e90e90a..e8d3b58ce 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -12,6 +12,7 @@ import torch import triton import triton.language as tl +from sglang.math_utils import ceil_div from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, @@ -518,10 +519,6 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) -def ceil_div(a, b): - return (a + b - 1) // b - - @triton.jit def moe_align_block_size_stage1( topk_ids_ptr, diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py index b6c776629..adf52b2f1 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py @@ -21,6 +21,12 @@ def _compute_enable_deep_gemm(): ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm() -DEEPGEMM_V202506 = False +try: + from deep_gemm import fp8_gemm_nt + + # They have not given a name to this breaking change + DEEPGEMM_V202506 = True +except ImportError: + DEEPGEMM_V202506 = False DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_V202506 diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py index 514a4f884..b2551471d 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py @@ -16,14 +16,24 @@ logger = logging.getLogger(__name__) if ENABLE_JIT_DEEPGEMM: import deep_gemm - from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw - from deep_gemm import get_col_major_tma_aligned_tensor - from deep_gemm import ( - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw, - ) - from deep_gemm import ( - m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw, - ) + + if DEEPGEMM_V202506: + from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw + from deep_gemm import ( + fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw, + ) + from deep_gemm import ( + m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw, + ) + else: + from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw + from deep_gemm import get_col_major_tma_aligned_tensor + from deep_gemm import ( + m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw, + ) + from deep_gemm import ( + m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw, + ) def grouped_gemm_nt_f8f8bf16_masked( diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 601a4b088..fed587ba1 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -765,7 +765,15 @@ def prepare_block_fp8_matmul_inputs( assert A.shape[-1] == B.shape[-1] assert A.shape[:-1] == As.shape[:-1] assert A.is_contiguous() - assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + + if As.dtype == torch.float: + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + elif Bs.dtype == torch.int: + assert ( + triton.cdiv(triton.cdiv(A.shape[-1], block_k), 4) == As.shape[-1] + ), f"{A.shape=} {As.shape=} {block_size=}" + else: + raise NotImplementedError M = A.numel() // A.shape[-1] @@ -773,8 +781,17 @@ def prepare_block_fp8_matmul_inputs( assert B.is_contiguous() assert Bs.ndim == 2 N, K = B.shape - assert triton.cdiv(N, block_n) == Bs.shape[0] - assert triton.cdiv(K, block_k) == Bs.shape[1] + + if Bs.dtype == torch.float: + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + elif Bs.dtype == torch.int: + assert N == Bs.shape[0], f"{B.shape=} {Bs.shape=} {block_size=}" + assert ( + triton.cdiv(triton.cdiv(K, block_k), 4) == Bs.shape[1] + ), f"{B.shape=} {Bs.shape=} {block_size=}" + else: + raise NotImplementedError C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 46ecf6267..6a40a7e9d 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -238,6 +238,7 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback( block_size[1], column_major_scales=True, scale_tma_aligned=True, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, ) if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 83837a748..ee86901aa 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -51,7 +51,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization import deep_gemm_wrapper @@ -1932,7 +1932,7 @@ class DeepseekV2ForCausalLM(nn.Module): self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous()) self_attn.use_deep_gemm_bmm = True - if False: # TODO (pr-chain) + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: self._weight_requant_ue8m0() def _weight_requant_ue8m0(self):