[BugFix] Fix data parallel (#940)
### What this PR does / why we need it? With this PR, we can migrate to the native `data_parallel.py` in vllm examples and remove the version in vllm-ascend. At present, `ASCEND_RT_VISIBLE_DEVICES` introduces considerable difficulties; therefore, we must employ a temporary workaround and manually specify the device. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -18,10 +18,13 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import PrefixStore
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import Platform, PlatformEnum
|
||||
|
||||
@@ -262,3 +265,45 @@ class NPUPlatform(Platform):
|
||||
Get piecewise backend class for piecewise graph.
|
||||
"""
|
||||
return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
from torch.distributed import is_hccl_available
|
||||
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
||||
|
||||
assert is_hccl_available()
|
||||
|
||||
# TODO(Yizhou): The reason we need to set options while vllm does not
|
||||
# seems to be related to the version of PyTorch. In the latest version,
|
||||
# there is no need to set options. While in the older version, 2.5.1
|
||||
# specifically, we need to set options.
|
||||
options = ProcessGroup.Options(backend=backend)
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
options,
|
||||
)
|
||||
|
||||
backend_options = ProcessGroupHCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
|
||||
backend_options)
|
||||
device = torch.device("npu")
|
||||
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
|
||||
# implemented in the 2.5.1 version of PyTorch. But we need to set it
|
||||
# after the latest version is released.
|
||||
# pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
backend_type = ProcessGroup.BackendType.CUSTOM
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
Reference in New Issue
Block a user