diff --git a/tests/ut/torchair/test_torchair_worker.py b/tests/ut/torchair/test_torchair_worker.py index 32d5a92e..0397aee1 100644 --- a/tests/ut/torchair/test_torchair_worker.py +++ b/tests/ut/torchair/test_torchair_worker.py @@ -59,6 +59,7 @@ class TestNPUTorchairWorker(TestBase): worker.vllm_config = MagicMock() worker.parallel_config = MagicMock() worker.parallel_config.local_world_size = 0 + worker.parallel_config.data_parallel_size = 1 result = worker._init_device() @@ -93,6 +94,7 @@ class TestNPUTorchairWorker(TestBase): worker.vllm_config = MagicMock() worker.parallel_config = MagicMock() worker.parallel_config.local_world_size = 0 + worker.parallel_config.data_parallel_size = 1 result = worker._init_device() diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index fbc7fdc4..5a12981a 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -329,6 +329,8 @@ class TestNPUWorker(TestBase): worker.model_config = MagicMock() worker.parallel_config = MagicMock() worker.parallel_config.local_world_size = 0 + worker.parallel_config.data_parallel_size = 1 + worker.model_config.seed = 42 # Test _init_device diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index e9000eae..df7fec60 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -208,12 +208,18 @@ class NPUWorker(WorkerBase): NPUPlatform.set_device(device) NPUPlatform.empty_cache() - visible_device_count = (torch.npu.device_count() - if torch.npu.is_available() else 0) - assert self.parallel_config.local_world_size <= visible_device_count, ( - f"local_world_size ({self.parallel_config.local_world_size}) must be " - f"less than or equal to the number of visible devices " - f"({visible_device_count}).") + if (self.parallel_config.data_parallel_size > 1 + and self.parallel_config.data_parallel_size_local > 0 + and self.parallel_config.distributed_executor_backend + not in ["ray", "external_launcher"] and + self.vllm_config.parallel_config.data_parallel_backend != "ray" + and self.vllm_config.parallel_config.nnodes_within_dp == 1): + visible_device_count = (torch.npu.device_count() + if torch.npu.is_available() else 0) + assert self.parallel_config.local_world_size <= visible_device_count, ( + f"local_world_size ({self.parallel_config.local_world_size}) must " + f"be less than or equal to the number of visible devices " + f"({visible_device_count}).") self.init_npu_memory = NPUPlatform.mem_get_info()[0] # Initialize the distributed environment.