[Feature] Support DeepEP normal & Redundant Experts on NPU (#9881)

This commit is contained in:
Even Zhou
2025-09-11 11:35:26 +08:00
committed by GitHub
parent 5b7448de77
commit 5b64f006ec
15 changed files with 319 additions and 111 deletions

View File

@@ -55,7 +55,7 @@ class EPLBManager:
enable_timing = self._rebalance_layers_per_chunk is None
if enable_timing:
torch.cuda.synchronize()
torch.get_device_module().synchronize()
time_start = time.time()
dump_record_output = get_global_expert_distribution_recorder().dump_record(
@@ -85,7 +85,7 @@ class EPLBManager:
msg = f"[EPLBManager] rebalance end"
if enable_timing:
torch.cuda.synchronize()
torch.get_device_module().synchronize()
time_end = time.time()
msg += f" time={time_end - time_start:.3f}s"
logger.info(msg)

View File

@@ -30,7 +30,9 @@ import torch.distributed
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var
from sglang.srt.utils import Withable, get_bool_env_var, is_npu
_is_npu = is_npu()
if TYPE_CHECKING:
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
@@ -216,7 +218,9 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
def _on_hook(self, hook_name: str, **kwargs):
if self._disable_all:
return
if not (self._recording or torch.cuda.is_current_stream_capturing()):
if not (
self._recording or torch.get_device_module().is_current_stream_capturing()
):
return
gatherer = self._single_pass_gatherers[
self._accumulator.get_single_pass_gatherer_key(
@@ -451,6 +455,10 @@ def _list_sum(a: List, b: List) -> List:
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
super().__init__(*args, **kwargs)
if not _is_npu:
device = "cuda"
else:
device = "npu"
self._enable_global_physical_experts = enable_global_physical_experts
self._data = torch.zeros(
(
@@ -462,7 +470,7 @@ class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
),
),
dtype=torch.int,
device="cuda",
device=device,
)
def reset(self):
@@ -784,7 +792,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
if self._first_dump:
self._first_dump = False
torch.cuda.empty_cache()
torch.get_device_module().empty_cache()
torch.distributed.all_reduce(
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM

View File

@@ -47,7 +47,7 @@ class ExpertLocationUpdater:
):
if self._first_execution:
self._first_execution = False
torch.cuda.empty_cache()
torch.get_device_module().empty_cache()
old_expert_location_metadata = get_global_expert_location_metadata()
assert old_expert_location_metadata is not None

View File

@@ -10,6 +10,7 @@ from torch.nn.functional import scaled_dot_product_attention
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_bool_env_var
@@ -33,6 +34,7 @@ class ForwardMetadata:
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_list: Optional[List[int]] = None
seq_lens_list_cumsum: Optional[List[int]] = None
class AscendAttnBackend(AttentionBackend):
@@ -83,6 +85,7 @@ class AscendAttnBackend(AttentionBackend):
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
tp_size = get_attention_tp_size()
self.forward_metadata = ForwardMetadata()
self.forward_metadata.block_tables = (
@@ -96,9 +99,13 @@ class AscendAttnBackend(AttentionBackend):
forward_batch.extend_seq_lens.cpu().int()
)
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
forward_batch.extend_seq_lens_cpu
)
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
if forward_batch.is_extend_in_batch:
seq_lens_list_cumsum[-1] = (
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
) * tp_size
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
self.graph_mode = False

View File

@@ -35,7 +35,6 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput,
DeepEPLLOutput,
DeepEPNormalOutput,
DispatchOutput,
@@ -454,7 +453,7 @@ class DeepEPMoE(EPMoE):
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output)
if _is_npu:
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
return self.forward_npu(dispatch_output)
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
@@ -718,63 +717,124 @@ class DeepEPMoE(EPMoE):
def forward_npu(
self,
dispatch_output: DeepEPLLOutput,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
):
if TYPE_CHECKING:
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# NOTE: Ascend's Dispatch & Combine does not support FP16
output_dtype = torch.bfloat16
pertoken_scale = hidden_states[1]
hidden_states = hidden_states[0]
group_list_type = 1
seg_indptr = seg_indptr.to(torch.int64)
import torch_npu
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=seg_indptr,
output_dtype=torch.int32,
)[0]
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=seg_indptr,
activate_left=True,
quant_mode=1,
)
# NOTE: Ascend's Dispatch & Combine does not support FP16
output_dtype = torch.bfloat16
group_list_type = 1
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=seg_indptr,
output_dtype=output_dtype,
)[0]
def _forward_normal(dispatch_output: DeepEPNormalOutput):
if TYPE_CHECKING:
assert isinstance(dispatch_output, DeepEPNormalOutput)
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
return hidden_states
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
else:
# dynamic quant
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
hidden_states.device
)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
def _forward_ll(dispatch_output: DeepEPLLOutput):
if TYPE_CHECKING:
assert isinstance(dispatch_output, DeepEPLLOutput)
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
group_list = group_list.to(torch.int64)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32,
)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=per_token_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
return _forward_normal(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
return _forward_ll(dispatch_output)
else:
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):

View File

@@ -9,7 +9,6 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
DispatchOutputFormat,
)
from sglang.srt.layers.moe.token_dispatcher.deepep import (
AscendDeepEPLLOutput,
DeepEPConfig,
DeepEPDispatcher,
DeepEPLLCombineInput,
@@ -23,7 +22,6 @@ from sglang.srt.layers.moe.token_dispatcher.standard import (
)
__all__ = [
"AscendDeepEPLLOutput",
"BaseDispatcher",
"BaseDispatcherConfig",
"CombineInput",

View File

@@ -8,7 +8,6 @@ import torch
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput,
DeepEPLLCombineInput,
DeepEPLLOutput,
DeepEPNormalCombineInput,
@@ -47,19 +46,12 @@ class DispatchOutputChecker:
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
return dispatch_output.format.is_deepep()
@staticmethod
def format_is_ascent_ll(
dispatch_output: DispatchOutput,
) -> TypeGuard[AscendDeepEPLLOutput]:
return dispatch_output.format.is_ascent_ll()
class DispatchOutputFormat(Enum):
STANDARD = "standard"
DEEPEP_NORMAL = "deepep_normal"
DEEPEP_LL = "deepep_ll"
ASCENT_LL = "ascent_ll"
def is_standard(self) -> bool:
return self == DispatchOutputFormat.STANDARD
@@ -76,9 +68,6 @@ class DispatchOutputFormat(Enum):
DispatchOutputFormat.DEEPEP_LL,
]
def is_ascent_ll(self) -> bool:
return self == DispatchOutputFormat.ASCENT_LL
@runtime_checkable
class DispatchOutput(Protocol):

View File

@@ -77,24 +77,8 @@ class DeepEPLLOutput(NamedTuple):
return DispatchOutputFormat.DEEPEP_LL
class AscendDeepEPLLOutput(NamedTuple):
"""AscendDeepEP low latency dispatch output."""
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
topk_idx: torch.Tensor
topk_weights: torch.Tensor
masked_m: torch.Tensor
seg_indptr: torch.Tensor
expected_m: int
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.ASCENT_LL
assert isinstance(DeepEPNormalOutput, DispatchOutput)
assert isinstance(DeepEPLLOutput, DispatchOutput)
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
class DeepEPNormalCombineInput(NamedTuple):
@@ -434,12 +418,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
):
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_post_reorder_triton_kernel,
)
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
output = hidden_states
else:
if hidden_states.shape[0] > 0:
@@ -553,23 +536,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
masked_m
)
if _is_npu:
deepep_output = AscendDeepEPLLOutput(
hidden_states,
topk_idx,
topk_weights,
masked_m,
self.handle[1],
expected_m,
)
else:
deepep_output = DeepEPLLOutput(
hidden_states,
topk_idx,
topk_weights,
masked_m,
expected_m,
)
deepep_output = DeepEPLLOutput(
hidden_states,
topk_idx,
topk_weights,
masked_m,
expected_m,
)
return deepep_output
def _dispatch_core(

View File

@@ -330,6 +330,14 @@ class TopK(CustomOp):
)
topk_weights = topk_weights / topk_weights_sum
if expert_location_dispatch_info is not None:
topk_ids = topk_ids_logical_to_physical(
topk_ids, expert_location_dispatch_info
)
get_global_expert_distribution_recorder().on_select_experts(
topk_ids=topk_ids
)
return StandardTopKOutput(topk_weights, topk_ids, _)
else:
self.topk_config.torch_native = True