[Feature] Optimize DeepSeek's DeepEP on Ascend NPU (#8355)
Co-authored-by: ronnie_zheng <zl19940307@163.com> Co-authored-by: Hexq0210 <hexq0809521@gmail.com>
This commit is contained in:
@@ -50,6 +50,8 @@ from sglang.srt.utils import (
|
||||
supports_custom_op,
|
||||
)
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphCaptureContext:
|
||||
@@ -591,7 +593,7 @@ class GroupCoordinator:
|
||||
)
|
||||
|
||||
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
||||
if not supports_custom_op():
|
||||
if _is_npu or not supports_custom_op():
|
||||
self._all_gather_into_tensor(output, input)
|
||||
else:
|
||||
torch.ops.sglang.reg_all_gather_into_tensor(
|
||||
@@ -1127,7 +1129,7 @@ def init_model_parallel_group(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=not is_npu(),
|
||||
use_pynccl=not _is_npu,
|
||||
use_pymscclpp=use_mscclpp_allreduce,
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
use_hpu_communicator=True,
|
||||
|
||||
@@ -75,6 +75,9 @@ class AscendAttnBackend(AttentionBackend):
|
||||
)
|
||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q,
|
||||
|
||||
@@ -34,6 +34,7 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
AscendDeepEPLLOutput,
|
||||
DeepEPLLOutput,
|
||||
DeepEPNormalOutput,
|
||||
DispatchOutput,
|
||||
@@ -387,7 +388,8 @@ class DeepEPMoE(EPMoE):
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
if self.deepep_mode.enable_low_latency():
|
||||
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
||||
# NPU supports low_latency deepep without deepgemm
|
||||
assert (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
@@ -404,7 +406,7 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
# the last one is invalid rank_id
|
||||
self.expert_mask[:-1] = 1
|
||||
else:
|
||||
elif not _is_npu:
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
@@ -459,6 +461,8 @@ class DeepEPMoE(EPMoE):
|
||||
if _use_aiter:
|
||||
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||
return self.forward_aiter(dispatch_output)
|
||||
if _is_npu:
|
||||
return self.forward_npu(dispatch_output)
|
||||
if dispatch_output.format.is_deepep_normal():
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
@@ -723,6 +727,60 @@ class DeepEPMoE(EPMoE):
|
||||
|
||||
return down_output
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
dispatch_output: DeepEPLLOutput,
|
||||
):
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
||||
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
|
||||
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
||||
output_dtype = torch.bfloat16
|
||||
|
||||
pertoken_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
group_list_type = 1
|
||||
seg_indptr = seg_indptr.to(torch.int64)
|
||||
|
||||
import torch_npu
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=seg_indptr,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=seg_indptr,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_moe_impl_class():
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
|
||||
@@ -23,14 +23,23 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
is_hip,
|
||||
is_npu,
|
||||
load_json_config,
|
||||
)
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
try:
|
||||
from deep_ep import Buffer, Config
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
if not _is_npu:
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
use_deepep = True
|
||||
except ImportError:
|
||||
@@ -80,8 +89,24 @@ class DeepEPLLOutput(NamedTuple):
|
||||
return DispatchOutputFormat.deepep_ll
|
||||
|
||||
|
||||
class AscendDeepEPLLOutput(NamedTuple):
|
||||
"""AscendDeepEP low latency dispatch output."""
|
||||
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
||||
topk_idx: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
masked_m: torch.Tensor
|
||||
seg_indptr: torch.Tensor
|
||||
expected_m: int
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat:
|
||||
return DispatchOutputFormat.deepep_ll
|
||||
|
||||
|
||||
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
||||
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
||||
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
|
||||
|
||||
|
||||
class DeepEPDispatchMode(IntEnum):
|
||||
@@ -150,19 +175,20 @@ class DeepEPBuffer:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
total_num_sms = torch.cuda.get_device_properties(
|
||||
device="cuda"
|
||||
).multi_processor_count
|
||||
if (
|
||||
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
||||
and not global_server_args_dict["enable_two_batch_overlap"]
|
||||
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
||||
):
|
||||
logger.warning(
|
||||
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
|
||||
f"This may result in highly suboptimal performance. "
|
||||
f"Consider using --deepep-config to change the behavior."
|
||||
)
|
||||
if not _is_npu:
|
||||
total_num_sms = torch.cuda.get_device_properties(
|
||||
device="cuda"
|
||||
).multi_processor_count
|
||||
if (
|
||||
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
||||
and not global_server_args_dict["enable_two_batch_overlap"]
|
||||
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
||||
):
|
||||
logger.warning(
|
||||
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
|
||||
f"This may result in highly suboptimal performance. "
|
||||
f"Consider using --deepep-config to change the behavior."
|
||||
)
|
||||
|
||||
cls._buffer = Buffer(
|
||||
group,
|
||||
@@ -507,13 +533,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
masked_m
|
||||
)
|
||||
|
||||
return DeepEPLLOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
if _is_npu:
|
||||
deepep_output = AscendDeepEPLLOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
self.handle[1],
|
||||
expected_m,
|
||||
)
|
||||
else:
|
||||
deepep_output = DeepEPLLOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
return deepep_output
|
||||
|
||||
def _dispatch_core(
|
||||
self,
|
||||
|
||||
@@ -250,10 +250,11 @@ class TopK(CustomOp):
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
return torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=self.top_k,
|
||||
bias=self.correction_bias,
|
||||
bias=self.correction_bias.to(torch.float32),
|
||||
k_group=self.topk_group,
|
||||
group_count=self.num_expert_group,
|
||||
group_select_mode=1,
|
||||
|
||||
@@ -3,7 +3,18 @@ from __future__ import annotations
|
||||
import importlib
|
||||
import sys
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -79,22 +90,16 @@ def npu_wrapper_rmsnorm_forward(func):
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
original_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(original_dtype)
|
||||
out, _, residual_out = torch_npu.npu_add_rms_norm(
|
||||
residual, x, self.weight.data, self.variance_epsilon
|
||||
)
|
||||
out = out + self.bias
|
||||
return out.to(x.dtype), residual_out
|
||||
|
||||
x = (
|
||||
torch_npu.npu_rms_norm(
|
||||
x, self.weight.to(torch.float32), self.variance_epsilon
|
||||
)[0]
|
||||
+ self.bias
|
||||
)
|
||||
|
||||
if residual is None:
|
||||
return x.to(original_dtype)
|
||||
return x.to(original_dtype), residual
|
||||
out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
|
||||
out = out + self.bias
|
||||
return out.to(x.dtype)
|
||||
|
||||
return _rmsnorm_forward_oot
|
||||
|
||||
@@ -571,8 +576,10 @@ class NPU_W8A8LinearMethodImpl:
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
# To prevent import loops
|
||||
from sglang.srt.layers.linear import RowParallelLinear
|
||||
|
||||
original_dtype = x.dtype
|
||||
if original_dtype != torch.int8:
|
||||
x = torch_npu.npu_quantize(
|
||||
@@ -583,8 +590,12 @@ class NPU_W8A8LinearMethodImpl:
|
||||
-1,
|
||||
True,
|
||||
)
|
||||
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in Attention TP>1 case)
|
||||
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
|
||||
quant_bias = None
|
||||
else:
|
||||
quant_bias = layer.quant_bias
|
||||
return torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight,
|
||||
@@ -651,13 +662,21 @@ class NPU_W8A8LinearMethodMTImpl:
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
# To prevent import loops
|
||||
from sglang.srt.layers.linear import RowParallelLinear
|
||||
|
||||
original_dtype = x.dtype
|
||||
if original_dtype != torch.int8:
|
||||
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
|
||||
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in Attention TP>1 case)
|
||||
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
|
||||
quant_bias = None
|
||||
else:
|
||||
quant_bias = layer.quant_bias
|
||||
|
||||
return ops.quant_matmul(
|
||||
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
|
||||
)
|
||||
@@ -737,11 +756,6 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.linear import RowParallelLinear
|
||||
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||
return self.quant_method.apply(layer, x, bias)
|
||||
|
||||
|
||||
@@ -780,7 +794,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
original_dtype = x.dtype
|
||||
# use ATB quantize
|
||||
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
return torch_npu.npu_quant_matmul(
|
||||
quant_out,
|
||||
@@ -863,11 +876,6 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.linear import RowParallelLinear
|
||||
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||
return self.quant_method.apply(layer, x, bias)
|
||||
|
||||
|
||||
|
||||
@@ -680,7 +680,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
)
|
||||
|
||||
# Re-dispatch
|
||||
if _is_hip or _is_npu:
|
||||
if _is_hip:
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
@@ -765,6 +765,46 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
key = key_rot
|
||||
return query.to(dtype), key.to(dtype)
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
|
||||
# and generalization to more scenarios will be supported in the future.
|
||||
if query.shape[1] * query.shape[2] > 4096:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
num_tokens = query.shape[0]
|
||||
rotary_mode = "half" if self.is_neox_style else "interleave"
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
|
||||
query_rot, key_rot = torch_npu.npu_mrope(
|
||||
torch.add(positions, offsets) if offsets is not None else positions,
|
||||
query_rot.reshape(num_tokens, -1),
|
||||
key_rot.reshape(num_tokens, -1),
|
||||
self.cos_sin_cache,
|
||||
self.rotary_dim,
|
||||
mrope_section=[0, 0, 0],
|
||||
rotary_mode=rotary_mode,
|
||||
)
|
||||
query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
|
||||
key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
else:
|
||||
query = query_rot
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user