[BugFix] Fix data parallel (#940)

### What this PR does / why we need it?
With this PR, we can migrate to the native `data_parallel.py` in vllm
examples and remove the version in vllm-ascend.

At present, `ASCEND_RT_VISIBLE_DEVICES` introduces considerable
difficulties; therefore, we must employ a temporary workaround and
manually specify the device.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-06-09 14:08:18 +08:00
committed by GitHub
parent eec6068187
commit 6003afa6d2
5 changed files with 191 additions and 115 deletions

View File

@@ -14,3 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import vllm_ascend.patch.platform.patch_0_9_0.patch_distributed # noqa

View File

@@ -0,0 +1,116 @@
import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous
from vllm.distributed import utils
def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int,
backend: str) -> ProcessGroup:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
some operations such as `allreduce`, because it does not depend on the
global rank. However, some operations such as `broadcast` cannot be used
because it depends on the global rank.
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
This function is useful when we are not sure about the total number of
processes in the process group. For example, we may have process
1, 2, ..., 8 who want to communicate, and process 9 might be the same
process as process 1, or it might be a different process; process 10
might be the same process as process 5, or it might be a different process.
In this case, how can we reliably form a communication channel within
process 9 and 10, without affecting the communication channel within
process 1, 2, ..., 8?
One possible solution is to figure out if process 9 and 10 are the same
as process 1 and 5 beforehand, and then form a communication channel
based on the information, adjusting the ranks and world_size etc. However,
figuring out the information is not always easy, and it will interfere
with the main communication channel.
Our solution is to always form a communication channel with process 1, 2,
..., 8, and then use this function to form another communication channel
with process 9 and 10. This way, regardless of whether process 9 and 10
are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
"""
init_method = f"tcp://{host}:{port}"
backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
group_rank = rank
group_size = world_size
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
# TODO(Yizhou): The reason we need to set options while vllm does not
# seems to be related to the version of PyTorch. In the latest version,
# there is no need to set options. While in the older version, 2.5.1
# specifically, we need to set options.
options = ProcessGroup.Options(backend=backend)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
elif backend == "hccl":
from torch.distributed import is_hccl_available
assert is_hccl_available()
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
backend_options = ProcessGroupHCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
backend_options)
device = torch.device("npu")
backend_class._set_sequence_number_for_group()
backend_type = ProcessGroup.BackendType.CUSTOM
pg._register_backend(device, backend_type, backend_class)
return pg
else:
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
# implemented in the 2.5.1 version of PyTorch. But we need to set it
# after the latest version is released.
# pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
utils.stateless_init_torch_distributed_process_group = stateless_init_torch_distributed_process_group

View File

@@ -17,16 +17,14 @@
# Adapted from vllm/model_executor/models/qwen2_vl.py
# This file is a part of the vllm-ascend project.
import torch
import vllm
import vllm.distributed
import vllm.envs as envs
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous
from vllm.config import ParallelConfig
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed.utils import \
stateless_init_torch_distributed_process_group
from vllm.v1.engine.core import DPEngineCoreProc
def ascend_destroy_model_parallel():
@@ -48,112 +46,6 @@ def ascend_destroy_model_parallel():
destory_ascend_model_parallel()
def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int,
backend: str) -> ProcessGroup:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
some operations such as `allreduce`, because it does not depend on the
global rank. However, some operations such as `broadcast` cannot be used
because it depends on the global rank.
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
This function is useful when we are not sure about the total number of
processes in the process group. For example, we may have process
1, 2, ..., 8 who want to communicate, and process 9 might be the same
process as process 1, or it might be a different process; process 10
might be the same process as process 5, or it might be a different process.
In this case, how can we reliably form a communication channel within
process 9 and 10, without affecting the communication channel within
process 1, 2, ..., 8?
One possible solution is to figure out if process 9 and 10 are the same
as process 1 and 5 beforehand, and then form a communication channel
based on the information, adjusting the ranks and world_size etc. However,
figuring out the information is not always easy, and it will interfere
with the main communication channel.
Our solution is to always form a communication channel with process 1, 2,
..., 8, and then use this function to form another communication channel
with process 9 and 10. This way, regardless of whether process 9 and 10
are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
"""
init_method = f"tcp://{host}:{port}"
backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
group_rank = rank
group_size = world_size
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
# TODO(Yizhou): The reason we need to set options while vllm does not
# seems to be related to the version of PyTorch. In the latest version,
# there is no need to set options. While in the older version, 2.5.1
# specifically, we need to set options.
options = ProcessGroup.Options(backend=backend)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
elif backend == "hccl":
from torch.distributed import is_hccl_available
assert is_hccl_available()
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
backend_options = ProcessGroupHCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
backend_options)
device = torch.device("npu")
backend_class._set_sequence_number_for_group()
backend_type = ProcessGroup.BackendType.CUSTOM
pg._register_backend(device, backend_type, backend_class)
return pg
else:
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
# implemented in the 2.5.1 version of PyTorch. But we need to set it
# after the latest version is released.
# pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
def parallel_config_get_dp_port(self) -> int:
"""
We might need to initialize process groups in multiple
@@ -171,7 +63,7 @@ def parallel_config_get_dp_port(self) -> int:
return port
def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
def stateless_init_dp_group(self) -> "ProcessGroup":
# TODO(Yizhou): Currently we have to set the backend to gloo
# because in vllm.config.ParallelConfig.has_unfinished_dp the
# device is set to cpu. We need to fix this in the future.
@@ -187,6 +79,21 @@ def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
return dp_group
def _init_data_parallel(self, vllm_config: VllmConfig):
# Configure NPUs and stateless process group for data parallel.
dp_rank = vllm_config.parallel_config.data_parallel_rank
dp_size = vllm_config.parallel_config.data_parallel_size
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1
assert 0 <= local_dp_rank <= dp_rank < dp_size
self.local_dp_rank = local_dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.current_wave = 0
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
DPEngineCoreProc._init_data_parallel = _init_data_parallel
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group
ParallelConfig.stateless_init_dp_group = stateless_init_dp_group

View File

@@ -18,10 +18,13 @@
import gc
import logging
import os
from datetime import timedelta
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import vllm.envs as envs
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import PrefixStore
from vllm.logger import logger
from vllm.platforms import Platform, PlatformEnum
@@ -262,3 +265,45 @@ class NPUPlatform(Platform):
Get piecewise backend class for piecewise graph.
"""
return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
from torch.distributed import is_hccl_available
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
assert is_hccl_available()
# TODO(Yizhou): The reason we need to set options while vllm does not
# seems to be related to the version of PyTorch. In the latest version,
# there is no need to set options. While in the older version, 2.5.1
# specifically, we need to set options.
options = ProcessGroup.Options(backend=backend)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
backend_options = ProcessGroupHCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
backend_options)
device = torch.device("npu")
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
# implemented in the 2.5.1 version of PyTorch. But we need to set it
# after the latest version is released.
# pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
backend_type = ProcessGroup.BackendType.CUSTOM
pg._register_backend(device, backend_type, backend_class)
return pg

View File

@@ -74,6 +74,13 @@ class NPUWorker(WorkerBase):
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker)
# NOTE(Yizhou): Since we do not set ASCEND_RT_VISIBLE_DEVICES in
# vllm_ascend, we need to set the device id manually.
local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local
world_size = self.vllm_config.parallel_config.world_size
self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank
# Try to import mindie_turbo to accelerate vLLM inference.
try_register_lib(
"mindie_turbo",
@@ -112,7 +119,7 @@ class NPUWorker(WorkerBase):
def init_device(self):
if self.device_config.device.type == "npu":
self.device = torch.device(f"npu:{self.local_rank}")
self.device = torch.device(f"npu:{self.local_rank_across_dp}")
NPUPlatform.set_device(self.device)
NPUPlatform.empty_cache()
self.init_npu_memory = NPUPlatform.mem_get_info()[0]