first commit
This commit is contained in:
473
vllm_br/distributed/parallel_state.py
Normal file
473
vllm_br/distributed/parallel_state.py
Normal file
@@ -0,0 +1,473 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch_br
|
||||
|
||||
import vllm
|
||||
import vllm.distributed.parallel_state
|
||||
from vllm.distributed import GroupCoordinator
|
||||
from vllm.distributed.parallel_state import (_WORLD, TensorMetadata,
|
||||
_split_tensor_dict, get_pp_group,
|
||||
get_tp_group, get_world_group,
|
||||
init_model_parallel_group, logger)
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphCaptureContext:
|
||||
stream: torch_br.supa.Stream
|
||||
|
||||
|
||||
@contextmanager
|
||||
#@patch_to(GroupCoordinator.graph_capture)
|
||||
def graph_capture_(self,
|
||||
graph_capture_context: Optional[GraphCaptureContext] = None
|
||||
):
|
||||
if graph_capture_context is None:
|
||||
stream = torch_br.supa.Stream()
|
||||
graph_capture_context = GraphCaptureContext(stream)
|
||||
else:
|
||||
stream = graph_capture_context.stream
|
||||
|
||||
# only supa uses this function,
|
||||
# so we don't abstract it into the base class
|
||||
#maybe_ca_context = nullcontext()
|
||||
#from vllm_br.distributed.communicator import SUPACommunicator
|
||||
#if self.device_communicator is not None:
|
||||
# assert isinstance(self.device_communicator, SUPACommunicator)
|
||||
# ca_comm = self.device_communicator.ca_comm
|
||||
# if ca_comm is not None:
|
||||
# maybe_ca_context = ca_comm.capture() # type: ignore
|
||||
|
||||
# ensure all initialization operations complete before attempting to
|
||||
# capture the graph on another stream
|
||||
curr_stream = torch_br.supa.current_stream()
|
||||
if curr_stream != stream:
|
||||
stream.wait_stream(curr_stream)
|
||||
|
||||
with torch_br.supa.stream(stream):
|
||||
yield graph_capture_context
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.GroupCoordinator.graph_capture = graph_capture_
|
||||
|
||||
|
||||
@contextmanager
|
||||
#@patch_to(graph_capture)
|
||||
def graph_capture_supa(device: torch.device):
|
||||
"""
|
||||
`graph_capture` is a context manager which should surround the code that
|
||||
is capturing the SUPA graph. Its main purpose is to ensure that the
|
||||
some operations will be run after the graph is captured, before the graph
|
||||
is replayed. It returns a `GraphCaptureContext` object which contains the
|
||||
necessary data for the graph capture. Currently, it only contains the
|
||||
stream that the graph capture is running on. This stream is set to the
|
||||
current SUPA stream when the context manager is entered and reset to the
|
||||
default stream when the context manager is exited. This is to ensure that
|
||||
the graph capture is running on a separate stream from the default stream,
|
||||
in order to explicitly distinguish the kernels to capture
|
||||
from other kernels possibly launched on background in the default stream.
|
||||
"""
|
||||
context = GraphCaptureContext(torch_br.supa.Stream(device=device))
|
||||
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
|
||||
context):
|
||||
yield context
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.graph_capture = graph_capture_supa
|
||||
|
||||
|
||||
def is_global_first_rank() -> bool:
|
||||
"""
|
||||
Check if the current process is the first rank globally across all
|
||||
parallelism strategies (PP, TP, DP, EP, etc.).
|
||||
|
||||
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
|
||||
or `get_pp_group().is_first_rank`, this function checks the global rank
|
||||
across all parallelism dimensions.
|
||||
|
||||
Returns:
|
||||
bool: True if this is the global first rank (rank 0), False otherwise.
|
||||
Returns True if distributed is not initialized (single process).
|
||||
"""
|
||||
try:
|
||||
# If world group is available, use it for the most accurate check
|
||||
if _WORLD is not None:
|
||||
return _WORLD.is_first_rank
|
||||
|
||||
# If torch distributed is not initialized, assume single process
|
||||
if not torch.distributed.is_initialized():
|
||||
return True
|
||||
|
||||
# Fallback to torch's global rank
|
||||
return torch.distributed.get_rank() == 0
|
||||
|
||||
except Exception:
|
||||
# If anything goes wrong, assume this is the first rank
|
||||
return True
|
||||
|
||||
|
||||
def generate_multi_node_parallel_groups(
|
||||
total_procs: int,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
dp_size: int,
|
||||
) -> dict:
|
||||
if total_procs == 16 and tp_size == 8 and pp_size == 2 and dp_size == 1:
|
||||
tp_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
|
||||
pp_groups = [[0, 4], [1, 5], [2, 6], [3, 7], [8, 12], [9, 13],
|
||||
[10, 14], [11, 15]]
|
||||
dp_groups = [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10],
|
||||
[11], [12], [13], [14], [15]]
|
||||
ep_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE parallel config of"
|
||||
" tp_size: {tp_size} pp_size: {pp_size} dp_size: {dp_size}"
|
||||
"Currently only 'tp8pp2dp1' is allowed.")
|
||||
return {
|
||||
"tp_groups": tp_groups,
|
||||
"pp_groups": pp_groups,
|
||||
"dp_groups": dp_groups,
|
||||
"ep_groups": ep_groups,
|
||||
}
|
||||
|
||||
|
||||
# sync v0.11 api update, while code logic possibly need sync with vllm original code implementation
|
||||
def initialize_model_parallel_cross_tp(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
decode_context_model_parallel_size: Optional[int] = 1,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model parallel groups.
|
||||
|
||||
Arguments:
|
||||
tensor_model_parallel_size: number of GPUs used for tensor model
|
||||
parallelism.
|
||||
pipeline_model_parallel_size: number of GPUs used for pipeline model
|
||||
parallelism.
|
||||
backend: name of torch distributed communication backend.
|
||||
|
||||
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
|
||||
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
||||
the model pipeline. The present function will
|
||||
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
|
||||
4 tensor model-parallel groups:
|
||||
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
|
||||
2 pipeline model-parallel groups:
|
||||
[g0, g2, g4, g6], [g1, g3, g5, g7]
|
||||
Note that for efficiency, the caller should make sure adjacent ranks
|
||||
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
||||
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
||||
ranks 8 to 15 belong to the second box.
|
||||
"""
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
backend = backend or torch.distributed.get_backend(
|
||||
get_world_group().device_group)
|
||||
|
||||
data_parallel_size = 1
|
||||
from vllm.config import get_current_vllm_config
|
||||
config = get_current_vllm_config()
|
||||
if config is not None:
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
|
||||
# the layout order is: ExternalDP x DP x PP x TP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
# every dp rank can generate independently (in verl integration).
|
||||
# DP is the data parallel group that is part of the model,
|
||||
# all the ranks in the same DP group should generate simultaneously,
|
||||
# i.e. the `generate` call in the same DP group should be called together,
|
||||
# otherwise it will cause deadlock.
|
||||
# to get group_ranks for each dimension, transpose that dimension to the
|
||||
# last dimension, then reshape to 2D, then unbind the last dimension
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, data_parallel_size, pipeline_model_parallel_size,
|
||||
tensor_model_parallel_size) # noqa
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
groups = generate_multi_node_parallel_groups(
|
||||
world_size, tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size, data_parallel_size)
|
||||
logger.info("supernode reorganized groups: %s", groups)
|
||||
# Build the tensor model-parallel groups.
|
||||
assert vllm.distributed.parallel_state._TP is None, (
|
||||
"tensor model parallel group is already initialized")
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
group_ranks = groups['tp_groups']
|
||||
else:
|
||||
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
vllm.distributed.parallel_state._TP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="tp")
|
||||
|
||||
# Build the DCP model-parallel groups.
|
||||
# global _DCP
|
||||
assert vllm.distributed.parallel_state._DCP is None, (
|
||||
"decode context model parallel group is already initialized")
|
||||
# Note(hc): In the current implementation of decode context parallel,
|
||||
# dcp_size must not exceed tp_size, because the world size does not
|
||||
# change by DCP, it simply reuses the GPUs of TP group, and split one
|
||||
# TP group into tp_size//dcp_size DCP groups.
|
||||
group_ranks = all_ranks.reshape(
|
||||
-1, decode_context_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
vllm.distributed.parallel_state._DCP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="dcp")
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
assert vllm.distributed.parallel_state._PP is None, (
|
||||
"pipeline model parallel group is already initialized")
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
group_ranks = groups['pp_groups']
|
||||
else:
|
||||
group_ranks = all_ranks.transpose(2, 3).reshape(
|
||||
-1, pipeline_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
vllm.distributed.parallel_state._PP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="pp")
|
||||
|
||||
assert vllm.distributed.parallel_state._DP is None, (
|
||||
"data parallel group is already initialized")
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
group_ranks = groups['dp_groups']
|
||||
else:
|
||||
group_ranks = all_ranks.transpose(1, 3).reshape(
|
||||
-1, data_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
vllm.distributed.parallel_state._DP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="dp")
|
||||
|
||||
assert vllm.distributed.parallel_state._EP is None, (
|
||||
"expert parallel group is already initialized")
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
group_ranks = groups['ep_groups']
|
||||
else:
|
||||
group_ranks = all_ranks.transpose(1, 2).reshape(
|
||||
-1, data_parallel_size * tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
vllm.distributed.parallel_state._EP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="ep")
|
||||
logger.info(
|
||||
"rank %s in world size %s is assigned as (br) "
|
||||
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
|
||||
vllm.distributed.parallel_state._DP.rank_in_group,
|
||||
vllm.distributed.parallel_state._PP.rank_in_group,
|
||||
vllm.distributed.parallel_state._TP.rank_in_group,
|
||||
vllm.distributed.parallel_state._EP.rank_in_group)
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.initialize_model_parallel = initialize_model_parallel_cross_tp
|
||||
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
||||
dst: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: Optional[dict[str, bool]] = None,
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Send the input tensor dictionary.
|
||||
NOTE: `dst` is the local rank of the source rank.
|
||||
|
||||
all_gather_group: The group for the all-gather operation. If provided,
|
||||
an optimization is enabled where each rank in the group sends a
|
||||
slice of a tensor and the receiver reconstructs it using an
|
||||
all-gather, which can improve performance. This is typically the
|
||||
tensor-parallel group.
|
||||
all_gather_tensors: A dictionary to specify which tensors should use
|
||||
the all-gather optimization, which is only effective when
|
||||
`all_gather_group` is provided. By default, this optimization is
|
||||
on for any tensor whose size is divisible by the
|
||||
`all_gather_group`'s world size. However, it should be disabled
|
||||
for tensors that are not fully replicated across the group (e.g.,
|
||||
the residual tensor when sequence parallelism is enabled). This
|
||||
dictionary allows overriding the default behavior on a per-tensor
|
||||
basis.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return tensor_dict
|
||||
all_gather_size = (1 if all_gather_group is None else
|
||||
all_gather_group.world_size)
|
||||
all_gather_rank = (0 if all_gather_group is None else
|
||||
all_gather_group.rank_in_group)
|
||||
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
self.device_communicator.send_tensor_dict( # type: ignore
|
||||
tensor_dict, dst)
|
||||
return None
|
||||
|
||||
metadata_list: list[tuple[Any, Any]] = []
|
||||
assert isinstance(tensor_dict,
|
||||
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
# `metadata_list` lives in CPU memory.
|
||||
# `send_object_list` has serialization & deserialization,
|
||||
# all happening on CPU. Therefore, we can use the CPU group.
|
||||
self.send_object(metadata_list, dst=dst)
|
||||
|
||||
tensor_keys = [
|
||||
k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)
|
||||
]
|
||||
assert len(tensor_keys) == len(tensor_list)
|
||||
|
||||
for key, tensor in zip(tensor_keys, tensor_list):
|
||||
if tensor.numel() == 0:
|
||||
# Skip sending empty tensors.
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
use_all_gather = (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0)
|
||||
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
|
||||
if all_gather_tensors else use_all_gather
|
||||
if use_all_gather:
|
||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.send(tensor,
|
||||
dst=self.ranks[dst],
|
||||
group=metadata_group)
|
||||
else:
|
||||
# ensure tensor is ready
|
||||
torch.supa.synchronize()
|
||||
# use group for GPU tensors
|
||||
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
|
||||
return None
|
||||
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: Optional[dict[str, bool]] = None,
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Recv the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
|
||||
all_gather_group: The group for the all-gather operation. If provided,
|
||||
an optimization is enabled where each rank in the group sends a
|
||||
slice of a tensor and the receiver reconstructs it using an
|
||||
all-gather, which can improve performance. This is typically the
|
||||
tensor-parallel group.
|
||||
all_gather_tensors: A dictionary to specify which tensors should use
|
||||
the all-gather optimization, which is only effective when
|
||||
`all_gather_group` is provided. By default, this optimization is
|
||||
on for any tensor whose size is divisible by the
|
||||
`all_gather_group`'s world size. However, it should be disabled
|
||||
for tensors that are not fully replicated across the group (e.g.,
|
||||
the residual tensor when sequence parallelism is enabled). This
|
||||
dictionary allows overriding the default behavior on a per-tensor
|
||||
basis.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return None
|
||||
all_gather_size = (1 if all_gather_group is None else
|
||||
all_gather_group.world_size)
|
||||
all_gather_rank = (0 if all_gather_group is None else
|
||||
all_gather_group.rank_in_group)
|
||||
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
return self.device_communicator.recv_tensor_dict( # type: ignore
|
||||
src)
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
tensor_dict: dict[str, Any] = {}
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
device=value.device)
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
tensor_dict[key] = tensor
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
use_all_gather = (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0)
|
||||
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
|
||||
if all_gather_tensors else use_all_gather
|
||||
|
||||
if use_all_gather:
|
||||
orig_shape = tensor.shape
|
||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.recv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
torch.distributed.recv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=group)
|
||||
# ensure recv is done
|
||||
torch.supa.synchronize()
|
||||
if use_all_gather:
|
||||
# do the allgather
|
||||
tensor = all_gather_group.all_gather( # type: ignore
|
||||
tensor, dim=0)
|
||||
tensor = tensor.reshape(orig_shape)
|
||||
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
return tensor_dict
|
||||
|
||||
|
||||
vllm.distributed.GroupCoordinator.send_tensor_dict = send_tensor_dict
|
||||
vllm.distributed.GroupCoordinator.recv_tensor_dict = recv_tensor_dict
|
||||
Reference in New Issue
Block a user