184 lines
7.1 KiB
Python
184 lines
7.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import os
|
|
from collections.abc import Callable
|
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
from functools import cached_property
|
|
from multiprocessing import Lock
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
import vllm.envs as envs
|
|
from vllm.logger import init_logger
|
|
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
|
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
|
from vllm.v1.executor.abstract import Executor
|
|
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
|
|
from vllm.v1.serial_utils import run_method
|
|
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class UniProcExecutor(Executor):
|
|
def _init_executor(self) -> None:
|
|
"""Initialize the worker and load the model."""
|
|
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
|
|
distributed_init_method, rank, local_rank = self._distributed_args()
|
|
kwargs = dict(
|
|
vllm_config=self.vllm_config,
|
|
local_rank=local_rank,
|
|
rank=rank,
|
|
distributed_init_method=distributed_init_method,
|
|
is_driver_worker=True,
|
|
shared_worker_lock=Lock(),
|
|
)
|
|
|
|
self.async_output_thread: ThreadPoolExecutor | None = None
|
|
if self.max_concurrent_batches > 1:
|
|
self.async_output_thread = ThreadPoolExecutor(
|
|
max_workers=1, thread_name_prefix="WorkerAsyncOutput"
|
|
)
|
|
|
|
self.driver_worker.init_worker(all_kwargs=[kwargs])
|
|
self.driver_worker.init_device()
|
|
self.driver_worker.load_model()
|
|
|
|
def _distributed_args(self) -> tuple[str, int, int]:
|
|
"""Return (distributed_init_method, rank, local_rank)."""
|
|
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
|
|
# set local rank as the device index if specified
|
|
device_info = self.vllm_config.device_config.device.__str__().split(":")
|
|
local_rank = int(device_info[1]) if len(device_info) > 1 else 0
|
|
return distributed_init_method, 0, local_rank
|
|
|
|
@cached_property
|
|
def max_concurrent_batches(self) -> int:
|
|
return 2 if self.scheduler_config.async_scheduling else 1
|
|
|
|
def collective_rpc( # type: ignore[override]
|
|
self,
|
|
method: str | Callable,
|
|
timeout: float | None = None,
|
|
args: tuple = (),
|
|
kwargs: dict | None = None,
|
|
non_block: bool = False,
|
|
single_value: bool = False,
|
|
) -> Any | list[Any] | Future[Any | list[Any]]:
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
if not non_block:
|
|
result = run_method(self.driver_worker, method, args, kwargs)
|
|
return result if single_value else [result]
|
|
|
|
try:
|
|
result = run_method(self.driver_worker, method, args, kwargs)
|
|
if isinstance(result, AsyncModelRunnerOutput):
|
|
if (async_thread := self.async_output_thread) is not None:
|
|
get_output = result.get_output
|
|
if not single_value:
|
|
get_output = lambda go=result.get_output: [go()]
|
|
return async_thread.submit(get_output)
|
|
result = result.get_output()
|
|
future = Future[Any]()
|
|
future.set_result(result if single_value else [result])
|
|
except Exception as e:
|
|
future = Future[Any]()
|
|
future.set_exception(e)
|
|
return future
|
|
|
|
def execute_model( # type: ignore[override]
|
|
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
|
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
|
return self.collective_rpc(
|
|
"execute_model",
|
|
args=(scheduler_output,),
|
|
non_block=non_block,
|
|
single_value=True,
|
|
)
|
|
|
|
def sample_tokens( # type: ignore[override]
|
|
self, grammar_output: GrammarOutput | None, non_block: bool = False
|
|
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
|
return self.collective_rpc(
|
|
"sample_tokens",
|
|
args=(grammar_output,),
|
|
non_block=non_block,
|
|
single_value=True,
|
|
)
|
|
|
|
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
|
return self.collective_rpc("take_draft_token_ids", single_value=True)
|
|
|
|
def check_health(self) -> None:
|
|
# UniProcExecutor will always be healthy as long as
|
|
# it's running.
|
|
return
|
|
|
|
def reinitialize_distributed(
|
|
self, reconfig_request: ReconfigureDistributedRequest
|
|
) -> None:
|
|
self.driver_worker.reinitialize_distributed(reconfig_request)
|
|
if (
|
|
reconfig_request.new_data_parallel_rank
|
|
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
|
):
|
|
self.shutdown()
|
|
|
|
def shutdown(self) -> None:
|
|
if worker := self.driver_worker:
|
|
worker.shutdown()
|
|
|
|
|
|
class ExecutorWithExternalLauncher(UniProcExecutor):
|
|
"""An executor that uses external launchers to launch engines,
|
|
specially designed for torchrun-compatible launchers, for
|
|
offline inference with tensor parallelism.
|
|
|
|
see https://github.com/vllm-project/vllm/issues/11400 for
|
|
the motivation, and examples/offline_inference/torchrun_example.py
|
|
for the usage example.
|
|
|
|
The key idea: although it is tensor-parallel inference, we only
|
|
create one worker per executor, users will launch multiple
|
|
engines with torchrun-compatible launchers, and all these engines
|
|
work together to process the same prompts. When scheduling is
|
|
deterministic, all the engines will generate the same outputs,
|
|
and they don't need to synchronize the states with each other.
|
|
"""
|
|
|
|
def _init_executor(self) -> None:
|
|
"""Initialize the worker and load the model."""
|
|
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
|
|
"To get deterministic execution, "
|
|
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0"
|
|
)
|
|
super()._init_executor()
|
|
|
|
def _distributed_args(self) -> tuple[str, int, int]:
|
|
# engines are launched in torchrun-compatible launchers
|
|
# so we can use the env:// method.
|
|
# required env vars:
|
|
# - RANK
|
|
# - LOCAL_RANK
|
|
# - MASTER_ADDR
|
|
# - MASTER_PORT
|
|
distributed_init_method = "env://"
|
|
rank = int(os.environ["RANK"])
|
|
local_rank = int(os.environ["LOCAL_RANK"])
|
|
return distributed_init_method, rank, local_rank
|
|
|
|
def determine_available_memory(self) -> list[int]: # in bytes
|
|
# we need to get the min across all ranks.
|
|
memory = super().determine_available_memory()
|
|
from vllm.distributed.parallel_state import get_world_group
|
|
|
|
cpu_group = get_world_group().cpu_group
|
|
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
|
|
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
|
return [memory_tensor.item()]
|