[BugFix] Fix data parallel (#940)

### What this PR does / why we need it?
With this PR, we can migrate to the native `data_parallel.py` in vllm
examples and remove the version in vllm-ascend.

At present, `ASCEND_RT_VISIBLE_DEVICES` introduces considerable
difficulties; therefore, we must employ a temporary workaround and
manually specify the device.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-06-09 14:08:18 +08:00
committed by GitHub
parent eec6068187
commit 6003afa6d2
5 changed files with 191 additions and 115 deletions

View File

@@ -74,6 +74,13 @@ class NPUWorker(WorkerBase):
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker)
# NOTE(Yizhou): Since we do not set ASCEND_RT_VISIBLE_DEVICES in
# vllm_ascend, we need to set the device id manually.
local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local
world_size = self.vllm_config.parallel_config.world_size
self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank
# Try to import mindie_turbo to accelerate vLLM inference.
try_register_lib(
"mindie_turbo",
@@ -112,7 +119,7 @@ class NPUWorker(WorkerBase):
def init_device(self):
if self.device_config.device.type == "npu":
self.device = torch.device(f"npu:{self.local_rank}")
self.device = torch.device(f"npu:{self.local_rank_across_dp}")
NPUPlatform.set_device(self.device)
NPUPlatform.empty_cache()
self.init_npu_memory = NPUPlatform.mem_get_info()[0]