[bugfix]fix extra npu context in device 0 (#8041)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? When we launch a PD-disaggregated process and send requests, an additional processes appear on NPU 0, becasue when a thread has a primary cuda context, the child thread it creates automatically doesn't inherit the cuda context. See https://forums.developer.nvidia.com/t/when-a-thread-has-a-primary-cuda-context-does-the-child-thread-it-creates-automatically-inherit-the-cuda-context/362810. vLLM has fixed this issue in [pr-37449 ](https://github.com/vllm-project/vllm/pull/37449), but version 0.18.0 does not include the fix. Therefore, we need to patch it. <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? no <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: zouyida <zouyida@huawei.com> Co-authored-by: zouyida <zouyida@huawei.com>
This commit is contained in:
@@ -14,8 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import vllm_ascend.patch.platform.patch_distributed # noqa
|
import vllm_ascend.patch.platform.patch_distributed # noqa
|
||||||
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
|
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
|
||||||
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
|
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
|
||||||
@@ -27,13 +25,11 @@ if not is_310p():
|
|||||||
else:
|
else:
|
||||||
import vllm_ascend.patch.platform.patch_mamba_config_310 # noqa
|
import vllm_ascend.patch.platform.patch_mamba_config_310 # noqa
|
||||||
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
|
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
|
||||||
|
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa
|
||||||
import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
||||||
import vllm_ascend.patch.platform.patch_torch_accelerator # noqa
|
import vllm_ascend.patch.platform.patch_torch_accelerator # noqa
|
||||||
import vllm_ascend.patch.platform.patch_minimax_usage_accounting # noqa
|
import vllm_ascend.patch.platform.patch_minimax_usage_accounting # noqa
|
||||||
import vllm_ascend.patch.platform.patch_glm_tool_call_parser # noqa
|
import vllm_ascend.patch.platform.patch_glm_tool_call_parser # noqa
|
||||||
|
|
||||||
if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true":
|
|
||||||
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa
|
|
||||||
|
|
||||||
if envs.VLLM_ASCEND_BALANCE_SCHEDULING:
|
if envs.VLLM_ASCEND_BALANCE_SCHEDULING:
|
||||||
import vllm_ascend.patch.platform.patch_balance_schedule # noqa
|
import vllm_ascend.patch.platform.patch_balance_schedule # noqa
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
import weakref
|
import weakref
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from multiprocessing.synchronize import Lock as LockType
|
from multiprocessing.synchronize import Lock as LockType
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
import vllm.v1.executor.multiproc_executor
|
import vllm.v1.executor.multiproc_executor
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
|
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
|
||||||
|
from vllm.envs import enable_envs_cache
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.tracing import instrument
|
||||||
from vllm.utils.network_utils import get_distributed_init_method, get_loopback_ip, get_open_port
|
from vllm.utils.network_utils import get_distributed_init_method, get_loopback_ip, get_open_port
|
||||||
from vllm.utils.system_utils import get_mp_context
|
from vllm.utils.system_utils import get_mp_context
|
||||||
from vllm.v1.executor.abstract import FailureCallback
|
from vllm.v1.executor.abstract import FailureCallback
|
||||||
@@ -19,6 +25,7 @@ from vllm.v1.executor.multiproc_executor import (
|
|||||||
WorkerProc,
|
WorkerProc,
|
||||||
set_multiprocessing_worker_envs,
|
set_multiprocessing_worker_envs,
|
||||||
)
|
)
|
||||||
|
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
|
|
||||||
class AscendMultiprocExecutor(MultiprocExecutor):
|
class AscendMultiprocExecutor(MultiprocExecutor):
|
||||||
@@ -159,6 +166,79 @@ class AscendMultiprocExecutor(MultiprocExecutor):
|
|||||||
|
|
||||||
|
|
||||||
class AscendWorkerProc(WorkerProc):
|
class AscendWorkerProc(WorkerProc):
|
||||||
|
@instrument(span_name="Worker init")
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
input_shm_handle: Handle,
|
||||||
|
shared_worker_lock: LockType,
|
||||||
|
is_driver_worker: bool,
|
||||||
|
):
|
||||||
|
self.rank = rank
|
||||||
|
wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank)
|
||||||
|
# TODO: move `init_worker` to executor level as a collective rpc call
|
||||||
|
all_kwargs: list[dict] = [{} for _ in range(vllm_config.parallel_config.world_size)]
|
||||||
|
all_kwargs[local_rank] = {
|
||||||
|
"vllm_config": vllm_config,
|
||||||
|
"local_rank": local_rank,
|
||||||
|
"rank": rank,
|
||||||
|
"distributed_init_method": distributed_init_method,
|
||||||
|
"is_driver_worker": is_driver_worker,
|
||||||
|
"shared_worker_lock": shared_worker_lock,
|
||||||
|
}
|
||||||
|
wrapper.init_worker(all_kwargs)
|
||||||
|
self.worker = wrapper
|
||||||
|
|
||||||
|
self.setup_proc_title_and_log_prefix(enable_ep=vllm_config.parallel_config.enable_expert_parallel)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
|
||||||
|
if not is_eep_new_worker:
|
||||||
|
self.worker.init_device()
|
||||||
|
# Update process title now that parallel groups are initialized
|
||||||
|
self.setup_proc_title_and_log_prefix(enable_ep=vllm_config.parallel_config.enable_expert_parallel)
|
||||||
|
self.worker.load_model()
|
||||||
|
|
||||||
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
self.use_async_scheduling = scheduler_config.async_scheduling
|
||||||
|
if self.use_async_scheduling:
|
||||||
|
self.async_output_queue: queue.Queue = queue.Queue()
|
||||||
|
self.async_output_copy_thread = Thread(
|
||||||
|
target=self.async_output_busy_loop,
|
||||||
|
daemon=True,
|
||||||
|
name="WorkerAsyncOutputCopy",
|
||||||
|
)
|
||||||
|
self.async_output_copy_thread.start()
|
||||||
|
|
||||||
|
# Set block size based on the attention backends
|
||||||
|
current_platform.update_block_size_for_backend(vllm_config)
|
||||||
|
|
||||||
|
# Initialize message queues after init_device() since multi-node setups
|
||||||
|
# (nnodes_within_dp > 1) require distributed groups to be initialized
|
||||||
|
self._init_message_queues(input_shm_handle, vllm_config)
|
||||||
|
|
||||||
|
# Enable environment variable cache (e.g. assume no more
|
||||||
|
# environment variable overrides after this point)
|
||||||
|
enable_envs_cache()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def worker_main(*args, **kwargs):
|
||||||
|
from vllm_ascend.utils import adapt_patch
|
||||||
|
|
||||||
|
adapt_patch(is_global_patch=True)
|
||||||
|
WorkerProc.worker_main(*args, **kwargs)
|
||||||
|
|
||||||
|
def async_output_busy_loop(self):
|
||||||
|
"""Entrypoint for the thread which handles outputs asynchronously."""
|
||||||
|
if hasattr(self.worker, "device"):
|
||||||
|
current_platform.set_device(self.worker.device)
|
||||||
|
while True:
|
||||||
|
output = self.async_output_queue.get()
|
||||||
|
self.enqueue_output(output)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_worker_process(
|
def make_worker_process(
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
@@ -192,11 +272,15 @@ class AscendWorkerProc(WorkerProc):
|
|||||||
"inherited_fds": inherited_fds if inherited_fds is not None else [],
|
"inherited_fds": inherited_fds if inherited_fds is not None else [],
|
||||||
}
|
}
|
||||||
# Run EngineCore busy loop in background process.
|
# Run EngineCore busy loop in background process.
|
||||||
|
daemon_mode = not (
|
||||||
|
os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1")
|
||||||
|
or os.getenv("EXPERT_MAP_RECORD", "false") == "true"
|
||||||
|
)
|
||||||
proc = context.Process(
|
proc = context.Process(
|
||||||
target=WorkerProc.worker_main,
|
target=AscendWorkerProc.worker_main,
|
||||||
kwargs=process_kwargs,
|
kwargs=process_kwargs,
|
||||||
name=f"VllmWorker-{rank}",
|
name=f"VllmWorker-{rank}",
|
||||||
daemon=False,
|
daemon=daemon_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
proc.start()
|
proc.start()
|
||||||
@@ -209,3 +293,4 @@ class AscendWorkerProc(WorkerProc):
|
|||||||
|
|
||||||
|
|
||||||
vllm.v1.executor.multiproc_executor.MultiprocExecutor = AscendMultiprocExecutor
|
vllm.v1.executor.multiproc_executor.MultiprocExecutor = AscendMultiprocExecutor
|
||||||
|
vllm.v1.executor.multiproc_executor.WorkerProc = AscendWorkerProc
|
||||||
|
|||||||
Reference in New Issue
Block a user