diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 39f4f646..ebf2cdf7 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -95,12 +95,7 @@ class NPUWorker(WorkerBase): from vllm_ascend.utils import adapt_patch adapt_patch() - # Import _inductor for graph mode execution with triton - # This lazy import avoids torch_npu re-initialization in patch - from vllm.triton_utils import HAS_TRITON - if HAS_TRITON: - import torch_npu._inductor # noqa: F401 # Register ops when worker init. from vllm_ascend import ops @@ -253,6 +248,15 @@ class NPUWorker(WorkerBase): device = torch.device(f"npu:{self.local_rank}") torch.npu.set_device(device) + # Import _inductor for graph mode execution with triton + # This lazy import avoids torch_npu re-initialization in patch + # Note that this should be imported after torch.npu.set_device + # to avoid repeated set_device in extra processes + from vllm.triton_utils import HAS_TRITON + + if HAS_TRITON: + import torch_npu._inductor # noqa: F401 + gc.collect() torch.npu.empty_cache()