Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -25,6 +25,7 @@ If you only need to use the distributed environment without model/pipeline
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
import pickle
|
||||
import weakref
|
||||
from collections import namedtuple
|
||||
@@ -33,7 +34,7 @@ from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -54,6 +55,10 @@ from vllm.utils.system_utils import suppress_stdout
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
import ixformer.distributed as ixfd
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
@@ -327,6 +332,8 @@ class GroupCoordinator:
|
||||
self.rank = torch.distributed.get_rank()
|
||||
self.local_rank = local_rank
|
||||
|
||||
use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM", None) not in {"1", "Y", "y"}
|
||||
|
||||
self_device_group = None
|
||||
self_cpu_group = None
|
||||
|
||||
@@ -339,7 +346,7 @@ class GroupCoordinator:
|
||||
with suppress_stdout():
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
if self.rank in ranks:
|
||||
self.ixfd_group = ixfd.init_comm_with_store(device_group)
|
||||
self.ixfd_group = ixfd.init_comm_with_store(device_group) if use_vllm_comm else None
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
self.rank_in_group = ranks.index(self.rank)
|
||||
@@ -372,8 +379,7 @@ class GroupCoordinator:
|
||||
self.device_communicator = device_comm_cls(
|
||||
cpu_group=self.cpu_group,
|
||||
device=self.device,
|
||||
# device_group=self.device_group,
|
||||
device_group=self.ixfd_group if envs.VLLM_FORCE_NCCL_COMM else self.device_group,
|
||||
device_group=self.ixfd_group if use_vllm_comm else self.device_group,
|
||||
unique_name=self.unique_name,
|
||||
)
|
||||
|
||||
@@ -385,11 +391,6 @@ class GroupCoordinator:
|
||||
self.cpu_group, 1 << 22, 6
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# self.use_custom_op_call = (
|
||||
# current_platform.is_cuda_alike() or current_platform.is_tpu()
|
||||
# )
|
||||
self.use_custom_op_call = False
|
||||
|
||||
self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
|
||||
@@ -468,14 +469,12 @@ class GroupCoordinator:
|
||||
# only cuda uses this function,
|
||||
# so we don't abstract it into the base class
|
||||
maybe_ca_context = nullcontext()
|
||||
# from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
# CudaCommunicator,
|
||||
# )
|
||||
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
|
||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
CudaCommunicator,
|
||||
)
|
||||
|
||||
if self.device_communicator is not None:
|
||||
# assert isinstance(self.device_communicator, CudaCommunicator)
|
||||
assert isinstance(self.device_communicator, DeviceCommunicatorBase)
|
||||
assert isinstance(self.device_communicator, CudaCommunicator)
|
||||
ca_comm = self.device_communicator.ca_comm
|
||||
if ca_comm is not None:
|
||||
maybe_ca_context = ca_comm.capture() # type: ignore
|
||||
@@ -608,9 +607,9 @@ class GroupCoordinator:
|
||||
src=self.ranks[src],
|
||||
group=self.device_group)
|
||||
else:
|
||||
torch.distributed.broadcast(input_,
|
||||
src=self.ranks[src],
|
||||
group=self.device_group)
|
||||
torch.distributed.broadcast(
|
||||
input_, src=self.ranks[src], group=self.device_group
|
||||
)
|
||||
return input_
|
||||
|
||||
def broadcast_object(self, obj: Any | None = None, src: int = 0):
|
||||
@@ -764,10 +763,9 @@ class GroupCoordinator:
|
||||
group=group,
|
||||
async_op=True)
|
||||
else:
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=self.ranks[src],
|
||||
group=group,
|
||||
async_op=True)
|
||||
handle = torch.distributed.broadcast(
|
||||
tensor, src=self.ranks[src], group=group, async_op=True
|
||||
)
|
||||
async_handles.append(handle)
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
@@ -802,10 +800,8 @@ class GroupCoordinator:
|
||||
async_op=True)
|
||||
else:
|
||||
handle = torch.distributed.broadcast(
|
||||
tensor,
|
||||
src=self.ranks[src],
|
||||
group=group,
|
||||
async_op=True)
|
||||
tensor, src=self.ranks[src], group=group, async_op=True
|
||||
)
|
||||
async_handles.append(handle)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
@@ -876,6 +872,10 @@ class GroupCoordinator:
|
||||
if self.world_size <= 1:
|
||||
return []
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
@@ -893,10 +893,6 @@ class GroupCoordinator:
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
self.send_object(metadata_list, dst=dst)
|
||||
|
||||
@@ -917,6 +913,7 @@ class GroupCoordinator:
|
||||
handle = torch.distributed.isend(
|
||||
tensor, dst=self.ranks[dst], group=comm_group
|
||||
)
|
||||
|
||||
if tensor.is_cuda:
|
||||
tensor.record_stream(torch.cuda.current_stream(tensor.device))
|
||||
handles.append(handle)
|
||||
@@ -973,6 +970,11 @@ class GroupCoordinator:
|
||||
]:
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return None, [], []
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
@@ -990,10 +992,6 @@ class GroupCoordinator:
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
tensor_dict: dict[str, Any] = {}
|
||||
handles: list[Handle] = []
|
||||
@@ -1072,14 +1070,13 @@ class GroupCoordinator:
|
||||
return self.device_communicator.recv(size, dtype, src)
|
||||
|
||||
def destroy(self):
|
||||
if hasattr(self, "device_group"):
|
||||
# torch.distributed.destroy_process_group(self.device_group)
|
||||
if self.device_group is not None:
|
||||
if self.device_communicator and self.device_communicator.use_vllm_comm:
|
||||
ixfd.destroy_process_group(self.device_group)
|
||||
else:
|
||||
torch.distributed.destroy_process_group(self.device_group)
|
||||
del self.device_group
|
||||
if hasattr(self, "cpu_group"):
|
||||
self.device_group = None
|
||||
if self.cpu_group is not None:
|
||||
torch.distributed.destroy_process_group(self.cpu_group)
|
||||
del self.cpu_group
|
||||
if self.device_communicator is not None:
|
||||
@@ -1094,7 +1091,6 @@ class GroupCoordinator:
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
extra_residual:torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
@@ -1105,13 +1101,12 @@ class GroupCoordinator:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch_router_logits(
|
||||
hidden_states,
|
||||
extra_residual,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors,
|
||||
)
|
||||
else:
|
||||
return hidden_states, extra_residual, router_logits
|
||||
return hidden_states, router_logits
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
@@ -1189,6 +1184,55 @@ def init_model_parallel_group(
|
||||
)
|
||||
|
||||
|
||||
def _init_stateless_group(
|
||||
group_ranks: list[list[int]],
|
||||
group_name: str,
|
||||
group_ports: list[list[int]],
|
||||
host: str,
|
||||
backend: str,
|
||||
use_device_communicator: bool = True,
|
||||
) -> "StatelessGroupCoordinator":
|
||||
"""Create a StatelessGroupCoordinator with the given parameters."""
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
world = get_world_group()
|
||||
return StatelessGroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=world.local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_device_communicator=use_device_communicator,
|
||||
group_name=group_name,
|
||||
host=host,
|
||||
group_ports=group_ports,
|
||||
global_rank=world.rank,
|
||||
global_world_size=world.world_size,
|
||||
)
|
||||
|
||||
|
||||
def _replace_active_groups(
|
||||
*,
|
||||
world: GroupCoordinator | None,
|
||||
dp: GroupCoordinator | None,
|
||||
ep: GroupCoordinator | None,
|
||||
eplb: GroupCoordinator | None,
|
||||
node_count: int | None,
|
||||
) -> None:
|
||||
"""Destroy the current DP/EP/WORLD/EPLB groups and replace them.
|
||||
|
||||
Destruction is collective — all ranks in the old groups must call this
|
||||
function together. Pass all-``None`` to tear down without replacement.
|
||||
"""
|
||||
global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
|
||||
for group in (_DP, _EP, _WORLD, _EPLB):
|
||||
if group is not None:
|
||||
group.destroy()
|
||||
_WORLD = world
|
||||
_DP = dp
|
||||
_EP = ep
|
||||
_EPLB = eplb
|
||||
_NODE_COUNT = node_count
|
||||
|
||||
|
||||
_TP: GroupCoordinator | None = None
|
||||
|
||||
|
||||
@@ -1286,6 +1330,39 @@ def set_custom_all_reduce(enable: bool):
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def _init_elastic_ep_world(
|
||||
config, local_rank: int, backend: str, rank: int, world_size: int
|
||||
) -> None:
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
global _WORLD, _NODE_COUNT
|
||||
assert _WORLD is None, "world group already initialized"
|
||||
parallel_config = config.parallel_config
|
||||
global_rank = parallel_config.data_parallel_rank * world_size + rank
|
||||
global_world_size = parallel_config.world_size_across_dp
|
||||
all_ranks = list(range(global_world_size))
|
||||
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
|
||||
if global_rank in all_ranks:
|
||||
group_ranks = [all_ranks]
|
||||
group_ports = [parallel_config.get_next_stateless_world_group_port()]
|
||||
world = StatelessGroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_device_communicator=False,
|
||||
group_name="world",
|
||||
host=parallel_config.data_parallel_master_ip,
|
||||
group_ports=group_ports,
|
||||
global_rank=global_rank,
|
||||
global_world_size=global_world_size,
|
||||
)
|
||||
assert parallel_config.nnodes_within_dp == 1, (
|
||||
"Elastic EP is not supported with multi-node TP/PP"
|
||||
)
|
||||
_NODE_COUNT = _node_count(world.tcp_store_group)
|
||||
_WORLD = world
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
@@ -1305,6 +1382,7 @@ def init_distributed_environment(
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
config = get_current_vllm_config_or_none()
|
||||
enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
|
||||
if (
|
||||
config is not None
|
||||
and config.parallel_config.distributed_executor_backend != "external_launcher"
|
||||
@@ -1312,6 +1390,7 @@ def init_distributed_environment(
|
||||
config.parallel_config.nnodes > 1
|
||||
or config.parallel_config.data_parallel_size > 1
|
||||
)
|
||||
and not enable_elastic_ep
|
||||
):
|
||||
parallel_config = config.parallel_config
|
||||
# adjust to take into account data parallelism
|
||||
@@ -1365,6 +1444,18 @@ def init_distributed_environment(
|
||||
rank=rank,
|
||||
timeout=timeout,
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
tp_pp_cpu_group = torch.distributed.new_group(
|
||||
backend="gloo", timeout=timeout
|
||||
)
|
||||
if _node_count(tp_pp_cpu_group) > 1:
|
||||
# NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
|
||||
# to initialize all DP/EP groups, hence all ranks within TP/PP group
|
||||
# must reside on the same node
|
||||
raise RuntimeError(
|
||||
"Elastic EP is not yet supported with multi-node TP/PP"
|
||||
)
|
||||
|
||||
# set the local rank
|
||||
# local_rank is not available in torch ProcessGroup,
|
||||
# see https://github.com/pytorch/pytorch/issues/122816
|
||||
@@ -1373,6 +1464,9 @@ def init_distributed_environment(
|
||||
# setting, where we can use rank as local rank
|
||||
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
|
||||
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
|
||||
if enable_elastic_ep:
|
||||
_init_elastic_ep_world(config, local_rank, backend, rank, world_size)
|
||||
return
|
||||
if _WORLD is None:
|
||||
ranks = list(range(torch.distributed.get_world_size()))
|
||||
_WORLD = init_world_group(ranks, local_rank, backend)
|
||||
@@ -1436,16 +1530,33 @@ def initialize_model_parallel(
|
||||
"""
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
|
||||
|
||||
data_parallel_size = 1
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
config = get_current_vllm_config_or_none()
|
||||
if config is not None:
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
config = get_current_vllm_config()
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
enable_elastic_ep = config.parallel_config.enable_elastic_ep
|
||||
if enable_elastic_ep:
|
||||
# Use stateless world group for global information
|
||||
world_size = get_world_group().world_size
|
||||
rank = get_world_group().rank
|
||||
backend = backend or "nccl"
|
||||
tp_pp_pcp_size = (
|
||||
tensor_model_parallel_size
|
||||
* pipeline_model_parallel_size
|
||||
* prefill_context_model_parallel_size
|
||||
)
|
||||
local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
|
||||
pipeline_model_parallel_size,
|
||||
prefill_context_model_parallel_size,
|
||||
tensor_model_parallel_size,
|
||||
)
|
||||
else:
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
backend = backend or torch.distributed.get_backend(
|
||||
get_world_group().device_group
|
||||
)
|
||||
|
||||
# the layout order is: ExternalDP x DP x PP x TP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
@@ -1469,7 +1580,9 @@ def initialize_model_parallel(
|
||||
assert _TP is None, "tensor model parallel group is already initialized"
|
||||
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
if enable_elastic_ep:
|
||||
group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
_TP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
@@ -1488,6 +1601,11 @@ def initialize_model_parallel(
|
||||
# TP group into tp_size//dcp_size DCP groups.
|
||||
group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
group_ranks = local_all_ranks.reshape(
|
||||
-1, decode_context_model_parallel_size
|
||||
).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_DCP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
@@ -1504,6 +1622,13 @@ def initialize_model_parallel(
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
group_ranks = (
|
||||
local_all_ranks.transpose(1, 2)
|
||||
.reshape(-1, prefill_context_model_parallel_size)
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_PCP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
|
||||
)
|
||||
@@ -1515,6 +1640,13 @@ def initialize_model_parallel(
|
||||
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
group_ranks = (
|
||||
local_all_ranks.transpose(0, 2)
|
||||
.reshape(-1, pipeline_model_parallel_size)
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_PP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="pp"
|
||||
)
|
||||
@@ -1523,14 +1655,27 @@ def initialize_model_parallel(
|
||||
assert _DP is None, "data parallel group is already initialized"
|
||||
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_DP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="dp"
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
parallel_config = config.parallel_config
|
||||
dp_ports = [
|
||||
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
|
||||
]
|
||||
_DP = _init_stateless_group(
|
||||
group_ranks,
|
||||
"dp",
|
||||
dp_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
)
|
||||
else:
|
||||
_DP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="dp"
|
||||
)
|
||||
|
||||
global _EP
|
||||
assert _EP is None, "expert parallel group is already initialized"
|
||||
# Don't create EP group for dense models.
|
||||
if config is None or config.model_config is None or config.model_config.is_moe:
|
||||
if config.model_config is None or config.model_config.is_moe:
|
||||
group_ranks = (
|
||||
all_ranks.transpose(1, 2)
|
||||
.reshape(
|
||||
@@ -1542,9 +1687,22 @@ def initialize_model_parallel(
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_EP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="ep"
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
parallel_config = config.parallel_config
|
||||
ep_ports = [
|
||||
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
|
||||
]
|
||||
_EP = _init_stateless_group(
|
||||
group_ranks,
|
||||
"ep",
|
||||
ep_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
)
|
||||
else:
|
||||
_EP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="ep"
|
||||
)
|
||||
|
||||
# Create EPLB group with the same ranks as EP if EPLB is enabled.
|
||||
# This is a separate process group to isolate EPLB communications
|
||||
@@ -1557,10 +1715,25 @@ def initialize_model_parallel(
|
||||
and config.parallel_config is not None
|
||||
and config.parallel_config.enable_eplb
|
||||
):
|
||||
# Reuse the same group_ranks from EP
|
||||
_EPLB = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="eplb"
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
eplb_ports = [
|
||||
parallel_config.get_next_stateless_eplb_group_port()
|
||||
for _ in group_ranks
|
||||
]
|
||||
_EPLB = _init_stateless_group(
|
||||
group_ranks,
|
||||
"eplb",
|
||||
eplb_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
)
|
||||
else:
|
||||
_EPLB = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="eplb",
|
||||
)
|
||||
# If no EP group needed, _EP remains None
|
||||
# If no EPLB group needed, _EPLB remains None
|
||||
|
||||
@@ -1590,7 +1763,11 @@ def ensure_model_parallel_initialized(
|
||||
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
||||
values if the model parallel groups are initialized.
|
||||
"""
|
||||
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
|
||||
world_group = get_world_group()
|
||||
if hasattr(world_group, "backend"):
|
||||
backend = backend or world_group.backend
|
||||
else:
|
||||
backend = backend or torch.distributed.get_backend(world_group.device_group)
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size,
|
||||
|
||||
Reference in New Issue
Block a user