diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index 22cc7a00..306a63e3 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa @@ -27,13 +25,11 @@ if not is_310p(): else: import vllm_ascend.patch.platform.patch_mamba_config_310 # noqa import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa +import vllm_ascend.patch.platform.patch_multiproc_executor # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa import vllm_ascend.patch.platform.patch_torch_accelerator # noqa import vllm_ascend.patch.platform.patch_minimax_usage_accounting # noqa import vllm_ascend.patch.platform.patch_glm_tool_call_parser # noqa -if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true": - import vllm_ascend.patch.platform.patch_multiproc_executor # noqa - if envs.VLLM_ASCEND_BALANCE_SCHEDULING: import vllm_ascend.patch.platform.patch_balance_schedule # noqa diff --git a/vllm_ascend/patch/platform/patch_multiproc_executor.py b/vllm_ascend/patch/platform/patch_multiproc_executor.py index 56e64040..3facc008 100644 --- a/vllm_ascend/patch/platform/patch_multiproc_executor.py +++ b/vllm_ascend/patch/platform/patch_multiproc_executor.py @@ -1,14 +1,20 @@ from __future__ import annotations +import os +import queue import weakref from collections import deque from collections.abc import Callable from multiprocessing.synchronize import Lock as LockType +from threading import Thread import vllm.v1.executor.multiproc_executor from vllm import envs from vllm.config import VllmConfig from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue +from vllm.envs import enable_envs_cache +from vllm.platforms import current_platform +from vllm.tracing import instrument from vllm.utils.network_utils import get_distributed_init_method, get_loopback_ip, get_open_port from vllm.utils.system_utils import get_mp_context from vllm.v1.executor.abstract import FailureCallback @@ -19,6 +25,7 @@ from vllm.v1.executor.multiproc_executor import ( WorkerProc, set_multiprocessing_worker_envs, ) +from vllm.v1.worker.worker_base import WorkerWrapperBase class AscendMultiprocExecutor(MultiprocExecutor): @@ -159,6 +166,79 @@ class AscendMultiprocExecutor(MultiprocExecutor): class AscendWorkerProc(WorkerProc): + @instrument(span_name="Worker init") + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle: Handle, + shared_worker_lock: LockType, + is_driver_worker: bool, + ): + self.rank = rank + wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank) + # TODO: move `init_worker` to executor level as a collective rpc call + all_kwargs: list[dict] = [{} for _ in range(vllm_config.parallel_config.world_size)] + all_kwargs[local_rank] = { + "vllm_config": vllm_config, + "local_rank": local_rank, + "rank": rank, + "distributed_init_method": distributed_init_method, + "is_driver_worker": is_driver_worker, + "shared_worker_lock": shared_worker_lock, + } + wrapper.init_worker(all_kwargs) + self.worker = wrapper + + self.setup_proc_title_and_log_prefix(enable_ep=vllm_config.parallel_config.enable_expert_parallel) + + # Load model + is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH + if not is_eep_new_worker: + self.worker.init_device() + # Update process title now that parallel groups are initialized + self.setup_proc_title_and_log_prefix(enable_ep=vllm_config.parallel_config.enable_expert_parallel) + self.worker.load_model() + + scheduler_config = vllm_config.scheduler_config + self.use_async_scheduling = scheduler_config.async_scheduling + if self.use_async_scheduling: + self.async_output_queue: queue.Queue = queue.Queue() + self.async_output_copy_thread = Thread( + target=self.async_output_busy_loop, + daemon=True, + name="WorkerAsyncOutputCopy", + ) + self.async_output_copy_thread.start() + + # Set block size based on the attention backends + current_platform.update_block_size_for_backend(vllm_config) + + # Initialize message queues after init_device() since multi-node setups + # (nnodes_within_dp > 1) require distributed groups to be initialized + self._init_message_queues(input_shm_handle, vllm_config) + + # Enable environment variable cache (e.g. assume no more + # environment variable overrides after this point) + enable_envs_cache() + + @staticmethod + def worker_main(*args, **kwargs): + from vllm_ascend.utils import adapt_patch + + adapt_patch(is_global_patch=True) + WorkerProc.worker_main(*args, **kwargs) + + def async_output_busy_loop(self): + """Entrypoint for the thread which handles outputs asynchronously.""" + if hasattr(self.worker, "device"): + current_platform.set_device(self.worker.device) + while True: + output = self.async_output_queue.get() + self.enqueue_output(output) + @staticmethod def make_worker_process( vllm_config: VllmConfig, @@ -192,11 +272,15 @@ class AscendWorkerProc(WorkerProc): "inherited_fds": inherited_fds if inherited_fds is not None else [], } # Run EngineCore busy loop in background process. + daemon_mode = not ( + os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") + or os.getenv("EXPERT_MAP_RECORD", "false") == "true" + ) proc = context.Process( - target=WorkerProc.worker_main, + target=AscendWorkerProc.worker_main, kwargs=process_kwargs, name=f"VllmWorker-{rank}", - daemon=False, + daemon=daemon_mode, ) proc.start() @@ -209,3 +293,4 @@ class AscendWorkerProc(WorkerProc): vllm.v1.executor.multiproc_executor.MultiprocExecutor = AscendMultiprocExecutor +vllm.v1.executor.multiproc_executor.WorkerProc = AscendWorkerProc