[Feature] Support for Ascend NPU backend (#3853)

Signed-off-by: Song Zhang <gepin.zs@antgroup.com>
Co-authored-by: 22dimensions <waitingwind@foxmail.com>
This commit is contained in:
Song Zhang
2025-05-07 11:32:53 +08:00
committed by GitHub
parent cb69194562
commit 00c2c1f08b
7 changed files with 92 additions and 3 deletions

View File

@@ -42,6 +42,7 @@ from torch.distributed import Backend, ProcessGroup
from sglang.srt.utils import (
direct_register_custom_op,
is_cuda_alike,
is_npu,
supports_custom_op,
)
@@ -206,6 +207,7 @@ class GroupCoordinator:
use_custom_allreduce: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
use_npu_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
@@ -244,6 +246,7 @@ class GroupCoordinator:
self.use_custom_allreduce = use_custom_allreduce
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
self.use_npu_communicator = use_npu_communicator
self.use_message_queue_broadcaster = use_message_queue_broadcaster
# lazy import to avoid documentation build error
@@ -291,6 +294,14 @@ class GroupCoordinator:
if use_xpu_communicator and self.world_size > 1:
self.xpu_communicator = XpuCommunicator(group=self.device_group)
from sglang.srt.distributed.device_communicators.npu_communicator import (
NpuCommunicator,
)
self.npu_communicator: Optional[NpuCommunicator] = None
if use_npu_communicator and self.world_size > 1:
self.npu_communicator = NpuCommunicator(group=self.device_group)
from sglang.srt.distributed.device_communicators.shm_broadcast import (
MessageQueue,
)
@@ -418,6 +429,9 @@ class GroupCoordinator:
if self.xpu_communicator is not None and not self.xpu_communicator.disabled:
return self.xpu_communicator.all_reduce(input_)
if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_)
if (
self.ca_comm is not None
and not self.ca_comm.disabled
@@ -497,6 +511,11 @@ class GroupCoordinator:
if hpu_comm is not None and not hpu_comm.disabled:
return hpu_comm.all_gather(input_, dim)
# For NPUs, use NPU communicator.
npu_comm = self.npu_communicator
if npu_comm is not None and not npu_comm.disabled:
return npu_comm.all_gather(input_, dim)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
@@ -941,6 +960,7 @@ def init_world_group(
use_custom_allreduce=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
use_npu_communicator=False,
group_name="world",
)
@@ -959,10 +979,11 @@ def init_model_parallel_group(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_pynccl=not is_npu(),
use_custom_allreduce=use_custom_allreduce,
use_hpu_communicator=True,
use_xpu_communicator=True,
use_npu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)