diff --git a/vllm_ascend/communicator.py b/vllm_ascend/communicator.py index 543b639..0c43f1f 100644 --- a/vllm_ascend/communicator.py +++ b/vllm_ascend/communicator.py @@ -17,7 +17,6 @@ from typing import Optional import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from vllm.distributed.device_communicators.base_device_communicator import \ DeviceCommunicatorBase @@ -31,6 +30,5 @@ class NPUCommunicator(DeviceCommunicatorBase): device_group: Optional[ProcessGroup] = None, unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) - # init device according to local rank - local_rank = dist.get_rank(device_group) - self.device = torch.device(f"npu:{local_rank}") + # init device according to rank + self.device = torch.npu.current_device()