[BugFix] Fix problem of extra processes on rank0 device (#7107)

### What this PR does / why we need it?
Currently when tp>1, we have extra processes on tp rank0 device which
consumes extra HBM memory. This is caused by `import
torch_npu._inductor` before set_device which introduces extra
initialization of device.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
All ci passed.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
Hexiang Wang
2026-03-12 15:59:03 +08:00
committed by GitHub
parent e5024d0264
commit f244f3c4a9

View File

@@ -95,12 +95,7 @@ class NPUWorker(WorkerBase):
from vllm_ascend.utils import adapt_patch
adapt_patch()
# Import _inductor for graph mode execution with triton
# This lazy import avoids torch_npu re-initialization in patch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
import torch_npu._inductor # noqa: F401
# Register ops when worker init.
from vllm_ascend import ops
@@ -253,6 +248,15 @@ class NPUWorker(WorkerBase):
device = torch.device(f"npu:{self.local_rank}")
torch.npu.set_device(device)
# Import _inductor for graph mode execution with triton
# This lazy import avoids torch_npu re-initialization in patch
# Note that this should be imported after torch.npu.set_device
# to avoid repeated set_device in extra processes
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
import torch_npu._inductor # noqa: F401
gc.collect()
torch.npu.empty_cache()