[BugFix] Fix data parallel (#940)

### What this PR does / why we need it?
With this PR, we can migrate to the native `data_parallel.py` in vllm
examples and remove the version in vllm-ascend.

At present, `ASCEND_RT_VISIBLE_DEVICES` introduces considerable
difficulties; therefore, we must employ a temporary workaround and
manually specify the device.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-06-09 14:08:18 +08:00
committed by GitHub
parent eec6068187
commit 6003afa6d2
5 changed files with 191 additions and 115 deletions

View File

@@ -18,10 +18,13 @@
import gc
import logging
import os
from datetime import timedelta
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import vllm.envs as envs
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import PrefixStore
from vllm.logger import logger
from vllm.platforms import Platform, PlatformEnum
@@ -262,3 +265,45 @@ class NPUPlatform(Platform):
Get piecewise backend class for piecewise graph.
"""
return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
from torch.distributed import is_hccl_available
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
assert is_hccl_available()
# TODO(Yizhou): The reason we need to set options while vllm does not
# seems to be related to the version of PyTorch. In the latest version,
# there is no need to set options. While in the older version, 2.5.1
# specifically, we need to set options.
options = ProcessGroup.Options(backend=backend)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
backend_options = ProcessGroupHCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
backend_options)
device = torch.device("npu")
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
# implemented in the 2.5.1 version of PyTorch. But we need to set it
# after the latest version is released.
# pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
backend_type = ProcessGroup.BackendType.CUSTOM
pg._register_backend(device, backend_type, backend_class)
return pg