diff --git a/tests/ut/patch/worker/patch_common/test_patch_distributed.py b/tests/ut/patch/worker/patch_common/test_patch_distributed.py index 4975313..f3d9509 100644 --- a/tests/ut/patch/worker/patch_common/test_patch_distributed.py +++ b/tests/ut/patch/worker/patch_common/test_patch_distributed.py @@ -29,7 +29,7 @@ class TestPatchDistributed(TestBase): self.mock_group_ranks = [[0, 1]] self.mock_local_rank = 0 self.mock_backend = "hccl" - self.mock_use_device_comm = True + self.mock_use_device_comm = False patcher_get_rank = patch("torch.distributed.get_rank", return_value=0) patcher_new_group = patch("torch.distributed.new_group", @@ -39,16 +39,24 @@ class TestPatchDistributed(TestBase): patcher_device_comm_cls = patch( "vllm.distributed.parallel_state.resolve_obj_by_qualname", return_value=MagicMock()) + patcher_calculate_dp_buffer = patch( + "vllm_ascend.utils.calculate_dp_buffer_size", return_value=64) + patcher_npu_current_device = patch("torch.npu.current_device", + return_value=MagicMock()) self.mock_get_rank = patcher_get_rank.start() self.mock_new_group = patcher_new_group.start() self.mock_is_cuda_alike = patcher_is_cuda_alike.start() self.mock_resolve_obj = patcher_device_comm_cls.start() + self.mock_calculate_dp_buffer = patcher_calculate_dp_buffer.start() + self.mock_npu_current_device = patcher_npu_current_device.start() self.addCleanup(patcher_get_rank.stop) self.addCleanup(patcher_new_group.stop) self.addCleanup(patcher_is_cuda_alike.stop) self.addCleanup(patcher_device_comm_cls.stop) + self.addCleanup(patcher_calculate_dp_buffer.stop) + self.addCleanup(patcher_npu_current_device.stop) self.group_coordinator = GroupCoordinatorPatch( group_ranks=self.mock_group_ranks, diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index bf02390..6cb9004 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -87,6 +87,19 @@ # ** File: worker/patch_common/patch_distributed.py ** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.distributed.parallel_state.GroupCoordinator` +# (1) __init__() +# Why: +# The original GroupCoordinator initialization lacks pg_options to generate new +# process group with customized options. +# How: +# Inject HCCL options during process group initialization. +# Related PR (if no, explain why): +# Need a PR to vllm to support a dictionary as input while initializing distributed +# environment (e.g., Dict[str, torch.distributed.ProcessGroupHCCL.Options]) +# https://github.com/vllm-project/vllm/pull/25417 +# Future Plan: +# Remove this patch when vllm merges this PR. +# (2) all_to_all() # Why: # vllm doesn't support all_to_all for GroupCoordinator. # How: diff --git a/vllm_ascend/patch/worker/patch_common/patch_distributed.py b/vllm_ascend/patch/worker/patch_common/patch_distributed.py index 846d82c..2b46290 100644 --- a/vllm_ascend/patch/worker/patch_common/patch_distributed.py +++ b/vllm_ascend/patch/worker/patch_common/patch_distributed.py @@ -15,17 +15,82 @@ # limitations under the License. # -from typing import List, Optional +from typing import List, Optional, Union import torch import vllm -from vllm.distributed.parallel_state import GroupCoordinator +from torch.distributed import Backend +from vllm.distributed.parallel_state import (GroupCoordinator, + _get_unique_name, _register_group) + +from vllm_ascend.distributed.communicator import NPUCommunicator +from vllm_ascend.utils import create_hccl_pg_options class GroupCoordinatorPatch(GroupCoordinator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_device_communicator: bool, # whether to use device communicator + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + + self_device_group = None + self_cpu_group = None + hccl_pg_options = create_hccl_pg_options(group_name) + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, + backend=torch_distributed_backend, + pg_options=hccl_pg_options) + + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self_device_group = device_group + self_cpu_group = cpu_group + + assert self_cpu_group is not None + assert self_device_group is not None + + self.cpu_group = self_cpu_group + self.device_group = self_device_group + self.device = torch.npu.current_device() + + self.use_device_communicator = use_device_communicator + self.device_communicator = None + if use_device_communicator and self.world_size > 1: + self.device_communicator = NPUCommunicator( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + ) + + from vllm.distributed.device_communicators.shm_broadcast import \ + MessageQueue + self.mq_broadcaster: Optional[MessageQueue] = None + if use_message_queue_broadcaster and self.world_size > 1: + self.mq_broadcaster = MessageQueue.create_from_process_group( + self.cpu_group, 1 << 22, 6) + + self.use_custom_op_call = False + self.use_cpu_custom_send_recv = False def all_to_all(self, input_: torch.Tensor, @@ -41,9 +106,10 @@ class GroupCoordinatorPatch(GroupCoordinator): assert -input_.dim() <= gather_dim < input_.dim(), ( f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}" ) + assert self.device_communicator is not None, "device_communicator should be initialized when world_size > 1" return self.device_communicator.all_to_all(input_, scatter_dim, gather_dim, scatter_sizes, gather_sizes) -vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving \ No newline at end of file +vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index b591413..8382cf7 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -53,6 +53,8 @@ _SLEEP_MODE_ENABLED = None _CURRENT_STREAM = None _PREFETCH_STREAM = None _ASCEND_CUSTOMOP_IS_REIGISTERED = False +_DEFAULT_BUFFER_SIZE = 200 +_MIN_DP_BUFFER_SIZE = 50 def is_310p(): @@ -648,3 +650,51 @@ def npu_stream_switch(target_stream: torch.npu.Stream, return nullcontext() assert target_stream is not None return torch.npu.stream(target_stream) + + +def create_hccl_pg_options(group_name: str): + options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options() + hccl_config = get_hccl_config_for_pg_options(group_name) + if hccl_config is not None: + options.hccl_config = hccl_config + return options + + +def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]: + """ + Get HCCL process group options for the given communication group name. + + Args: + group_name: Name of the communication group + + Returns: + HCCL pg_options or None for mc2 group + """ + # FIXME: Current mc2 operators only perform communication space partitioning + # based on HCCL_BUFFSIZE configuration. Using pg_options with mc2 group would + # result in memory misalignment problems. + if group_name and "mc2" in group_name: + return None + hccl_config_map = { + "dp": { + "hccl_buffer_size": calculate_dp_buffer_size() + }, + } + return hccl_config_map.get(group_name, get_default_buffer_config()) + + +def get_default_buffer_config() -> dict: + return {"hccl_buffer_size": _DEFAULT_BUFFER_SIZE} + + +def calculate_dp_buffer_size() -> int: + """ + formula of dp buffer size: + dp_size + 2 (flags: with_prefill and enable_dbo) + """ + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + dp_size = vllm_config.parallel_config.data_parallel_size + int32_size = torch.iinfo(torch.int32).bits // 8 + dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024)) + return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)