diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index e8dab5c80..57d966f70 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -52,6 +52,8 @@ from sglang.srt.utils import ( _is_npu = is_npu() +IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS") + @dataclass class GraphCaptureContext: @@ -223,10 +225,12 @@ class GroupCoordinator: use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ): + # Set group info group_name = group_name or "anonymous" self.unique_name = _get_unique_name(group_name) _register_group(self) + # Set rank info self.rank = torch.distributed.get_rank() self.local_rank = local_rank self.device_group = None @@ -250,14 +254,16 @@ class GroupCoordinator: assert self.cpu_group is not None assert self.device_group is not None + device_id = 0 if IS_ONE_DEVICE_PER_PROCESS else local_rank if is_cuda_alike(): - self.device = torch.device(f"cuda:{local_rank}") + self.device = torch.device(f"cuda:{device_id}") elif _is_npu: - self.device = torch.device(f"npu:{local_rank}") + self.device = torch.device(f"npu:{device_id}") else: self.device = torch.device("cpu") self.device_module = torch.get_device_module(self.device) + # Import communicators self.use_pynccl = use_pynccl self.use_pymscclpp = use_pymscclpp self.use_custom_allreduce = use_custom_allreduce @@ -270,6 +276,9 @@ class GroupCoordinator: from sglang.srt.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce, ) + from sglang.srt.distributed.device_communicators.pymscclpp import ( + PyMscclppCommunicator, + ) from sglang.srt.distributed.device_communicators.pynccl import ( PyNcclCommunicator, ) @@ -287,10 +296,6 @@ class GroupCoordinator: device=self.device, ) - from sglang.srt.distributed.device_communicators.pymscclpp import ( - PyMscclppCommunicator, - ) - self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None if use_pymscclpp and self.world_size > 1: self.pymscclpp_comm = PyMscclppCommunicator( @@ -325,30 +330,30 @@ class GroupCoordinator: except Exception as e: logger.warning(f"Failed to initialize QuickAllReduce: {e}") + # Create communicator for other hardware backends from sglang.srt.distributed.device_communicators.hpu_communicator import ( HpuCommunicator, ) + from sglang.srt.distributed.device_communicators.npu_communicator import ( + NpuCommunicator, + ) + from sglang.srt.distributed.device_communicators.xpu_communicator import ( + XpuCommunicator, + ) self.hpu_communicator: Optional[HpuCommunicator] = None if use_hpu_communicator and self.world_size > 1: self.hpu_communicator = HpuCommunicator(group=self.device_group) - from sglang.srt.distributed.device_communicators.xpu_communicator import ( - XpuCommunicator, - ) - self.xpu_communicator: Optional[XpuCommunicator] = None if use_xpu_communicator and self.world_size > 1: self.xpu_communicator = XpuCommunicator(group=self.device_group) - from sglang.srt.distributed.device_communicators.npu_communicator import ( - NpuCommunicator, - ) - self.npu_communicator: Optional[NpuCommunicator] = None if use_npu_communicator and self.world_size > 1: self.npu_communicator = NpuCommunicator(group=self.device_group) + # Create message queue from sglang.srt.distributed.device_communicators.shm_broadcast import ( MessageQueue, ) @@ -848,6 +853,11 @@ class GroupCoordinator: ) return obj_list + def all_gather_object(self, obj: Any) -> List[Any]: + objs = [None] * self.world_size + torch.distributed.all_gather_object(objs, obj, group=self.cpu_group) + return objs + def send_object(self, obj: Any, dst: int) -> None: """Send the input object list to the destination rank.""" """NOTE: `dst` is the local rank of the destination rank."""