From a0fb70e9c15dadc32103adf3acc9b29abe5516c5 Mon Sep 17 00:00:00 2001 From: lizhigong <306128847@qq.com> Date: Wed, 22 Oct 2025 15:16:12 +0800 Subject: [PATCH] adapt w4a8 marlin deepep dp ep --- python/sglang/srt/layers/moe/ep_moe/layer.py | 97 ++++++++++++++----- .../srt/layers/moe/token_dispatcher/deepep.py | 52 +++++----- 2 files changed, 98 insertions(+), 51 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 0bd49600e..7049563a5 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, List, Optional, Union +from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig import torch import triton import triton.language as tl @@ -124,7 +125,6 @@ class EPMoE(FusedMoE): ) self.intermediate_size = intermediate_size - if isinstance(quant_config, Fp8Config): self.use_block_quant = getattr(self.quant_method, "block_quant", False) self.block_shape = ( @@ -135,11 +135,23 @@ class EPMoE(FusedMoE): self.use_fp8_w8a8 = True self.fp8_dtype = torch.float8_e4m3fn self.activation_scheme = quant_config.activation_scheme + self.use_w4a8_marlin = False + elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig): + self.use_block_quant = getattr(self.quant_method, "block_quant", False) + self.block_shape = ( + self.quant_method.quant_config.weight_block_size + if self.use_block_quant + else None + ) + self.use_fp8_w8a8 = False + self.activation_scheme = None + self.use_w4a8_marlin = True else: self.use_fp8_w8a8 = False self.use_block_quant = False self.block_shape = None self.activation_scheme = None + self.use_w4a8_marlin = False def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: @@ -386,11 +398,11 @@ class DeepEPMoE(EPMoE): return_recv_hook=True, ) - if self.deepep_mode.enable_low_latency() and not _is_npu: - # NPU supports low_latency deepep without deepgemm - assert ( - deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM - ), f"DeepEP {self.deepep_mode} mode requires deep_gemm" + # if self.deepep_mode.enable_low_latency() and not _is_npu: + # # NPU supports low_latency deepep without deepgemm + # assert ( + # deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + # ), f"DeepEP {self.deepep_mode} mode requires deep_gemm" if _use_aiter: # expert_mask is of size (self.num_local_experts + 1), # the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid) @@ -404,23 +416,23 @@ class DeepEPMoE(EPMoE): ) # the last one is invalid rank_id self.expert_mask[:-1] = 1 - elif not _is_npu: - self.w13_weight_fp8 = ( - self.w13_weight, - ( - self.w13_weight_scale_inv - if self.use_block_quant - else self.w13_weight_scale - ), - ) - self.w2_weight_fp8 = ( - self.w2_weight, - ( - self.w2_weight_scale_inv - if self.use_block_quant - else self.w2_weight_scale - ), - ) + # elif not _is_npu: + # self.w13_weight_fp8 = ( + # self.w13_weight, + # ( + # self.w13_weight_scale_inv + # if self.use_block_quant + # else self.w13_weight_scale + # ), + # ) + # self.w2_weight_fp8 = ( + # self.w2_weight, + # ( + # self.w2_weight_scale_inv + # if self.use_block_quant + # else self.w2_weight_scale + # ), + # ) def forward( self, @@ -466,8 +478,15 @@ class DeepEPMoE(EPMoE): assert DispatchOutputChecker.format_is_deepep(dispatch_output) return self.forward_npu(dispatch_output) if DispatchOutputChecker.format_is_deepep_normal(dispatch_output): - assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 - return self.forward_deepgemm_contiguous(dispatch_output) + #assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: + return self.forward_deepgemm_contiguous(dispatch_output) + elif self.use_w4a8_marlin: + return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output) + else: + raise ValueError( + f"Dispatch output is not supported" + ) elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): if get_moe_runner_backend().is_flashinfer_cutedsl(): return self.forward_flashinfer_cutedsl(dispatch_output) @@ -526,6 +545,34 @@ class DeepEPMoE(EPMoE): expert_mask=self.expert_mask, ) + def forward_deepgemm_w4a8_marlin_contiguous( + self, + dispatch_output: DeepEPNormalOutput, + ): + hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = ( + dispatch_output + ) + assert self.quant_method is not None + assert self.moe_runner_config.activation == "silu" + # if num_recv_tokens_per_expert is None: + return hidden_states_int8.bfloat16() + # expert_output = self.quant_method.apply_ep( + # layer=self, + # x=dispatch_output, + # topk_weights=topk_weights, + # topk_ids=topk_idx, + # global_num_experts=self.global_num_experts, + # expert_map=self.expert_map, + # activation=self.activation, + # apply_router_weight_on_input=self.apply_router_weight_on_input, + # use_nn_moe=self.use_nn_moe, + # num_local_tokens=dispatch_recv_num_token, + # config_select_bs=hidden_states.shape[0], + # scales=dispatch_scales if self.use_int8_dispatch else None + # # routed_scaling_factor=self.routed_scaling_factor, + # ) + # return expert_output + def forward_deepgemm_contiguous( self, dispatch_output: DeepEPNormalOutput, diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index 598f51331..db445b5a1 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -431,32 +431,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): deepep_post_reorder_triton_kernel, ) - if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu: - output = hidden_states - else: - if hidden_states.shape[0] > 0: - num_tokens = self.src2dst.shape[0] // self.router_topk - output = torch.empty( - (num_tokens, hidden_states.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - deepep_post_reorder_triton_kernel[(num_tokens,)]( - hidden_states, - output, - self.src2dst, - topk_idx, - topk_weights, - self.router_topk, - hidden_states.shape[1], - BLOCK_SIZE=512, - ) - else: - output = torch.zeros( - (0, hidden_states.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) + #if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu: + output = hidden_states + # else: + # if hidden_states.shape[0] > 0: + # num_tokens = self.src2dst.shape[0] // self.router_topk + # output = torch.empty( + # (num_tokens, hidden_states.shape[1]), + # device=hidden_states.device, + # dtype=hidden_states.dtype, + # ) + # deepep_post_reorder_triton_kernel[(num_tokens,)]( + # hidden_states, + # output, + # self.src2dst, + # topk_idx, + # topk_weights, + # self.router_topk, + # hidden_states.shape[1], + # BLOCK_SIZE=512, + # ) + # else: + # output = torch.zeros( + # (0, hidden_states.shape[1]), + # device=hidden_states.device, + # dtype=hidden_states.dtype, + # ) previous_event = Buffer.capture() if self.async_finish else None return output, previous_event