Files
enginex-biren-vllm/vllm_br/distributed/parallel_state.py
2026-03-10 13:31:25 +08:00

474 lines
20 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
import torch.distributed
import torch_br
import vllm
import vllm.distributed.parallel_state
from vllm.distributed import GroupCoordinator
from vllm.distributed.parallel_state import (_WORLD, TensorMetadata,
_split_tensor_dict, get_pp_group,
get_tp_group, get_world_group,
init_model_parallel_group, logger)
from vllm_br import envs
@dataclass
class GraphCaptureContext:
stream: torch_br.supa.Stream
@contextmanager
#@patch_to(GroupCoordinator.graph_capture)
def graph_capture_(self,
graph_capture_context: Optional[GraphCaptureContext] = None
):
if graph_capture_context is None:
stream = torch_br.supa.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
# only supa uses this function,
# so we don't abstract it into the base class
#maybe_ca_context = nullcontext()
#from vllm_br.distributed.communicator import SUPACommunicator
#if self.device_communicator is not None:
# assert isinstance(self.device_communicator, SUPACommunicator)
# ca_comm = self.device_communicator.ca_comm
# if ca_comm is not None:
# maybe_ca_context = ca_comm.capture() # type: ignore
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch_br.supa.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch_br.supa.stream(stream):
yield graph_capture_context
vllm.distributed.parallel_state.GroupCoordinator.graph_capture = graph_capture_
@contextmanager
#@patch_to(graph_capture)
def graph_capture_supa(device: torch.device):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the SUPA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current SUPA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
context = GraphCaptureContext(torch_br.supa.Stream(device=device))
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
context):
yield context
vllm.distributed.parallel_state.graph_capture = graph_capture_supa
def is_global_first_rank() -> bool:
"""
Check if the current process is the first rank globally across all
parallelism strategies (PP, TP, DP, EP, etc.).
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
or `get_pp_group().is_first_rank`, this function checks the global rank
across all parallelism dimensions.
Returns:
bool: True if this is the global first rank (rank 0), False otherwise.
Returns True if distributed is not initialized (single process).
"""
try:
# If world group is available, use it for the most accurate check
if _WORLD is not None:
return _WORLD.is_first_rank
# If torch distributed is not initialized, assume single process
if not torch.distributed.is_initialized():
return True
# Fallback to torch's global rank
return torch.distributed.get_rank() == 0
except Exception:
# If anything goes wrong, assume this is the first rank
return True
def generate_multi_node_parallel_groups(
total_procs: int,
tp_size: int,
pp_size: int,
dp_size: int,
) -> dict:
if total_procs == 16 and tp_size == 8 and pp_size == 2 and dp_size == 1:
tp_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
pp_groups = [[0, 4], [1, 5], [2, 6], [3, 7], [8, 12], [9, 13],
[10, 14], [11, 15]]
dp_groups = [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10],
[11], [12], [13], [14], [15]]
ep_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
else:
raise ValueError(
"Unsupported VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE parallel config of"
" tp_size: {tp_size} pp_size: {pp_size} dp_size: {dp_size}"
"Currently only 'tp8pp2dp1' is allowed.")
return {
"tp_groups": tp_groups,
"pp_groups": pp_groups,
"dp_groups": dp_groups,
"ep_groups": ep_groups,
}
# sync v0.11 api update, while code logic possibly need sync with vllm original code implementation
def initialize_model_parallel_cross_tp(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None,
) -> None:
"""
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
backend: name of torch distributed communication backend.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
data_parallel_size = 1
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size
# the layout order is: ExternalDP x DP x PP x TP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
# DP is the data parallel group that is part of the model,
# all the ranks in the same DP group should generate simultaneously,
# i.e. the `generate` call in the same DP group should be called together,
# otherwise it will cause deadlock.
# to get group_ranks for each dimension, transpose that dimension to the
# last dimension, then reshape to 2D, then unbind the last dimension
all_ranks = torch.arange(world_size).reshape(
-1, data_parallel_size, pipeline_model_parallel_size,
tensor_model_parallel_size) # noqa
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
groups = generate_multi_node_parallel_groups(
world_size, tensor_model_parallel_size,
pipeline_model_parallel_size, data_parallel_size)
logger.info("supernode reorganized groups: %s", groups)
# Build the tensor model-parallel groups.
assert vllm.distributed.parallel_state._TP is None, (
"tensor model parallel group is already initialized")
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
group_ranks = groups['tp_groups']
else:
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group
vllm.distributed.parallel_state._TP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="tp")
# Build the DCP model-parallel groups.
# global _DCP
assert vllm.distributed.parallel_state._DCP is None, (
"decode context model parallel group is already initialized")
# Note(hc): In the current implementation of decode context parallel,
# dcp_size must not exceed tp_size, because the world size does not
# change by DCP, it simply reuses the GPUs of TP group, and split one
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(
-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
vllm.distributed.parallel_state._DCP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="dcp")
# Build the pipeline model-parallel groups.
assert vllm.distributed.parallel_state._PP is None, (
"pipeline model parallel group is already initialized")
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
group_ranks = groups['pp_groups']
else:
group_ranks = all_ranks.transpose(2, 3).reshape(
-1, pipeline_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
vllm.distributed.parallel_state._PP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pp")
assert vllm.distributed.parallel_state._DP is None, (
"data parallel group is already initialized")
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
group_ranks = groups['dp_groups']
else:
group_ranks = all_ranks.transpose(1, 3).reshape(
-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
vllm.distributed.parallel_state._DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp")
assert vllm.distributed.parallel_state._EP is None, (
"expert parallel group is already initialized")
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
group_ranks = groups['ep_groups']
else:
group_ranks = all_ranks.transpose(1, 2).reshape(
-1, data_parallel_size * tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
vllm.distributed.parallel_state._EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep")
logger.info(
"rank %s in world size %s is assigned as (br) "
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
vllm.distributed.parallel_state._DP.rank_in_group,
vllm.distributed.parallel_state._PP.rank_in_group,
vllm.distributed.parallel_state._TP.rank_in_group,
vllm.distributed.parallel_state._EP.rank_in_group)
vllm.distributed.parallel_state.initialize_model_parallel = initialize_model_parallel_cross_tp
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: Optional[dict[str, bool]] = None,
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
all_gather_group: The group for the all-gather operation. If provided,
an optimization is enabled where each rank in the group sends a
slice of a tensor and the receiver reconstructs it using an
all-gather, which can improve performance. This is typically the
tensor-parallel group.
all_gather_tensors: A dictionary to specify which tensors should use
the all-gather optimization, which is only effective when
`all_gather_group` is provided. By default, this optimization is
on for any tensor whose size is divisible by the
`all_gather_group`'s world size. However, it should be disabled
for tensors that are not fully replicated across the group (e.g.,
the residual tensor when sequence parallelism is enabled). This
dictionary allows overriding the default behavior on a per-tensor
basis.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
self.device_communicator.send_tensor_dict( # type: ignore
tensor_dict, dst)
return None
metadata_list: list[tuple[Any, Any]] = []
assert isinstance(tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
tensor_keys = [
k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)
]
assert len(tensor_keys) == len(tensor_list)
for key, tensor in zip(tensor_keys, tensor_list):
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
if all_gather_tensors else use_all_gather
if use_all_gather:
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor,
dst=self.ranks[dst],
group=metadata_group)
else:
# ensure tensor is ready
torch.supa.synchronize()
# use group for GPU tensors
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None
def recv_tensor_dict(
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: Optional[dict[str, bool]] = None,
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
all_gather_group: The group for the all-gather operation. If provided,
an optimization is enabled where each rank in the group sends a
slice of a tensor and the receiver reconstructs it using an
all-gather, which can improve performance. This is typically the
tensor-parallel group.
all_gather_tensors: A dictionary to specify which tensors should use
the all-gather optimization, which is only effective when
`all_gather_group` is provided. By default, this optimization is
on for any tensor whose size is divisible by the
`all_gather_group`'s world size. However, it should be disabled
for tensors that are not fully replicated across the group (e.g.,
the residual tensor when sequence parallelism is enabled). This
dictionary allows overriding the default behavior on a per-tensor
basis.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.recv_tensor_dict( # type: ignore
src)
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
if all_gather_tensors else use_all_gather
if use_all_gather:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
# ensure recv is done
torch.supa.synchronize()
if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0)
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict
vllm.distributed.GroupCoordinator.send_tensor_dict = send_tensor_dict
vllm.distributed.GroupCoordinator.recv_tensor_dict = recv_tensor_dict