[BugFix] Fix problem of extra processes on rank0 device (#7107)
### What this PR does / why we need it?
Currently when tp>1, we have extra processes on tp rank0 device which
consumes extra HBM memory. This is caused by `import
torch_npu._inductor` before set_device which introduces extra
initialization of device.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
All ci passed.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -95,12 +95,7 @@ class NPUWorker(WorkerBase):
|
||||
from vllm_ascend.utils import adapt_patch
|
||||
|
||||
adapt_patch()
|
||||
# Import _inductor for graph mode execution with triton
|
||||
# This lazy import avoids torch_npu re-initialization in patch
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
import torch_npu._inductor # noqa: F401
|
||||
# Register ops when worker init.
|
||||
from vllm_ascend import ops
|
||||
|
||||
@@ -253,6 +248,15 @@ class NPUWorker(WorkerBase):
|
||||
device = torch.device(f"npu:{self.local_rank}")
|
||||
torch.npu.set_device(device)
|
||||
|
||||
# Import _inductor for graph mode execution with triton
|
||||
# This lazy import avoids torch_npu re-initialization in patch
|
||||
# Note that this should be imported after torch.npu.set_device
|
||||
# to avoid repeated set_device in extra processes
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
import torch_npu._inductor # noqa: F401
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user