[Feature] Support for Ascend NPU backend (#3853)
Signed-off-by: Song Zhang <gepin.zs@antgroup.com> Co-authored-by: 22dimensions <waitingwind@foxmail.com>
This commit is contained in:
@@ -42,6 +42,7 @@ from torch.distributed import Backend, ProcessGroup
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
is_cuda_alike,
|
||||
is_npu,
|
||||
supports_custom_op,
|
||||
)
|
||||
|
||||
@@ -206,6 +207,7 @@ class GroupCoordinator:
|
||||
use_custom_allreduce: bool,
|
||||
use_hpu_communicator: bool,
|
||||
use_xpu_communicator: bool,
|
||||
use_npu_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
):
|
||||
@@ -244,6 +246,7 @@ class GroupCoordinator:
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_hpu_communicator = use_hpu_communicator
|
||||
self.use_xpu_communicator = use_xpu_communicator
|
||||
self.use_npu_communicator = use_npu_communicator
|
||||
self.use_message_queue_broadcaster = use_message_queue_broadcaster
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
@@ -291,6 +294,14 @@ class GroupCoordinator:
|
||||
if use_xpu_communicator and self.world_size > 1:
|
||||
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
||||
|
||||
from sglang.srt.distributed.device_communicators.npu_communicator import (
|
||||
NpuCommunicator,
|
||||
)
|
||||
|
||||
self.npu_communicator: Optional[NpuCommunicator] = None
|
||||
if use_npu_communicator and self.world_size > 1:
|
||||
self.npu_communicator = NpuCommunicator(group=self.device_group)
|
||||
|
||||
from sglang.srt.distributed.device_communicators.shm_broadcast import (
|
||||
MessageQueue,
|
||||
)
|
||||
@@ -418,6 +429,9 @@ class GroupCoordinator:
|
||||
if self.xpu_communicator is not None and not self.xpu_communicator.disabled:
|
||||
return self.xpu_communicator.all_reduce(input_)
|
||||
|
||||
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
||||
return self.npu_communicator.all_reduce(input_)
|
||||
|
||||
if (
|
||||
self.ca_comm is not None
|
||||
and not self.ca_comm.disabled
|
||||
@@ -497,6 +511,11 @@ class GroupCoordinator:
|
||||
if hpu_comm is not None and not hpu_comm.disabled:
|
||||
return hpu_comm.all_gather(input_, dim)
|
||||
|
||||
# For NPUs, use NPU communicator.
|
||||
npu_comm = self.npu_communicator
|
||||
if npu_comm is not None and not npu_comm.disabled:
|
||||
return npu_comm.all_gather(input_, dim)
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
@@ -941,6 +960,7 @@ def init_world_group(
|
||||
use_custom_allreduce=False,
|
||||
use_hpu_communicator=False,
|
||||
use_xpu_communicator=False,
|
||||
use_npu_communicator=False,
|
||||
group_name="world",
|
||||
)
|
||||
|
||||
@@ -959,10 +979,11 @@ def init_model_parallel_group(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=True,
|
||||
use_pynccl=not is_npu(),
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
use_hpu_communicator=True,
|
||||
use_xpu_communicator=True,
|
||||
use_npu_communicator=True,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user