[BugFix] Fix problem of extra processes on rank0 device (#7107)
### What this PR does / why we need it?
Currently when tp>1, we have extra processes on tp rank0 device which
consumes extra HBM memory. This is caused by `import
torch_npu._inductor` before set_device which introduces extra
initialization of device.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
All ci passed.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -95,12 +95,7 @@ class NPUWorker(WorkerBase):
|
|||||||
from vllm_ascend.utils import adapt_patch
|
from vllm_ascend.utils import adapt_patch
|
||||||
|
|
||||||
adapt_patch()
|
adapt_patch()
|
||||||
# Import _inductor for graph mode execution with triton
|
|
||||||
# This lazy import avoids torch_npu re-initialization in patch
|
|
||||||
from vllm.triton_utils import HAS_TRITON
|
|
||||||
|
|
||||||
if HAS_TRITON:
|
|
||||||
import torch_npu._inductor # noqa: F401
|
|
||||||
# Register ops when worker init.
|
# Register ops when worker init.
|
||||||
from vllm_ascend import ops
|
from vllm_ascend import ops
|
||||||
|
|
||||||
@@ -253,6 +248,15 @@ class NPUWorker(WorkerBase):
|
|||||||
device = torch.device(f"npu:{self.local_rank}")
|
device = torch.device(f"npu:{self.local_rank}")
|
||||||
torch.npu.set_device(device)
|
torch.npu.set_device(device)
|
||||||
|
|
||||||
|
# Import _inductor for graph mode execution with triton
|
||||||
|
# This lazy import avoids torch_npu re-initialization in patch
|
||||||
|
# Note that this should be imported after torch.npu.set_device
|
||||||
|
# to avoid repeated set_device in extra processes
|
||||||
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
|
if HAS_TRITON:
|
||||||
|
import torch_npu._inductor # noqa: F401
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user