From 755f314785fc1fa620526778f929c2425d645671 Mon Sep 17 00:00:00 2001 From: Alex Sun Date: Tue, 24 Jun 2025 17:05:47 +0800 Subject: [PATCH] [AMD] add aiter fused moe in DeepEP path (#7268) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 91 ++++++++++++++++--- .../srt/layers/moe/ep_moe/token_dispatcher.py | 21 ++++- 2 files changed, 98 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 5b654b2d8..30fb0b6b7 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -54,10 +54,16 @@ from sglang.srt.utils import ( _is_hip = is_hip() _is_fp8_fnuz = is_fp8_fnuz() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_hip: from vllm._custom_ops import scaled_fp8_quant +if _use_aiter: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + from aiter.ops.shuffle import shuffle_weight + logger = logging.getLogger(__name__) @@ -1046,6 +1052,15 @@ class Fp8EPMoEMethod(Fp8MoEMethod): w2_weight_scale, requires_grad=False ) layer.w2_input_scale = None + if _use_aiter: + layer.w13_weight = torch.nn.Parameter( + shuffle_weight(layer.w13_weight.data, (16, 16)), + requires_grad=False, + ) + layer.w2_weight = torch.nn.Parameter( + shuffle_weight(layer.w2_weight.data, (16, 16)), + requires_grad=False, + ) return def apply( @@ -1117,18 +1132,36 @@ class DeepEPMoE(EPMoE): assert ( deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM ), f"DeepEP {self.deepep_mode} mode requires deep_gemm" - self.w13_weight_fp8 = ( - self.w13_weight, - ( - self.w13_weight_scale_inv - if self.use_block_quant - else self.w13_weight_scale - ), - ) - self.w2_weight_fp8 = ( - self.w2_weight, - self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale, - ) + if _use_aiter: + # expert_mask is of size (self.num_experts_per_partition + 1), + # the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid) + # for instance, if we have 4 experts on this rank, we would have a expert_mask like: + # self.expert_mask = [1, 1, 1, 1, 0] + # idx from 0-3 is valid and will be processed, while idx == 4 will be masked out + self.expert_mask = torch.zeros( + (self.num_experts_per_partition + 1), + device=torch.cuda.current_device(), + dtype=torch.int, + ) + # the last one is invalid rank_id + self.expert_mask[:-1] = 1 + else: + self.w13_weight_fp8 = ( + self.w13_weight, + ( + self.w13_weight_scale_inv + if self.use_block_quant + else self.w13_weight_scale + ), + ) + self.w2_weight_fp8 = ( + self.w2_weight, + ( + self.w2_weight_scale_inv + if self.use_block_quant + else self.w2_weight_scale + ), + ) def forward( self, @@ -1142,6 +1175,9 @@ class DeepEPMoE(EPMoE): num_recv_tokens_per_expert: List[int], forward_mode: ForwardMode, ): + if _use_aiter: + # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel + return self.forward_aiter(hidden_states, topk_idx, topk_weights) resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: @@ -1274,6 +1310,37 @@ class DeepEPMoE(EPMoE): ) return down_output + def forward_aiter( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + ): + if hidden_states.shape[0] == 0: + return hidden_states + # in original deepep, idx == -1 meaning invalid and will not be processed. + # aiter does not accept -1, we use a expert mask to make these idx invalid + # (idx == num_experts_per_partition) meaning not used in aiter fused_moe + topk_idx_copy = topk_idx.to(torch.int32) + topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition + + return fused_moe( + hidden_states, + self.w13_weight, + self.w2_weight, + topk_weights, + topk_idx_copy, + w1_scale=self.w13_weight_scale_inv, + w2_scale=self.w2_weight_scale_inv, + quant_type=QuantType.per_128x128, + activation=( + ActivationType.Silu + if self.activation == "silu" + else ActivationType.Gelu + ), + expert_mask=self.expert_mask, + ) + def forward_deepgemm_contiguous( self, hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 091e9ec69..ac9217da8 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -6,7 +6,13 @@ from sglang.srt.managers.expert_distribution import ( get_global_expert_distribution_recorder, ) from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.utils import DeepEPMode, get_int_env_var, load_json_config +from sglang.srt.utils import ( + DeepEPMode, + get_bool_env_var, + get_int_env_var, + is_hip, + load_json_config, +) try: from deep_ep import Buffer, Config @@ -32,6 +38,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardMode +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() + logger = logging.getLogger(__name__) @@ -376,6 +384,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py """ + if _use_aiter: + # skip permutation here as aiter fused_moe has fused inside + reorder_topk_ids = torch.empty( + (0,), device=hidden_states.device, dtype=torch.int64 + ) + seg_indptr = torch.zeros( + (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64 + ) + return reorder_topk_ids, seg_indptr, hidden_states reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( topk_idx, self.num_experts @@ -409,7 +426,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): topk_idx: torch.Tensor, topk_weights: torch.Tensor, ): - if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter: output = hidden_states else: if hidden_states.shape[0] > 0: