From b091a7a5c9e47eebee9cfab97adcbc17ba61c726 Mon Sep 17 00:00:00 2001 From: lizhigong <306128847@qq.com> Date: Wed, 22 Oct 2025 15:16:12 +0800 Subject: [PATCH] adapt w4a8 marlin deepep dp ep (cherry picked from commit a0fb70e9c15dadc32103adf3acc9b29abe5516c5) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 384 ++++++++++++++++-- .../srt/layers/moe/token_dispatcher/deepep.py | 31 +- 2 files changed, 383 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 12f04eb9e..892ead5b2 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig import torch from sglang.srt import single_batch_overlap @@ -54,7 +55,286 @@ if _use_aiter: logger = logging.getLogger(__name__) -class DeepEPMoE(FusedMoE): +# TODO(kaixih@nvidia): ideally we should merge this logic into +# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale. +@torch.compile +def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor: + temp = x.to(torch.float32).view(torch.int32) + exp = torch.bitwise_right_shift(temp, 23) + mant = torch.bitwise_and(temp, 0x7FFFFF) + is_ru = torch.logical_and( + torch.logical_and((mant > 0), (exp != 0xFE)), + ~torch.logical_and((exp == 0), (mant <= 0x400000)), + ) + exp = torch.where(is_ru, exp + 1, exp) + new_x = exp.to(torch.uint8).view(torch.int) + return new_x.transpose(1, 2).contiguous().transpose(1, 2) + + +class EPMoE(FusedMoE): + """ + MoE Expert Parallel Impl + + + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + layer_id: int, + num_fused_shared_experts: int = 0, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + activation: str = "silu", + routed_scaling_factor: Optional[float] = None, + gemm1_alpha: Optional[float] = None, + gemm1_clamp_limit: Optional[float] = None, + with_bias: bool = False, + ): + super().__init__( + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_fused_shared_experts=num_fused_shared_experts, + layer_id=layer_id, + top_k=top_k, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + activation=activation, + # apply_router_weight_on_input=apply_router_weight_on_input, + routed_scaling_factor=routed_scaling_factor, + gemm1_alpha=gemm1_alpha, + gemm1_clamp_limit=gemm1_clamp_limit, + with_bias=with_bias, + ) + + self.intermediate_size = intermediate_size + if isinstance(quant_config, Fp8Config): + self.use_block_quant = getattr(self.quant_method, "block_quant", False) + self.block_shape = ( + self.quant_method.quant_config.weight_block_size + if self.use_block_quant + else None + ) + self.use_fp8_w8a8 = True + self.fp8_dtype = torch.float8_e4m3fn + self.activation_scheme = quant_config.activation_scheme + self.use_w4a8_marlin = False + elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig): + self.use_block_quant = getattr(self.quant_method, "block_quant", False) + self.block_shape = ( + self.quant_method.quant_config.weight_block_size + if self.use_block_quant + else None + ) + self.use_fp8_w8a8 = False + self.activation_scheme = None + self.use_w4a8_marlin = True + else: + self.use_fp8_w8a8 = False + self.use_block_quant = False + self.block_shape = None + self.activation_scheme = None + self.use_w4a8_marlin = False + + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: + return self.forward_deepgemm(hidden_states, topk_output) + else: + return super().forward(hidden_states, topk_output) + + def forward_deepgemm( + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + ): + + self.w13_weight_fp8 = ( + self.w13_weight, + ( + self.w13_weight_scale_inv + if self.use_block_quant + else self.w13_weight_scale + ), + ) + self.w2_weight_fp8 = ( + self.w2_weight, + self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale, + ) + + assert self.quant_method is not None + assert self.moe_runner_config.activation == "silu" + + hidden_states_shape = hidden_states.shape + hidden_states_dtype = hidden_states.dtype + hidden_states_device = hidden_states.device + + topk_weights, topk_ids, _ = topk_output + + if not self.use_block_quant: + # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm + scale_block_size = 128 + w13_weight_scale_n = 2 * ( + (self.intermediate_size + scale_block_size - 1) // scale_block_size + ) + w13_weight_scale_k = ( + hidden_states_shape[-1] + scale_block_size - 1 + ) // scale_block_size + w13_weight_scale = ( + self.w13_weight_scale.unsqueeze(1) + .repeat_interleave(w13_weight_scale_n, dim=1) + .unsqueeze(2) + .repeat_interleave(w13_weight_scale_k, dim=2) + ) + self.w13_weight_fp8 = ( + self.w13_weight, + w13_weight_scale, + ) + w2_weight_scale_n = ( + hidden_states_shape[-1] + scale_block_size - 1 + ) // scale_block_size + w2_weight_scale_k = ( + self.intermediate_size + scale_block_size - 1 + ) // scale_block_size + w2_weight_scale = ( + self.w2_weight_scale.unsqueeze(1) + .repeat_interleave(w2_weight_scale_n, dim=1) + .unsqueeze(2) + .repeat_interleave(w2_weight_scale_k, dim=2) + ) + self.w2_weight_fp8 = ( + self.w2_weight, + w2_weight_scale, + ) + + # PreReorder + m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = ( + moe_ep_deepgemm_preprocess( + topk_ids, + self.num_experts, + hidden_states, + self.top_k, + self.start_expert_id, + self.end_expert_id, + self.block_shape, + ) + ) + + dispose_tensor(hidden_states) + + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: + b, s_mn, s_k = gateup_input_scale.shape + assert ( + s_mn % 4 == 0 and s_k % 4 == 0 + ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})" + + # GroupGemm-0 + gateup_input_fp8 = ( + gateup_input, + ( + _cast_to_e8m0_with_rounding_up(gateup_input_scale) + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( + gateup_input_scale + ) + ), + ) + num_groups, m, k = gateup_input_fp8[0].size() + n = self.w13_weight.size(1) + gateup_output = torch.empty( + (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16 + ) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + gateup_input_fp8, + self.w13_weight_fp8, + gateup_output, + masked_m, + expected_m, + ) + del gateup_input + del gateup_input_fp8 + + # Act + down_input = torch.empty( + ( + gateup_output.shape[0], + gateup_output.shape[1], + gateup_output.shape[2] // 2, + ), + device=hidden_states_device, + dtype=self.fp8_dtype, + ) + scale_block_size = 128 + down_input_scale = torch.empty( + ( + gateup_output.shape[0], + gateup_output.shape[1], + gateup_output.shape[2] // 2 // scale_block_size, + ), + device=hidden_states_device, + dtype=torch.float32, + ) + silu_and_mul_masked_post_quant_fwd( + gateup_output, + down_input, + down_input_scale, + scale_block_size, + masked_m, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + ) + del gateup_output + + # GroupGemm-1 + n = self.w2_weight.size(1) + down_input_fp8 = ( + down_input, + ( + down_input_scale + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale) + ), + ) + down_output = torch.empty( + (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16 + ) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + down_input_fp8, + self.w2_weight_fp8, + down_output, + masked_m, + expected_m, + ) + del down_input + del down_input_fp8 + + # PostReorder + output = torch.empty( + hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device + ) + post_reorder_triton_kernel[(hidden_states_shape[0],)]( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + self.start_expert_id, + self.end_expert_id, + self.top_k, + hidden_states_shape[1], + m_max * self.start_expert_id, + BLOCK_SIZE=512, + ) + if self.moe_runner_config.routed_scaling_factor is not None: + output *= self.moe_runner_config.routed_scaling_factor + return output + + +class DeepEPMoE(EPMoE): """ MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main) Mooncake EP shares the same class, as they expose the same interface. @@ -106,11 +386,28 @@ class DeepEPMoE(FusedMoE): self.deepep_mode = get_deepep_mode() - if self.deepep_mode.enable_low_latency() and not _is_npu: - # NPU supports low_latency deepep without deepgemm - assert ( - deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM - ), f"DeepEP {self.deepep_mode} mode requires deep_gemm" + # TODO: move to the beginning of the file + from sglang.srt.distributed.parallel_state import get_tp_group + from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher + + self.deepep_dispatcher = MaybeTboDeepEPDispatcher( + group=get_tp_group().device_group, + router_topk=self.top_k, + permute_fusion=True, + num_experts=self.num_experts, + num_local_experts=self.num_local_experts, + hidden_size=hidden_size, + params_dtype=params_dtype, + deepep_mode=self.deepep_mode, + async_finish=True, # TODO + return_recv_hook=True, + ) + + # if self.deepep_mode.enable_low_latency() and not _is_npu: + # # NPU supports low_latency deepep without deepgemm + # assert ( + # deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + # ), f"DeepEP {self.deepep_mode} mode requires deep_gemm" if _use_aiter: # expert_mask is of size (self.num_local_experts + 1), # the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid) @@ -124,23 +421,23 @@ class DeepEPMoE(FusedMoE): ) # the last one is invalid rank_id self.expert_mask[:-1] = 1 - elif not _is_npu: - self.w13_weight_fp8 = ( - self.w13_weight, - ( - self.w13_weight_scale_inv - if self.use_block_quant or self.use_w4afp8 - else self.w13_weight_scale - ), - ) - self.w2_weight_fp8 = ( - self.w2_weight, - ( - self.w2_weight_scale_inv - if self.use_block_quant or self.use_w4afp8 - else self.w2_weight_scale - ), - ) + # elif not _is_npu: + # self.w13_weight_fp8 = ( + # self.w13_weight, + # ( + # self.w13_weight_scale_inv + # if self.use_block_quant + # else self.w13_weight_scale + # ), + # ) + # self.w2_weight_fp8 = ( + # self.w2_weight, + # ( + # self.w2_weight_scale_inv + # if self.use_block_quant + # else self.w2_weight_scale + # ), + # ) def forward( self, @@ -187,10 +484,15 @@ class DeepEPMoE(FusedMoE): assert DispatchOutputChecker.format_is_deepep(dispatch_output) return self.forward_npu(dispatch_output) if DispatchOutputChecker.format_is_deepep_normal(dispatch_output): - if self.use_w4afp8: - return self.forward_cutlass_w4afp8(dispatch_output) - assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 - return self.forward_deepgemm_contiguous(dispatch_output) + #assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: + return self.forward_deepgemm_contiguous(dispatch_output) + elif self.use_w4a8_marlin: + return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output) + else: + raise ValueError( + f"Dispatch output is not supported" + ) elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): if ( get_moe_runner_backend().is_flashinfer_cutedsl() @@ -255,6 +557,34 @@ class DeepEPMoE(FusedMoE): expert_mask=self.expert_mask, ) + def forward_deepgemm_w4a8_marlin_contiguous( + self, + dispatch_output: DeepEPNormalOutput, + ): + hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = ( + dispatch_output + ) + assert self.quant_method is not None + assert self.moe_runner_config.activation == "silu" + # if num_recv_tokens_per_expert is None: + return hidden_states_int8.bfloat16() + # expert_output = self.quant_method.apply_ep( + # layer=self, + # x=dispatch_output, + # topk_weights=topk_weights, + # topk_ids=topk_idx, + # global_num_experts=self.global_num_experts, + # expert_map=self.expert_map, + # activation=self.activation, + # apply_router_weight_on_input=self.apply_router_weight_on_input, + # use_nn_moe=self.use_nn_moe, + # num_local_tokens=dispatch_recv_num_token, + # config_select_bs=hidden_states.shape[0], + # scales=dispatch_scales if self.use_int8_dispatch else None + # # routed_scaling_factor=self.routed_scaling_factor, + # ) + # return expert_output + def forward_deepgemm_contiguous( self, dispatch_output: DeepEPNormalOutput, diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index 8c6796bf1..f31d3733c 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -460,11 +460,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): overlap_args: Optional["CombineOverlapArgs"], ): - if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu: - output = hidden_states - else: - raise NotImplementedError() # triton runner was supported but it's temporarily disabled - + #if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu: + output = hidden_states + # else: + # if hidden_states.shape[0] > 0: + # num_tokens = self.src2dst.shape[0] // self.router_topk + # output = torch.empty( + # (num_tokens, hidden_states.shape[1]), + # device=hidden_states.device, + # dtype=hidden_states.dtype, + # ) + # deepep_post_reorder_triton_kernel[(num_tokens,)]( + # hidden_states, + # output, + # self.src2dst, + # topk_idx, + # topk_weights, + # self.router_topk, + # hidden_states.shape[1], + # BLOCK_SIZE=512, + # ) + # else: + # output = torch.zeros( + # (0, hidden_states.shape[1]), + # device=hidden_states.device, + # dtype=hidden_states.dtype, + # ) previous_event = Buffer.capture() if self.async_finish else None return output, previous_event