[Feature] Support DeepEP normal & Redundant Experts on NPU (#9881)
This commit is contained in:
@@ -30,7 +30,9 @@ import torch.distributed
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import Withable, get_bool_env_var
|
||||
from sglang.srt.utils import Withable, get_bool_env_var, is_npu
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
||||
@@ -216,7 +218,9 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
||||
def _on_hook(self, hook_name: str, **kwargs):
|
||||
if self._disable_all:
|
||||
return
|
||||
if not (self._recording or torch.cuda.is_current_stream_capturing()):
|
||||
if not (
|
||||
self._recording or torch.get_device_module().is_current_stream_capturing()
|
||||
):
|
||||
return
|
||||
gatherer = self._single_pass_gatherers[
|
||||
self._accumulator.get_single_pass_gatherer_key(
|
||||
@@ -451,6 +455,10 @@ def _list_sum(a: List, b: List) -> List:
|
||||
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
|
||||
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if not _is_npu:
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "npu"
|
||||
self._enable_global_physical_experts = enable_global_physical_experts
|
||||
self._data = torch.zeros(
|
||||
(
|
||||
@@ -462,7 +470,7 @@ class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
|
||||
),
|
||||
),
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
device=device,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
@@ -784,7 +792,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
|
||||
|
||||
if self._first_dump:
|
||||
self._first_dump = False
|
||||
torch.cuda.empty_cache()
|
||||
torch.get_device_module().empty_cache()
|
||||
|
||||
torch.distributed.all_reduce(
|
||||
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
|
||||
|
||||
Reference in New Issue
Block a user