From 137e75daa1d337b35a7ddc268f9d9e22de063530 Mon Sep 17 00:00:00 2001 From: Even Zhou Date: Sat, 9 Aug 2025 16:35:00 +0800 Subject: [PATCH] [Feature] Optimize DeepSeek's DeepEP on Ascend NPU (#8355) Co-authored-by: ronnie_zheng Co-authored-by: Hexq0210 --- .../sglang/srt/distributed/parallel_state.py | 6 +- .../srt/layers/attention/ascend_backend.py | 3 + python/sglang/srt/layers/moe/ep_moe/layer.py | 62 +++++++++++++- .../srt/layers/moe/token_dispatcher/deepep.py | 85 +++++++++++++------ python/sglang/srt/layers/moe/topk.py | 3 +- .../srt/layers/quantization/w8a8_int8.py | 70 ++++++++------- python/sglang/srt/layers/rotary_embedding.py | 42 ++++++++- 7 files changed, 210 insertions(+), 61 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index ad336c808..adb43158f 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -50,6 +50,8 @@ from sglang.srt.utils import ( supports_custom_op, ) +_is_npu = is_npu() + @dataclass class GraphCaptureContext: @@ -591,7 +593,7 @@ class GroupCoordinator: ) def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): - if not supports_custom_op(): + if _is_npu or not supports_custom_op(): self._all_gather_into_tensor(output, input) else: torch.ops.sglang.reg_all_gather_into_tensor( @@ -1127,7 +1129,7 @@ def init_model_parallel_group( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, - use_pynccl=not is_npu(), + use_pynccl=not _is_npu, use_pymscclpp=use_mscclpp_allreduce, use_custom_allreduce=use_custom_allreduce, use_hpu_communicator=True, diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 7bce68655..020f04dcd 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -75,6 +75,9 @@ class AscendAttnBackend(AttentionBackend): ) self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() + def get_cuda_graph_seq_len_fill_value(self): + return 1 + def forward_extend( self, q, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 464d5c938..8e99d212d 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -34,6 +34,7 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import ( + AscendDeepEPLLOutput, DeepEPLLOutput, DeepEPNormalOutput, DispatchOutput, @@ -387,7 +388,8 @@ class DeepEPMoE(EPMoE): return_recv_hook=True, ) - if self.deepep_mode.enable_low_latency(): + if self.deepep_mode.enable_low_latency() and not _is_npu: + # NPU supports low_latency deepep without deepgemm assert ( deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM ), f"DeepEP {self.deepep_mode} mode requires deep_gemm" @@ -404,7 +406,7 @@ class DeepEPMoE(EPMoE): ) # the last one is invalid rank_id self.expert_mask[:-1] = 1 - else: + elif not _is_npu: self.w13_weight_fp8 = ( self.w13_weight, ( @@ -459,6 +461,8 @@ class DeepEPMoE(EPMoE): if _use_aiter: # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel return self.forward_aiter(dispatch_output) + if _is_npu: + return self.forward_npu(dispatch_output) if dispatch_output.format.is_deepep_normal(): assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 return self.forward_deepgemm_contiguous(dispatch_output) @@ -723,6 +727,60 @@ class DeepEPMoE(EPMoE): return down_output + def forward_npu( + self, + dispatch_output: DeepEPLLOutput, + ): + if TYPE_CHECKING: + assert isinstance(dispatch_output, AscendDeepEPLLOutput) + hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output + assert self.quant_method is not None + assert self.activation == "silu" + + # NOTE: Ascend's Dispatch & Combine does not support FP16 + output_dtype = torch.bfloat16 + + pertoken_scale = hidden_states[1] + hidden_states = hidden_states[0] + + group_list_type = 1 + seg_indptr = seg_indptr.to(torch.int64) + + import torch_npu + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[self.w13_weight], + scale=[self.w13_weight_scale.to(output_dtype)], + per_token_scale=[pertoken_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=seg_indptr, + output_dtype=output_dtype, + )[0] + + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + + hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[self.w2_weight], + scale=[self.w2_weight_scale.to(output_dtype)], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=seg_indptr, + output_dtype=output_dtype, + )[0] + + return hidden_states + def get_moe_impl_class(): if global_server_args_dict["moe_a2a_backend"].is_deepep(): diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index c711d4427..372717bf9 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -23,14 +23,23 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config +from sglang.srt.utils import ( + get_bool_env_var, + get_int_env_var, + is_hip, + is_npu, + load_json_config, +) + +_is_npu = is_npu() try: from deep_ep import Buffer, Config - from sglang.srt.layers.quantization.fp8_kernel import ( - sglang_per_token_group_quant_fp8, - ) + if not _is_npu: + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_fp8, + ) use_deepep = True except ImportError: @@ -80,8 +89,24 @@ class DeepEPLLOutput(NamedTuple): return DispatchOutputFormat.deepep_ll +class AscendDeepEPLLOutput(NamedTuple): + """AscendDeepEP low latency dispatch output.""" + + hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor] + topk_idx: torch.Tensor + topk_weights: torch.Tensor + masked_m: torch.Tensor + seg_indptr: torch.Tensor + expected_m: int + + @property + def format(self) -> DispatchOutputFormat: + return DispatchOutputFormat.deepep_ll + + assert isinstance(DeepEPNormalOutput, DispatchOutput) assert isinstance(DeepEPLLOutput, DispatchOutput) +assert isinstance(AscendDeepEPLLOutput, DispatchOutput) class DeepEPDispatchMode(IntEnum): @@ -150,19 +175,20 @@ class DeepEPBuffer: else: raise NotImplementedError - total_num_sms = torch.cuda.get_device_properties( - device="cuda" - ).multi_processor_count - if ( - (deepep_mode != DeepEPMode.LOW_LATENCY) - and not global_server_args_dict["enable_two_batch_overlap"] - and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2) - ): - logger.warning( - f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. " - f"This may result in highly suboptimal performance. " - f"Consider using --deepep-config to change the behavior." - ) + if not _is_npu: + total_num_sms = torch.cuda.get_device_properties( + device="cuda" + ).multi_processor_count + if ( + (deepep_mode != DeepEPMode.LOW_LATENCY) + and not global_server_args_dict["enable_two_batch_overlap"] + and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2) + ): + logger.warning( + f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. " + f"This may result in highly suboptimal performance. " + f"Consider using --deepep-config to change the behavior." + ) cls._buffer = Buffer( group, @@ -507,13 +533,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): masked_m ) - return DeepEPLLOutput( - hidden_states, - topk_idx, - topk_weights, - masked_m, - expected_m, - ) + if _is_npu: + deepep_output = AscendDeepEPLLOutput( + hidden_states, + topk_idx, + topk_weights, + masked_m, + self.handle[1], + expected_m, + ) + else: + deepep_output = DeepEPLLOutput( + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + ) + return deepep_output def _dispatch_core( self, diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index d1b560219..192510608 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -250,10 +250,11 @@ class TopK(CustomOp): # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern if global_num_experts == 256: + router_logits = router_logits.to(torch.float32) return torch_npu.npu_moe_gating_top_k( router_logits, k=self.top_k, - bias=self.correction_bias, + bias=self.correction_bias.to(torch.float32), k_group=self.topk_group, group_count=self.num_expert_group, group_select_mode=1, diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 826a8c8e8..4e33d4be8 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -3,7 +3,18 @@ from __future__ import annotations import importlib import sys from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, + cast, +) import torch from torch.nn.parameter import Parameter @@ -79,22 +90,16 @@ def npu_wrapper_rmsnorm_forward(func): ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if not x.is_contiguous(): x = x.contiguous() - original_dtype = x.dtype - x = x.to(torch.float32) if residual is not None: - x = x + residual.to(torch.float32) - residual = x.to(original_dtype) + out, _, residual_out = torch_npu.npu_add_rms_norm( + residual, x, self.weight.data, self.variance_epsilon + ) + out = out + self.bias + return out.to(x.dtype), residual_out - x = ( - torch_npu.npu_rms_norm( - x, self.weight.to(torch.float32), self.variance_epsilon - )[0] - + self.bias - ) - - if residual is None: - return x.to(original_dtype) - return x.to(original_dtype), residual + out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0] + out = out + self.bias + return out.to(x.dtype) return _rmsnorm_forward_oot @@ -571,8 +576,10 @@ class NPU_W8A8LinearMethodImpl: layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = 0, ) -> torch.Tensor: + # To prevent import loops + from sglang.srt.layers.linear import RowParallelLinear + original_dtype = x.dtype if original_dtype != torch.int8: x = torch_npu.npu_quantize( @@ -583,8 +590,12 @@ class NPU_W8A8LinearMethodImpl: -1, True, ) - - quant_bias = layer.quant_bias if tp_rank == 0 else None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in Attention TP>1 case) + if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0: + quant_bias = None + else: + quant_bias = layer.quant_bias return torch_npu.npu_quant_matmul( x, layer.weight, @@ -651,13 +662,21 @@ class NPU_W8A8LinearMethodMTImpl: layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = 0, ) -> torch.Tensor: + # To prevent import loops + from sglang.srt.layers.linear import RowParallelLinear + original_dtype = x.dtype if original_dtype != torch.int8: x = quant_per_tensor(x, layer.input_scale, layer.input_offset) - quant_bias = layer.quant_bias if tp_rank == 0 else None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in Attention TP>1 case) + if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0: + quant_bias = None + else: + quant_bias = layer.quant_bias + return ops.quant_matmul( x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias ) @@ -737,11 +756,6 @@ class NPU_W8A8LinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from sglang.srt.layers.linear import RowParallelLinear - - if isinstance(layer, RowParallelLinear): - tp_rank = get_tensor_model_parallel_rank() - return self.quant_method.apply(layer, x, bias, tp_rank) return self.quant_method.apply(layer, x, bias) @@ -780,7 +794,6 @@ class NPU_W8A8DynamicLinearMethodImpl: tp_rank: Optional[int] = 0, ) -> torch.Tensor: original_dtype = x.dtype - # use ATB quantize quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) return torch_npu.npu_quant_matmul( quant_out, @@ -863,11 +876,6 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from sglang.srt.layers.linear import RowParallelLinear - - if isinstance(layer, RowParallelLinear): - tp_rank = get_tensor_model_parallel_rank() - return self.quant_method.apply(layer, x, bias, tp_rank) return self.quant_method.apply(layer, x, bias) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index c583a5d23..52d4f28c1 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -680,7 +680,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ) # Re-dispatch - if _is_hip or _is_npu: + if _is_hip: self._forward_method = self.forward_native def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: @@ -765,6 +765,46 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): key = key_rot return query.to(dtype), key.to(dtype) + def forward_npu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern, + # and generalization to more scenarios will be supported in the future. + if query.shape[1] * query.shape[2] > 4096: + return self.forward_native(positions, query, key, offsets) + num_tokens = query.shape[0] + rotary_mode = "half" if self.is_neox_style else "interleave" + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] + + query_rot, key_rot = torch_npu.npu_mrope( + torch.add(positions, offsets) if offsets is not None else positions, + query_rot.reshape(num_tokens, -1), + key_rot.reshape(num_tokens, -1), + self.cos_sin_cache, + self.rotary_dim, + mrope_section=[0, 0, 0], + rotary_mode=rotary_mode, + ) + query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim) + key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim) + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key + def forward_cpu( self, positions: torch.Tensor,