From 097725bb665f02c6494b8c99bcd84898a123dc78 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 2 Oct 2025 01:09:13 -0700 Subject: [PATCH] Clean up parallel_state.py (#11148) --- .../sglang/srt/distributed/parallel_state.py | 183 ++++++++++-------- 1 file changed, 101 insertions(+), 82 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 4f410570d..25dd4e511 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -4,7 +4,7 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -"""vLLM distributed state. +"""Distributed state. It takes over the control of the distributed environment from PyTorch. The typical workflow is: @@ -53,19 +53,26 @@ from sglang.srt.utils import ( _is_npu = is_npu() _is_cpu = is_cpu() +_supports_custom_op = supports_custom_op() IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS") +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + +# use int value instead of ReduceOp.SUM to support torch compile +REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM) + + @dataclass class GraphCaptureContext: stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream -TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) - -# use int value instead of ReduceOp.SUM to support torch compile -REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM) +@dataclass +class P2PWork: + work: Optional[torch.distributed.Work] + payload: Optional[torch.Tensor] def _split_tensor_dict( @@ -117,7 +124,7 @@ def _register_group(group: "GroupCoordinator") -> None: _groups[group.unique_name] = weakref.ref(group) -if supports_custom_op(): +if _supports_custom_op: def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: assert group_name in _groups, f"Group {group_name} is not found." @@ -277,7 +284,7 @@ class GroupCoordinator: self.use_npu_communicator = use_npu_communicator self.use_message_queue_broadcaster = use_message_queue_broadcaster - # lazy import to avoid documentation build error + # Lazy import to avoid documentation build error from sglang.srt.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce, ) @@ -497,7 +504,7 @@ class GroupCoordinator: torch.distributed.all_reduce(input_, group=self.device_group) return input_ - if not supports_custom_op(): + if not _supports_custom_op: self._all_reduce_in_place(input_) return input_ @@ -523,23 +530,24 @@ class GroupCoordinator: outplace_all_reduce_method = None if ( - self.qr_comm is not None - and not self.qr_comm.disabled - and self.qr_comm.should_quick_allreduce(input_) - ): - outplace_all_reduce_method = "qr" - elif ( self.ca_comm is not None and not self.ca_comm.disabled and self.ca_comm.should_custom_ar(input_) ): outplace_all_reduce_method = "ca" + elif ( + self.qr_comm is not None + and not self.qr_comm.disabled + and self.qr_comm.should_quick_allreduce(input_) + ): + outplace_all_reduce_method = "qr" elif ( self.pymscclpp_comm is not None and not self.pymscclpp_comm.disabled and self.pymscclpp_comm.should_mscclpp_allreduce(input_) ): outplace_all_reduce_method = "pymscclpp" + if outplace_all_reduce_method is not None: return torch.ops.sglang.outplace_all_reduce( input_, @@ -553,16 +561,16 @@ class GroupCoordinator: def _all_reduce_out_place( self, input_: torch.Tensor, outplace_all_reduce_method: str ) -> torch.Tensor: - qr_comm = self.qr_comm ca_comm = self.ca_comm + qr_comm = self.qr_comm pymscclpp_comm = self.pymscclpp_comm assert any([qr_comm, ca_comm, pymscclpp_comm]) - if outplace_all_reduce_method == "qr": - assert not qr_comm.disabled - out = qr_comm.quick_all_reduce(input_) - elif outplace_all_reduce_method == "ca": + if outplace_all_reduce_method == "ca": assert not ca_comm.disabled out = ca_comm.custom_all_reduce(input_) + elif outplace_all_reduce_method == "qr": + assert not qr_comm.disabled + out = qr_comm.quick_all_reduce(input_) else: assert not pymscclpp_comm.disabled out = pymscclpp_comm.all_reduce(input_) @@ -637,7 +645,7 @@ class GroupCoordinator: ) def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): - if _is_npu or not supports_custom_op(): + if _is_npu or not _supports_custom_op: self._all_gather_into_tensor(output, input) else: torch.ops.sglang.reg_all_gather_into_tensor( @@ -697,15 +705,13 @@ class GroupCoordinator: ) # All-gather. - if input_.is_cpu and is_shm_available( - input_.dtype, self.world_size, self.local_size - ): - return torch.ops.sgl_kernel.shm_allgather(input_, dim) - if input_.is_cpu: - torch.distributed.all_gather_into_tensor( - output_tensor, input_, group=self.device_group - ) + if is_shm_available(input_.dtype, self.world_size, self.local_size): + return torch.ops.sgl_kernel.shm_allgather(input_, dim) + else: + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) else: self.all_gather_into_tensor(output_tensor, input_) @@ -861,45 +867,63 @@ class GroupCoordinator: torch.distributed.all_gather_object(objs, obj, group=self.cpu_group) return objs - def send_object(self, obj: Any, dst: int) -> None: - """Send the input object list to the destination rank.""" - """NOTE: `dst` is the local rank of the destination rank.""" + def send_object( + self, + obj: Any, + dst: int, + async_send: bool = False, + ) -> List[P2PWork]: + """ + Send the input object list to the destination rank. + This function uses the CPU group for all communications. + + TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group), + use other functions (e.g., send), or implement a new function (e.g., send_object_device). + + NOTE: `dst` is the local rank of the destination rank. + """ assert dst < self.world_size, f"Invalid dst rank ({dst})" - assert dst != self.rank_in_group, ( "Invalid destination rank. Destination rank is the same " "as the current rank." ) + send_func = torch.distributed.isend if async_send else torch.distributed.send # Serialize object to tensor and get the size as well - object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda( - device=torch.cuda.current_device() - ) - + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) size_tensor = torch.tensor( - [object_tensor.numel()], - dtype=torch.long, - device="cpu", + [object_tensor.numel()], dtype=torch.long, device="cpu" ) + # Send object size - torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) - - # Send object - torch.distributed.send( - object_tensor, - dst=self.ranks[dst], - group=self.device_group, + p2p_work = [] + size_work = send_func( + size_tensor, + self.ranks[dst], + group=self.cpu_group, ) + if async_send: + p2p_work.append(P2PWork(size_work, size_tensor)) - return None + object_work = send_func( + object_tensor, + self.ranks[dst], + group=self.cpu_group, + ) + if async_send: + p2p_work.append(P2PWork(object_work, object_tensor)) - def recv_object(self, src: int) -> Any: + return p2p_work + + def recv_object( + self, + src: int, + ) -> Any: """Receive the input object list from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" assert src < self.world_size, f"Invalid src rank ({src})" - assert ( src != self.rank_in_group ), "Invalid source rank. Source rank is the same as the current rank." @@ -907,27 +931,25 @@ class GroupCoordinator: size_tensor = torch.empty(1, dtype=torch.long, device="cpu") # Receive object size - rank_size = torch.distributed.recv( + # We have to use irecv here to make it work for both isend and send. + work = torch.distributed.irecv( size_tensor, src=self.ranks[src], group=self.cpu_group ) + work.wait() # Tensor to receive serialized objects into. - object_tensor = torch.empty( # type: ignore[call-overload] + object_tensor: Any = torch.empty( # type: ignore[call-overload] size_tensor.item(), # type: ignore[arg-type] dtype=torch.uint8, - device=torch.cuda.current_device(), + device="cpu", ) - rank_object = torch.distributed.recv( - object_tensor, src=self.ranks[src], group=self.device_group + work = torch.distributed.irecv( + object_tensor, src=self.ranks[src], group=self.cpu_group ) + work.wait() - assert ( - rank_object == rank_size - ), "Received object sender rank does not match the size sender rank." - - obj = pickle.loads(object_tensor.cpu().numpy()) - + obj = pickle.loads(object_tensor.numpy()) return obj def broadcast_tensor_dict( @@ -1017,12 +1039,13 @@ class GroupCoordinator: tensor_dict: Dict[str, Union[torch.Tensor, Any]], dst: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + async_send: bool = False, + ) -> Optional[List[P2PWork]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. """ # Bypass the function if we are using only 1 GPU. - if not torch.distributed.is_initialized() or self.world_size == 1: + if self.world_size == 1: return tensor_dict all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size @@ -1047,7 +1070,10 @@ class GroupCoordinator: # 1. Superior D2D transfer bandwidth # 2. Ability to overlap send and recv operations # Thus the net performance gain justifies this approach. - self.send_object(metadata_list, dst=dst) + + send_func = torch.distributed.isend if async_send else torch.distributed.send + p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send) + for tensor in tensor_list: if tensor.numel() == 0: # Skip sending empty tensors. @@ -1057,15 +1083,10 @@ class GroupCoordinator: if all_gather_group is not None and tensor.numel() % all_gather_size == 0: tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] - if tensor.is_cpu: - # use metadata_group for CPU tensors - torch.distributed.send( - tensor, dst=self.ranks[dst], group=metadata_group - ) - else: - # use group for GPU tensors - torch.distributed.send(tensor, dst=self.ranks[dst], group=group) - return None + comm_group = metadata_group if tensor.is_cpu else group + work = send_func(tensor, self.ranks[dst], group=comm_group) + p2p_works.append(P2PWork(work, tensor)) + return p2p_works def recv_tensor_dict( self, @@ -1111,17 +1132,15 @@ class GroupCoordinator: orig_shape = tensor.shape tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] - if tensor.is_cpu: - # use metadata_group for CPU tensors - torch.distributed.recv( - tensor, src=self.ranks[src], group=metadata_group - ) - else: - # use group for GPU tensors - torch.distributed.recv(tensor, src=self.ranks[src], group=group) + # We have to use irecv here to make it work for both isend and send. + comm_group = metadata_group if tensor.is_cpu else group + work = torch.distributed.irecv( + tensor, src=self.ranks[src], group=comm_group + ) + work.wait() + if use_all_gather: - # do the allgather - tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore + tensor = all_gather_group.all_gather(tensor, dim=0) tensor = tensor.reshape(orig_shape) tensor_dict[key] = tensor