Sync from v0.13
This commit is contained in:
414
vllm/v1/utils.py
Normal file
414
vllm/v1/utils.py
Normal file
@@ -0,0 +1,414 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import contextlib
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import AbstractContextManager
|
||||
from multiprocessing import connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler import record_function
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
|
||||
from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri
|
||||
from vllm.utils.system_utils import kill_process_tree
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ConstantList(Generic[T], Sequence):
|
||||
def __init__(self, x: list[T]) -> None:
|
||||
self._x = x
|
||||
|
||||
def append(self, item):
|
||||
raise TypeError("Cannot append to a constant list")
|
||||
|
||||
def extend(self, item):
|
||||
raise TypeError("Cannot extend a constant list")
|
||||
|
||||
def insert(self, item):
|
||||
raise TypeError("Cannot insert into a constant list")
|
||||
|
||||
def pop(self, item):
|
||||
raise TypeError("Cannot pop from a constant list")
|
||||
|
||||
def remove(self, item):
|
||||
raise TypeError("Cannot remove from a constant list")
|
||||
|
||||
def clear(self):
|
||||
raise TypeError("Cannot clear a constant list")
|
||||
|
||||
def index(self, item: T, start: int = 0, stop: int | None = None) -> int:
|
||||
return self._x.index(item, start, stop if stop is not None else len(self._x))
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: int) -> T: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice, /) -> list[T]: ...
|
||||
|
||||
def __getitem__(self, item: int | slice) -> T | list[T]:
|
||||
return self._x[item]
|
||||
|
||||
@overload
|
||||
def __setitem__(self, item: int, value: T): ...
|
||||
|
||||
@overload
|
||||
def __setitem__(self, s: slice, value: T, /): ...
|
||||
|
||||
def __setitem__(self, item: int | slice, value: T | list[T]):
|
||||
raise TypeError("Cannot set item in a constant list")
|
||||
|
||||
def __delitem__(self, item):
|
||||
raise TypeError("Cannot delete item from a constant list")
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._x)
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._x
|
||||
|
||||
def __len__(self):
|
||||
return len(self._x)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ConstantList({self._x})"
|
||||
|
||||
def copy(self) -> list[T]:
|
||||
return self._x.copy()
|
||||
|
||||
|
||||
class CpuGpuBuffer:
|
||||
"""Buffer to easily copy tensors between CPU and GPU."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*size: int | torch.SymInt,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
with_numpy: bool = True,
|
||||
) -> None:
|
||||
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=pin_memory)
|
||||
self.gpu = torch.zeros_like(self.cpu, device=device)
|
||||
self.np: np.ndarray
|
||||
# To keep type hints simple (avoiding generics and subclasses), we
|
||||
# only conditionally create the numpy array attribute. This can cause
|
||||
# AttributeError if `self.np` is accessed when `with_numpy=False`.
|
||||
if with_numpy:
|
||||
if dtype == torch.bfloat16:
|
||||
raise ValueError(
|
||||
"Bfloat16 torch tensors cannot be directly cast to a "
|
||||
"numpy array, so call CpuGpuBuffer with with_numpy=False"
|
||||
)
|
||||
self.np = self.cpu.numpy()
|
||||
|
||||
def copy_to_gpu(self, n: int | None = None) -> torch.Tensor:
|
||||
if n is None:
|
||||
return self.gpu.copy_(self.cpu, non_blocking=True)
|
||||
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
|
||||
|
||||
def copy_to_cpu(self, n: int | None = None) -> torch.Tensor:
|
||||
"""NOTE: Because this method is non-blocking, explicit synchronization
|
||||
is needed to ensure the data is copied to CPU."""
|
||||
if n is None:
|
||||
return self.cpu.copy_(self.gpu, non_blocking=True)
|
||||
return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)
|
||||
|
||||
|
||||
def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str:
|
||||
"""Assign a new ZMQ socket address.
|
||||
|
||||
If local_only is True, participants are colocated and so a unique IPC
|
||||
address will be returned.
|
||||
|
||||
Otherwise, the provided host and port will be used to construct a TCP
|
||||
address (port == 0 means assign an available port)."""
|
||||
|
||||
return (
|
||||
get_open_zmq_ipc_path()
|
||||
if local_only
|
||||
else (get_tcp_uri(host, port or get_open_port()))
|
||||
)
|
||||
|
||||
|
||||
class APIServerProcessManager:
|
||||
"""Manages a group of API server processes.
|
||||
|
||||
Handles creation, monitoring, and termination of API server worker
|
||||
processes. Also monitors extra processes to check if they are healthy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_server_fn: Callable,
|
||||
listen_address: str,
|
||||
sock: Any,
|
||||
args: argparse.Namespace,
|
||||
num_servers: int,
|
||||
input_addresses: list[str],
|
||||
output_addresses: list[str],
|
||||
stats_update_address: str | None = None,
|
||||
):
|
||||
"""Initialize and start API server worker processes.
|
||||
|
||||
Args:
|
||||
target_server_fn: Function to call for each API server process
|
||||
listen_address: Address to listen for client connections
|
||||
sock: Socket for client connections
|
||||
args: Command line arguments
|
||||
num_servers: Number of API server processes to start
|
||||
input_addresses: Input addresses for each API server
|
||||
output_addresses: Output addresses for each API server
|
||||
stats_update_address: Optional stats update address
|
||||
"""
|
||||
self.listen_address = listen_address
|
||||
self.sock = sock
|
||||
self.args = args
|
||||
|
||||
# Start API servers
|
||||
spawn_context = multiprocessing.get_context("spawn")
|
||||
self.processes: list[BaseProcess] = []
|
||||
|
||||
for i, in_addr, out_addr in zip(
|
||||
range(num_servers), input_addresses, output_addresses
|
||||
):
|
||||
client_config = {
|
||||
"input_address": in_addr,
|
||||
"output_address": out_addr,
|
||||
"client_count": num_servers,
|
||||
"client_index": i,
|
||||
}
|
||||
if stats_update_address is not None:
|
||||
client_config["stats_update_address"] = stats_update_address
|
||||
|
||||
proc = spawn_context.Process(
|
||||
target=target_server_fn,
|
||||
name=f"ApiServer_{i}",
|
||||
args=(listen_address, sock, args, client_config),
|
||||
)
|
||||
self.processes.append(proc)
|
||||
proc.start()
|
||||
|
||||
logger.info("Started %d API server processes", len(self.processes))
|
||||
|
||||
# Shutdown only the API server processes on garbage collection
|
||||
# The extra processes are managed by their owners
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||
|
||||
def close(self) -> None:
|
||||
self._finalizer()
|
||||
|
||||
|
||||
def wait_for_completion_or_failure(
|
||||
api_server_manager: APIServerProcessManager,
|
||||
engine_manager: Union["CoreEngineProcManager", "CoreEngineActorManager"]
|
||||
| None = None,
|
||||
coordinator: Optional["DPCoordinator"] = None,
|
||||
) -> None:
|
||||
"""Wait for all processes to complete or detect if any fail.
|
||||
|
||||
Raises an exception if any process exits with a non-zero status.
|
||||
|
||||
Args:
|
||||
api_server_manager: The manager for API servers.
|
||||
engine_manager: The manager for engine processes.
|
||||
If CoreEngineProcManager, it manages local engines;
|
||||
if CoreEngineActorManager, it manages all engines.
|
||||
coordinator: The coordinator for data parallel.
|
||||
"""
|
||||
|
||||
from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
|
||||
|
||||
try:
|
||||
logger.info("Waiting for API servers to complete ...")
|
||||
# Create a mapping of sentinels to their corresponding processes
|
||||
# for efficient lookup
|
||||
sentinel_to_proc: dict[Any, BaseProcess] = {
|
||||
proc.sentinel: proc for proc in api_server_manager.processes
|
||||
}
|
||||
|
||||
if coordinator:
|
||||
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
|
||||
|
||||
actor_run_refs = []
|
||||
if isinstance(engine_manager, CoreEngineProcManager):
|
||||
for proc in engine_manager.processes:
|
||||
sentinel_to_proc[proc.sentinel] = proc
|
||||
elif isinstance(engine_manager, CoreEngineActorManager):
|
||||
actor_run_refs = engine_manager.get_run_refs()
|
||||
|
||||
# Check if any process terminates
|
||||
while sentinel_to_proc or actor_run_refs:
|
||||
# Wait for any process to terminate
|
||||
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5)
|
||||
|
||||
# Process any terminated processes
|
||||
for sentinel in ready_sentinels:
|
||||
proc = sentinel_to_proc.pop(sentinel)
|
||||
|
||||
# Check if process exited with error
|
||||
if proc.exitcode != 0:
|
||||
raise RuntimeError(
|
||||
f"Process {proc.name} (PID: {proc.pid}) "
|
||||
f"died with exit code {proc.exitcode}"
|
||||
)
|
||||
|
||||
if actor_run_refs:
|
||||
import ray
|
||||
|
||||
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received KeyboardInterrupt, shutting down API servers...")
|
||||
except Exception as e:
|
||||
logger.exception("Exception occurred while running API servers: %s", str(e))
|
||||
raise
|
||||
finally:
|
||||
logger.info("Terminating remaining processes ...")
|
||||
api_server_manager.close()
|
||||
if coordinator:
|
||||
coordinator.close()
|
||||
if engine_manager:
|
||||
engine_manager.close()
|
||||
|
||||
|
||||
# Note(rob): shutdown function cannot be a bound method,
|
||||
# else the gc cannot collect the object.
|
||||
def shutdown(procs: list[BaseProcess]):
|
||||
# Shutdown the process.
|
||||
for proc in procs:
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
|
||||
# Allow 5 seconds for remaining procs to terminate.
|
||||
deadline = time.monotonic() + 5
|
||||
for proc in procs:
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if proc.is_alive():
|
||||
proc.join(remaining)
|
||||
|
||||
for proc in procs:
|
||||
if proc.is_alive() and (pid := proc.pid) is not None:
|
||||
kill_process_tree(pid)
|
||||
|
||||
|
||||
def copy_slice(
|
||||
from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Copy the first length elements of a tensor into another tensor in a
|
||||
non-blocking manner.
|
||||
|
||||
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
|
||||
|
||||
Returns the sliced target tensor.
|
||||
"""
|
||||
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
||||
|
||||
|
||||
def report_usage_stats(
|
||||
vllm_config, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT
|
||||
) -> None:
|
||||
"""Report usage statistics if enabled."""
|
||||
|
||||
if not is_usage_stats_enabled():
|
||||
return
|
||||
|
||||
from vllm.model_executor.model_loader import get_architecture_class_name
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
# Prepare KV connector string if applicable
|
||||
kv_connector = None
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
kv_connector = vllm_config.kv_transfer_config.kv_connector
|
||||
|
||||
usage_message.report_usage(
|
||||
get_architecture_class_name(vllm_config.model_config),
|
||||
usage_context,
|
||||
extra_kvs={
|
||||
# Common configuration
|
||||
"dtype": str(vllm_config.model_config.dtype),
|
||||
"block_size": vllm_config.cache_config.block_size,
|
||||
"gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization,
|
||||
"kv_cache_memory_bytes": vllm_config.cache_config.kv_cache_memory_bytes,
|
||||
# Quantization
|
||||
"quantization": vllm_config.model_config.quantization,
|
||||
"kv_cache_dtype": str(vllm_config.cache_config.cache_dtype),
|
||||
# Feature flags
|
||||
"enable_lora": bool(vllm_config.lora_config),
|
||||
"enable_prefix_caching": vllm_config.cache_config.enable_prefix_caching,
|
||||
"enforce_eager": vllm_config.model_config.enforce_eager,
|
||||
"disable_custom_all_reduce": parallel_config.disable_custom_all_reduce,
|
||||
# Distributed parallelism settings
|
||||
"tensor_parallel_size": parallel_config.tensor_parallel_size,
|
||||
"data_parallel_size": parallel_config.data_parallel_size,
|
||||
"pipeline_parallel_size": parallel_config.pipeline_parallel_size,
|
||||
"enable_expert_parallel": parallel_config.enable_expert_parallel,
|
||||
# All2All backend for MoE expert parallel
|
||||
"all2all_backend": parallel_config.all2all_backend,
|
||||
# KV connector used
|
||||
"kv_connector": kv_connector,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
_PROFILER_FUNC = None
|
||||
|
||||
|
||||
def record_function_or_nullcontext(name: str) -> AbstractContextManager:
|
||||
global _PROFILER_FUNC
|
||||
|
||||
# fast path assume it is set
|
||||
if _PROFILER_FUNC is not None:
|
||||
return _PROFILER_FUNC(name)
|
||||
|
||||
func = contextlib.nullcontext
|
||||
if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
|
||||
func = record_function
|
||||
elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING:
|
||||
import nvtx
|
||||
|
||||
func = nvtx.annotate
|
||||
|
||||
_PROFILER_FUNC = func
|
||||
return func(name)
|
||||
|
||||
|
||||
def tensor_data(tensor: torch.Tensor) -> memoryview:
|
||||
"""Get the raw data of a tensor as a uint8 memoryview, useful for
|
||||
serializing and hashing.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor.
|
||||
|
||||
Returns:
|
||||
A memoryview of the tensor data as uint8.
|
||||
"""
|
||||
return tensor.flatten().contiguous().view(torch.uint8).numpy().data
|
||||
Reference in New Issue
Block a user