# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any, Optional import torch from torch.distributed import Backend, ProcessGroup from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator from vllm.distributed.parallel_state import ( GroupCoordinator, TensorMetadata, _get_unique_name, _register_group, _split_tensor_dict, ) from vllm.distributed.utils import ( StatelessProcessGroup, stateless_destroy_torch_distributed_process_group, stateless_init_torch_distributed_process_group, ) from vllm.logger import init_logger from vllm.utils.import_utils import resolve_obj_by_qualname logger = init_logger(__name__) class StatelessGroupCoordinator(GroupCoordinator): """ A stateless version of the GroupCoordinator class in parallel_state, It will create CPU, device and TCPStore based communication groups that are independent of PyTorch's WORLD group. Hence, communication groups with a different set of participants GPUs can be created without destroying the existing ones. """ def __init__( self, group_ranks: list[list[int]], local_rank: int, torch_distributed_backend: str | Backend, use_device_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: str | None = None, host: str = "127.0.0.1", group_ports: list[list[int]] | None = None, global_rank: int = 0, global_world_size: int = 1, ): group_name = group_name or "anonymous" self.unique_name = _get_unique_name(group_name) _register_group(self) self.rank = global_rank self.local_rank = local_rank self_device_group = None self_cpu_group = None self_tcp_store_group = None from vllm.platforms import current_platform backend = str(torch_distributed_backend) self.backend = backend assert group_ports is not None, "group_ports is not provided" for idx, ranks in enumerate(group_ranks): if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) self.rank_in_group = ranks.index(self.rank) ports = group_ports[idx] device_port = ports[0] cpu_port = ports[1] tcp_store_port = ports[2] device_group = stateless_init_torch_distributed_process_group( host=host, port=device_port, rank=self.rank_in_group, world_size=self.world_size, backend=backend, group_name=f"{self.unique_name}_device", ) cpu_group = stateless_init_torch_distributed_process_group( host=host, port=cpu_port, rank=self.rank_in_group, world_size=self.world_size, backend="gloo", group_name=f"{self.unique_name}_cpu", ) tcp_store_group = StatelessProcessGroup.create( host=host, port=tcp_store_port, rank=self.rank_in_group, world_size=self.world_size, ) self_device_group = device_group self_cpu_group = cpu_group self_tcp_store_group = tcp_store_group assert self_cpu_group is not None assert self_device_group is not None assert self_tcp_store_group is not None self.cpu_group = self_cpu_group self.device_group = self_device_group self.tcp_store_group = self_tcp_store_group if current_platform.is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") elif current_platform.is_xpu(): self.device = torch.device(f"xpu:{local_rank}") elif current_platform.is_out_of_tree(): self.device = torch.device(f"{current_platform.device_name}:{local_rank}") else: self.device = torch.device("cpu") self.use_device_communicator = use_device_communicator self.device_communicator = None if use_device_communicator and self.world_size > 1: device_comm_cls = resolve_obj_by_qualname( current_platform.get_device_communicator_cls() ) assert device_comm_cls == CudaCommunicator self.device_communicator = CudaCommunicator( cpu_group=self.cpu_group, device=self.device, device_group=self.device_group, unique_name=self.unique_name, global_ranks=self.ranks, global_world_size=global_world_size, tcp_store_group=self.tcp_store_group, ) self.mq_broadcaster = None self.use_custom_op_call = ( current_platform.is_cuda_alike() or current_platform.is_tpu() ) self.use_cpu_custom_send_recv = False def destroy(self): if self.device_communicator: self.device_communicator.destroy() if self.device_group: stateless_destroy_torch_distributed_process_group(self.device_group) if self.cpu_group: stateless_destroy_torch_distributed_process_group(self.cpu_group) def size(self) -> int: """Return the world size of this group.""" return self.world_size def broadcast(self, input_: torch.Tensor, src: int = 0): if self.world_size == 1: return input_ if self.device_communicator and input_.is_cuda: return self.device_communicator.broadcast(input_, src) else: return self.tcp_store_group.broadcast(input_, src) def broadcast_object(self, obj=None, src: int = 0): if self.world_size == 1: return obj return self.tcp_store_group.broadcast_obj(obj, src) def broadcast_object_list( self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None ): assert src < self.world_size if self.world_size == 1: return obj_list if self.rank_in_group == src: for obj in obj_list: self.tcp_store_group.broadcast_obj(obj, src) else: for i in range(len(obj_list)): obj_list[i] = self.tcp_store_group.broadcast_obj(None, src) return obj_list def broadcast_tensor_dict( self, tensor_dict: dict[str, torch.Tensor | Any] | None = None, src: int = 0, group: ProcessGroup | None = None, metadata_group: ProcessGroup | None = None, ) -> dict[str, torch.Tensor | Any] | None: if self.world_size == 1: return tensor_dict if self.rank_in_group == src: assert isinstance(tensor_dict, dict), ( f"Expecting a dictionary, got {type(tensor_dict)}" ) metadata_list, tensor_list = _split_tensor_dict(tensor_dict) else: metadata_list = None tensor_list = [] recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj( metadata_list, src ) if self.rank_in_group != src: tensor_dict = {} for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): tensor = torch.empty( value.size, dtype=value.dtype, device=value.device ) tensor_list.append(tensor) tensor_dict[key] = tensor else: tensor_dict[key] = value for tensor in tensor_list: if tensor.numel() == 0: continue if self.device_communicator and tensor.is_cuda: tensor.copy_(self.device_communicator.broadcast(tensor, src)) else: tensor.copy_(self.tcp_store_group.broadcast(tensor, src)) return tensor_dict def send_object(self, obj, dst: int) -> None: assert dst < self.world_size assert dst != self.rank_in_group self.tcp_store_group.send_obj(obj, dst) def recv_object(self, src: int): assert src < self.world_size assert src != self.rank_in_group return self.tcp_store_group.recv_obj(src) def send_tensor_dict( self, tensor_dict: dict[str, torch.Tensor | Any], dst: int | None = None, all_gather_group: Optional["GroupCoordinator"] = None, all_gather_tensors: dict[str, bool] | None = None, ) -> dict[str, torch.Tensor | Any] | None: if self.world_size == 1: return tensor_dict if dst is None: dst = (self.rank_in_group + 1) % self.world_size assert dst < self.world_size metadata_list, tensor_list = _split_tensor_dict(tensor_dict) self.tcp_store_group.send_obj(metadata_list, dst) for tensor in tensor_list: if tensor.numel() == 0: continue if self.device_communicator and tensor.is_cuda: self.device_communicator.send(tensor, dst) else: self.tcp_store_group.send(tensor, dst) return None def recv_tensor_dict( self, src: int | None = None, all_gather_group: Optional["GroupCoordinator"] = None, all_gather_tensors: dict[str, bool] | None = None, ) -> dict[str, torch.Tensor | Any] | None: if self.world_size == 1: return None if src is None: src = (self.rank_in_group - 1) % self.world_size assert src < self.world_size recv_metadata_list = self.tcp_store_group.recv_obj(src) tensor_dict = {} for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) if tensor.numel() > 0: if self.device_communicator and tensor.is_cuda: tensor = self.device_communicator.recv( tensor.size(), tensor.dtype, src ) else: tensor = self.tcp_store_group.recv(tensor, src) tensor_dict[key] = tensor else: tensor_dict[key] = value return tensor_dict def barrier(self): self.tcp_store_group.barrier() def gather( self, input_: torch.Tensor, dst: int = 0, dim: int = -1 ) -> torch.Tensor | None: if self.world_size == 1: return input_ if self.device_communicator is None: raise ValueError("No device communicator found") if self.rank_in_group == dst: gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)] gathered_list[self.rank_in_group] = input_ for src_rank in range(self.world_size): if src_rank != self.rank_in_group: gathered_list[src_rank] = self.device_communicator.recv( input_.size(), input_.dtype, src_rank ) return torch.cat(gathered_list, dim=dim) else: self.device_communicator.send(input_, dst) return None