################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Optional, Union import torch import torch.distributed import torch_br import vllm import vllm.distributed.parallel_state from vllm.distributed import GroupCoordinator from vllm.distributed.parallel_state import (_WORLD, TensorMetadata, _split_tensor_dict, get_pp_group, get_tp_group, get_world_group, init_model_parallel_group, logger) from vllm_br import envs @dataclass class GraphCaptureContext: stream: torch_br.supa.Stream @contextmanager #@patch_to(GroupCoordinator.graph_capture) def graph_capture_(self, graph_capture_context: Optional[GraphCaptureContext] = None ): if graph_capture_context is None: stream = torch_br.supa.Stream() graph_capture_context = GraphCaptureContext(stream) else: stream = graph_capture_context.stream # only supa uses this function, # so we don't abstract it into the base class #maybe_ca_context = nullcontext() #from vllm_br.distributed.communicator import SUPACommunicator #if self.device_communicator is not None: # assert isinstance(self.device_communicator, SUPACommunicator) # ca_comm = self.device_communicator.ca_comm # if ca_comm is not None: # maybe_ca_context = ca_comm.capture() # type: ignore # ensure all initialization operations complete before attempting to # capture the graph on another stream curr_stream = torch_br.supa.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) with torch_br.supa.stream(stream): yield graph_capture_context vllm.distributed.parallel_state.GroupCoordinator.graph_capture = graph_capture_ @contextmanager #@patch_to(graph_capture) def graph_capture_supa(device: torch.device): """ `graph_capture` is a context manager which should surround the code that is capturing the SUPA graph. Its main purpose is to ensure that the some operations will be run after the graph is captured, before the graph is replayed. It returns a `GraphCaptureContext` object which contains the necessary data for the graph capture. Currently, it only contains the stream that the graph capture is running on. This stream is set to the current SUPA stream when the context manager is entered and reset to the default stream when the context manager is exited. This is to ensure that the graph capture is running on a separate stream from the default stream, in order to explicitly distinguish the kernels to capture from other kernels possibly launched on background in the default stream. """ context = GraphCaptureContext(torch_br.supa.Stream(device=device)) with get_tp_group().graph_capture(context), get_pp_group().graph_capture( context): yield context vllm.distributed.parallel_state.graph_capture = graph_capture_supa def is_global_first_rank() -> bool: """ Check if the current process is the first rank globally across all parallelism strategies (PP, TP, DP, EP, etc.). Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0` or `get_pp_group().is_first_rank`, this function checks the global rank across all parallelism dimensions. Returns: bool: True if this is the global first rank (rank 0), False otherwise. Returns True if distributed is not initialized (single process). """ try: # If world group is available, use it for the most accurate check if _WORLD is not None: return _WORLD.is_first_rank # If torch distributed is not initialized, assume single process if not torch.distributed.is_initialized(): return True # Fallback to torch's global rank return torch.distributed.get_rank() == 0 except Exception: # If anything goes wrong, assume this is the first rank return True def generate_multi_node_parallel_groups( total_procs: int, tp_size: int, pp_size: int, dp_size: int, ) -> dict: if total_procs == 16 and tp_size == 8 and pp_size == 2 and dp_size == 1: tp_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]] pp_groups = [[0, 4], [1, 5], [2, 6], [3, 7], [8, 12], [9, 13], [10, 14], [11, 15]] dp_groups = [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15]] ep_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]] else: raise ValueError( "Unsupported VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE parallel config of" " tp_size: {tp_size} pp_size: {pp_size} dp_size: {dp_size}" "Currently only 'tp8pp2dp1' is allowed.") return { "tp_groups": tp_groups, "pp_groups": pp_groups, "dp_groups": dp_groups, "ep_groups": ep_groups, } # sync v0.11 api update, while code logic possibly need sync with vllm original code implementation def initialize_model_parallel_cross_tp( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, decode_context_model_parallel_size: Optional[int] = 1, backend: Optional[str] = None, ) -> None: """ Initialize model parallel groups. Arguments: tensor_model_parallel_size: number of GPUs used for tensor model parallelism. pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. backend: name of torch distributed communication backend. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize the model pipeline. The present function will create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: 4 tensor model-parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7] 2 pipeline model-parallel groups: [g0, g2, g4, g6], [g1, g3, g5, g7] Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() rank = torch.distributed.get_rank() backend = backend or torch.distributed.get_backend( get_world_group().device_group) data_parallel_size = 1 from vllm.config import get_current_vllm_config config = get_current_vllm_config() if config is not None: data_parallel_size = config.parallel_config.data_parallel_size # the layout order is: ExternalDP x DP x PP x TP # ExternalDP is the data parallel group that is not part of the model, # every dp rank can generate independently (in verl integration). # DP is the data parallel group that is part of the model, # all the ranks in the same DP group should generate simultaneously, # i.e. the `generate` call in the same DP group should be called together, # otherwise it will cause deadlock. # to get group_ranks for each dimension, transpose that dimension to the # last dimension, then reshape to 2D, then unbind the last dimension all_ranks = torch.arange(world_size).reshape( -1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size) # noqa if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE: groups = generate_multi_node_parallel_groups( world_size, tensor_model_parallel_size, pipeline_model_parallel_size, data_parallel_size) logger.info("supernode reorganized groups: %s", groups) # Build the tensor model-parallel groups. assert vllm.distributed.parallel_state._TP is None, ( "tensor model parallel group is already initialized") if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE: group_ranks = groups['tp_groups'] else: group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] # message queue broadcaster is only used in tensor model parallel group vllm.distributed.parallel_state._TP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, use_message_queue_broadcaster=True, group_name="tp") # Build the DCP model-parallel groups. # global _DCP assert vllm.distributed.parallel_state._DCP is None, ( "decode context model parallel group is already initialized") # Note(hc): In the current implementation of decode context parallel, # dcp_size must not exceed tp_size, because the world size does not # change by DCP, it simply reuses the GPUs of TP group, and split one # TP group into tp_size//dcp_size DCP groups. group_ranks = all_ranks.reshape( -1, decode_context_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] vllm.distributed.parallel_state._DCP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, use_message_queue_broadcaster=True, group_name="dcp") # Build the pipeline model-parallel groups. assert vllm.distributed.parallel_state._PP is None, ( "pipeline model parallel group is already initialized") if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE: group_ranks = groups['pp_groups'] else: group_ranks = all_ranks.transpose(2, 3).reshape( -1, pipeline_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] vllm.distributed.parallel_state._PP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, group_name="pp") assert vllm.distributed.parallel_state._DP is None, ( "data parallel group is already initialized") if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE: group_ranks = groups['dp_groups'] else: group_ranks = all_ranks.transpose(1, 3).reshape( -1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] vllm.distributed.parallel_state._DP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, group_name="dp") assert vllm.distributed.parallel_state._EP is None, ( "expert parallel group is already initialized") if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE: group_ranks = groups['ep_groups'] else: group_ranks = all_ranks.transpose(1, 2).reshape( -1, data_parallel_size * tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] vllm.distributed.parallel_state._EP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, group_name="ep") logger.info( "rank %s in world size %s is assigned as (br) " "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size, vllm.distributed.parallel_state._DP.rank_in_group, vllm.distributed.parallel_state._PP.rank_in_group, vllm.distributed.parallel_state._TP.rank_in_group, vllm.distributed.parallel_state._EP.rank_in_group) vllm.distributed.parallel_state.initialize_model_parallel = initialize_model_parallel_cross_tp def send_tensor_dict( self, tensor_dict: dict[str, Union[torch.Tensor, Any]], dst: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, all_gather_tensors: Optional[dict[str, bool]] = None, ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. all_gather_group: The group for the all-gather operation. If provided, an optimization is enabled where each rank in the group sends a slice of a tensor and the receiver reconstructs it using an all-gather, which can improve performance. This is typically the tensor-parallel group. all_gather_tensors: A dictionary to specify which tensors should use the all-gather optimization, which is only effective when `all_gather_group` is provided. By default, this optimization is on for any tensor whose size is divisible by the `all_gather_group`'s world size. However, it should be disabled for tensors that are not fully replicated across the group (e.g., the residual tensor when sequence parallelism is enabled). This dictionary allows overriding the default behavior on a per-tensor basis. """ # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict all_gather_size = (1 if all_gather_group is None else all_gather_group.world_size) all_gather_rank = (0 if all_gather_group is None else all_gather_group.rank_in_group) group = self.device_group metadata_group = self.cpu_group if dst is None: dst = (self.rank_in_group + 1) % self.world_size assert dst < self.world_size, f"Invalid dst rank ({dst})" if self.use_cpu_custom_send_recv: if self.device_communicator is None: raise ValueError("No device communicator found") self.device_communicator.send_tensor_dict( # type: ignore tensor_dict, dst) return None metadata_list: list[tuple[Any, Any]] = [] assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `send_object_list` has serialization & deserialization, # all happening on CPU. Therefore, we can use the CPU group. self.send_object(metadata_list, dst=dst) tensor_keys = [ k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor) ] assert len(tensor_keys) == len(tensor_list) for key, tensor in zip(tensor_keys, tensor_list): if tensor.numel() == 0: # Skip sending empty tensors. continue # send-allgather: send only a slice, then do allgather. use_all_gather = (all_gather_group is not None and tensor.numel() % all_gather_size == 0) use_all_gather = all_gather_tensors.get(key, use_all_gather) \ if all_gather_tensors else use_all_gather if use_all_gather: tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] if tensor.is_cpu: # use metadata_group for CPU tensors torch.distributed.send(tensor, dst=self.ranks[dst], group=metadata_group) else: # ensure tensor is ready torch.supa.synchronize() # use group for GPU tensors torch.distributed.send(tensor, dst=self.ranks[dst], group=group) return None def recv_tensor_dict( self, src: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, all_gather_tensors: Optional[dict[str, bool]] = None, ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. all_gather_group: The group for the all-gather operation. If provided, an optimization is enabled where each rank in the group sends a slice of a tensor and the receiver reconstructs it using an all-gather, which can improve performance. This is typically the tensor-parallel group. all_gather_tensors: A dictionary to specify which tensors should use the all-gather optimization, which is only effective when `all_gather_group` is provided. By default, this optimization is on for any tensor whose size is divisible by the `all_gather_group`'s world size. However, it should be disabled for tensors that are not fully replicated across the group (e.g., the residual tensor when sequence parallelism is enabled). This dictionary allows overriding the default behavior on a per-tensor basis. """ # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return None all_gather_size = (1 if all_gather_group is None else all_gather_group.world_size) all_gather_rank = (0 if all_gather_group is None else all_gather_group.rank_in_group) group = self.device_group metadata_group = self.cpu_group if src is None: src = (self.rank_in_group - 1) % self.world_size assert src < self.world_size, f"Invalid src rank ({src})" if self.use_cpu_custom_send_recv: if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.recv_tensor_dict( # type: ignore src) recv_metadata_list = self.recv_object(src=src) tensor_dict: dict[str, Any] = {} for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) if tensor.numel() == 0: # Skip broadcasting empty tensors. tensor_dict[key] = tensor continue # send-allgather: send only a slice, then do allgather. use_all_gather = (all_gather_group is not None and tensor.numel() % all_gather_size == 0) use_all_gather = all_gather_tensors.get(key, use_all_gather) \ if all_gather_tensors else use_all_gather if use_all_gather: orig_shape = tensor.shape tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] if tensor.is_cpu: # use metadata_group for CPU tensors torch.distributed.recv(tensor, src=self.ranks[src], group=metadata_group) else: # use group for GPU tensors torch.distributed.recv(tensor, src=self.ranks[src], group=group) # ensure recv is done torch.supa.synchronize() if use_all_gather: # do the allgather tensor = all_gather_group.all_gather( # type: ignore tensor, dim=0) tensor = tensor.reshape(orig_shape) tensor_dict[key] = tensor else: tensor_dict[key] = value return tensor_dict vllm.distributed.GroupCoordinator.send_tensor_dict = send_tensor_dict vllm.distributed.GroupCoordinator.recv_tensor_dict = recv_tensor_dict