[Feature] Support DeepEP normal & Redundant Experts on NPU (#9881)

This commit is contained in:
Even Zhou
2025-09-11 11:35:26 +08:00
committed by GitHub
parent 5b7448de77
commit 5b64f006ec
15 changed files with 319 additions and 111 deletions

View File

@@ -30,7 +30,9 @@ import torch.distributed
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var
from sglang.srt.utils import Withable, get_bool_env_var, is_npu
_is_npu = is_npu()
if TYPE_CHECKING:
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
@@ -216,7 +218,9 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
def _on_hook(self, hook_name: str, **kwargs):
if self._disable_all:
return
if not (self._recording or torch.cuda.is_current_stream_capturing()):
if not (
self._recording or torch.get_device_module().is_current_stream_capturing()
):
return
gatherer = self._single_pass_gatherers[
self._accumulator.get_single_pass_gatherer_key(
@@ -451,6 +455,10 @@ def _list_sum(a: List, b: List) -> List:
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
super().__init__(*args, **kwargs)
if not _is_npu:
device = "cuda"
else:
device = "npu"
self._enable_global_physical_experts = enable_global_physical_experts
self._data = torch.zeros(
(
@@ -462,7 +470,7 @@ class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
),
),
dtype=torch.int,
device="cuda",
device=device,
)
def reset(self):
@@ -784,7 +792,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
if self._first_dump:
self._first_dump = False
torch.cuda.empty_cache()
torch.get_device_module().empty_cache()
torch.distributed.all_reduce(
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM