init
This commit is contained in:
0
vllm/v1/worker/__init__.py
Normal file
0
vllm/v1/worker/__init__.py
Normal file
BIN
vllm/v1/worker/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm/v1/worker/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/v1/worker/__pycache__/worker_base.cpython-312.pyc
Normal file
BIN
vllm/v1/worker/__pycache__/worker_base.cpython-312.pyc
Normal file
Binary file not shown.
210
vllm/v1/worker/block_table.py
Normal file
210
vllm/v1/worker/block_table.py
Normal file
@@ -0,0 +1,210 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockTable:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
max_num_reqs: int,
|
||||
max_num_blocks_per_req: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.block_size = block_size
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
self.block_table = self._make_buffer(max_num_reqs,
|
||||
max_num_blocks_per_req,
|
||||
dtype=torch.int32)
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.slot_mapping = self._make_buffer(self.max_num_batched_tokens,
|
||||
dtype=torch.int64)
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
block_ids: list[int],
|
||||
row_idx: int,
|
||||
) -> None:
|
||||
if not block_ids:
|
||||
return
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
self.num_blocks_per_row[row_idx] += num_blocks
|
||||
self.block_table.np[row_idx, start:start + num_blocks] = block_ids
|
||||
|
||||
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
||||
self.num_blocks_per_row[row_idx] = 0
|
||||
self.append_row(block_ids, row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
block_table_np = self.block_table.np
|
||||
block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
|
||||
self.num_blocks_per_row[tgt] = num_blocks
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
src_tgt, tgt_src = [src, tgt], [tgt, src]
|
||||
self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
|
||||
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
|
||||
|
||||
def compute_slot_mapping(self, req_indices: np.ndarray,
|
||||
positions: np.ndarray) -> None:
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
if self.dcp_world_size > 1:
|
||||
# Note(hc): The DCP implement store kvcache with an interleave
|
||||
# style, the kvcache for the token whose token_idx is i is
|
||||
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
||||
|
||||
# Use a "virtual block" which equals to world_size * block_size
|
||||
# for block_table_indices calculation.
|
||||
virtual_block_size = self.block_size * self.dcp_world_size
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // virtual_block_size)
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
|
||||
# Calculate local block_offsets
|
||||
block_offsets = virtual_block_offsets // self.dcp_world_size
|
||||
# Calculate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Write final slots, use -1 for not-local
|
||||
self.slot_mapping.np[:req_indices.shape[0]] = np.where(
|
||||
mask, slot_mapping, -1)
|
||||
else:
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // self.block_size)
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping.np[:req_indices.shape[0]])
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
self.block_table.copy_to_gpu(num_reqs)
|
||||
|
||||
def commit_slot_mapping(self, num_tokens: int) -> None:
|
||||
self.slot_mapping.copy_to_gpu(num_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.block_table.gpu.fill_(0)
|
||||
self.block_table.cpu.fill_(0)
|
||||
|
||||
def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
|
||||
"""Returns the device tensor of the block table."""
|
||||
return self.block_table.gpu[:num_reqs]
|
||||
|
||||
def get_cpu_tensor(self) -> torch.Tensor:
|
||||
"""Returns the CPU tensor of the block table."""
|
||||
return self.block_table.cpu
|
||||
|
||||
def get_numpy_array(self) -> np.ndarray:
|
||||
"""Returns the numpy array of the block table."""
|
||||
return self.block_table.np
|
||||
|
||||
def _make_buffer(self, *size: Union[int, torch.SymInt],
|
||||
dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(*size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
|
||||
class MultiGroupBlockTable:
|
||||
"""The BlockTables for each KV cache group."""
|
||||
|
||||
def __init__(self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
# so the block_size which used for calc max_num_blocks_per_req
|
||||
# must be multiplied by dcp_world_size.
|
||||
try:
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
|
||||
self.block_tables = [
|
||||
BlockTable(
|
||||
block_size, max_num_reqs,
|
||||
max(cdiv(max_model_len, block_size * dcp_world_size),
|
||||
1 + num_speculative_tokens), max_num_batched_tokens,
|
||||
pin_memory, device) for block_size in block_sizes
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: tuple[list[int], ...],
|
||||
row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.append_row(block_ids[i], row_idx)
|
||||
|
||||
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.add_row(block_ids[i], row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.move_row(src, tgt)
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.swap_row(src, tgt)
|
||||
|
||||
def compute_slot_mapping(self, req_indices: np.ndarray,
|
||||
positions: np.ndarray) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.compute_slot_mapping(req_indices, positions)
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit_block_table(num_reqs)
|
||||
|
||||
def commit_slot_mapping(self, num_tokens: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit_slot_mapping(num_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.clear()
|
||||
|
||||
def __getitem__(self, idx: int) -> "BlockTable":
|
||||
"""Returns the BlockTable for the i-th KV cache group."""
|
||||
return self.block_tables[idx]
|
||||
175
vllm/v1/worker/cpu_model_runner.py
Normal file
175
vllm/v1/worker/cpu_model_runner.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUModelRunner(GPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
with _torch_cuda_wrapper():
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
assert device == torch.device("cpu")
|
||||
assert self.speculative_config is None, "spec decode is not supported."
|
||||
|
||||
self.use_cuda_graph = False
|
||||
self.cascade_attn_enabled = False
|
||||
|
||||
self._postprocess_tensors()
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
backend's needs. For example, some attention backends (namely MLA) may
|
||||
want to separate requests based on if the attention computation will be
|
||||
compute-bound or memory-bound.
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output.
|
||||
"""
|
||||
# Attention free models have zero kv_cache_groups, however models
|
||||
# like Mamba are also attention free but use the kv_cache for
|
||||
# keeping its internal state. This is why we check the number
|
||||
# of kv_cache groups instead of solely checking
|
||||
# for self.model_config.is_attention_free.
|
||||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||||
return
|
||||
|
||||
if len(self.kv_cache_config.kv_cache_groups) > 1:
|
||||
raise ValueError("Multiple KVCacheGroups is not"
|
||||
"currently supported with CPU model runner.")
|
||||
|
||||
# Guard against encoder-only / pooling models where `attn_groups`
|
||||
# may be empty or lack the expected metadata_builder.
|
||||
# Without this check, accessing `attn_groups[0][0]` would trigger
|
||||
# an AssertionError on CPU backend.
|
||||
if not hasattr(self, "attn_groups") or not self.attn_groups:
|
||||
return
|
||||
if not self.attn_groups[0]:
|
||||
return
|
||||
|
||||
mb = getattr(self.attn_groups[0][0], "metadata_builders", None)
|
||||
if isinstance(mb, list):
|
||||
if not isinstance(mb[0], TorchSDPAMetadataBuilderV1):
|
||||
return
|
||||
mb[0].reorder_batch(self.input_batch, scheduler_output)
|
||||
return
|
||||
elif not isinstance(mb, TorchSDPAMetadataBuilderV1):
|
||||
# Encoder-only / rerank models do not benefit from reordering,
|
||||
# so we safely skip here.
|
||||
return
|
||||
|
||||
# Safe path for decoder/attention-heavy models
|
||||
mb.reorder_batch(self.input_batch, scheduler_output)
|
||||
|
||||
def _postprocess_tensors(self) -> None:
|
||||
# Note: replace device tensors with cpu tensors
|
||||
def replace_tensor(obj: Any, cpu_attr_name: str,
|
||||
device_attr_name) -> None:
|
||||
cpu_tensor = getattr(obj, cpu_attr_name, None)
|
||||
device_tensor = getattr(obj, device_attr_name, None)
|
||||
if cpu_tensor is not None and device_tensor is not None:
|
||||
assert isinstance(cpu_tensor, torch.Tensor)
|
||||
assert isinstance(device_tensor, torch.Tensor)
|
||||
setattr(obj, device_attr_name, cpu_tensor)
|
||||
|
||||
for v in vars(self).values():
|
||||
if isinstance(v, CpuGpuBuffer):
|
||||
v.gpu = v.cpu
|
||||
|
||||
for k, v in vars(self.input_batch).items():
|
||||
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self.input_batch, k, k[:-11])
|
||||
|
||||
for block_table in self.input_batch.block_table.block_tables:
|
||||
for v in vars(block_table).values():
|
||||
if isinstance(v, CpuGpuBuffer):
|
||||
v.gpu = v.cpu
|
||||
|
||||
def load_model(self, eep_scale_up: bool = False) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model, self.vllm_config,
|
||||
self.device)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def warming_up_model(self) -> None:
|
||||
logger.info("Warming up model for the compilation...")
|
||||
# Only generate graph for the generic shape
|
||||
with _set_global_compilation_settings(self.vllm_config):
|
||||
self._dummy_run(max(16, self.max_num_reqs))
|
||||
logger.info("Warming up done.")
|
||||
|
||||
def _init_device_properties(self) -> None:
|
||||
pass
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
pass
|
||||
|
||||
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
|
||||
return sampled_token_ids.tolist()
|
||||
|
||||
def get_dp_padding(self,
|
||||
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
|
||||
# Note: For CPU backend, dp padding is not required for now.
|
||||
return 0, None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _torch_cuda_wrapper():
|
||||
|
||||
class _EventPlaceholder:
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self.record = lambda: None
|
||||
self.synchronize = lambda: None
|
||||
|
||||
class _StreamPlaceholder:
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
cuda_event = torch.cuda.Event
|
||||
cuda_stream = torch.cuda.Stream
|
||||
try:
|
||||
torch.cuda.Event = _EventPlaceholder
|
||||
torch.cuda.Stream = _StreamPlaceholder
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.Event = cuda_event
|
||||
torch.cuda.Stream = cuda_stream
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_global_compilation_settings(config: VllmConfig):
|
||||
import torch._inductor.config as torch_inductor_config
|
||||
|
||||
inductor_config = config.compilation_config.inductor_compile_config
|
||||
# Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
|
||||
freezing_value = torch_inductor_config.freezing
|
||||
try:
|
||||
if inductor_config.get("max_autotune", False):
|
||||
torch_inductor_config.freezing = True
|
||||
yield
|
||||
finally:
|
||||
torch_inductor_config.freezing = freezing_value
|
||||
156
vllm/v1/worker/cpu_worker.py
Normal file
156
vllm/v1/worker/cpu_worker.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
import platform
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_worker import (Worker,
|
||||
init_worker_distributed_environment)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUWorker(Worker):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False):
|
||||
super().__init__(vllm_config,
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
self.parallel_config.disable_custom_all_reduce = True
|
||||
|
||||
def init_device(self):
|
||||
# Setup OpenMP threads affinity.
|
||||
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
|
||||
if omp_cpuids == "auto" and platform.system() == "Linux":
|
||||
cpu_arch = current_platform.get_cpu_architecture()
|
||||
if cpu_arch in (CpuArchEnum.POWERPC, CpuArchEnum.S390X):
|
||||
# For S390X/POWERPC SMT-8/4/2
|
||||
self.local_omp_cpuid = self._get_autobind_cpu_ids(
|
||||
lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4])
|
||||
elif current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||
# For x86 SMT-2, use 1 CPU per core
|
||||
self.local_omp_cpuid = self._get_autobind_cpu_ids(
|
||||
lambda cpus: cpus[-1:])
|
||||
else:
|
||||
self.local_omp_cpuid = "all"
|
||||
else:
|
||||
local_dp_rank = self.parallel_config.data_parallel_rank_local
|
||||
omp_cpuids = omp_cpuids.split("|")
|
||||
if local_dp_rank is not None:
|
||||
world_size = self.parallel_config.world_size
|
||||
omp_cpuids = omp_cpuids[local_dp_rank *
|
||||
world_size:(local_dp_rank + 1) *
|
||||
world_size]
|
||||
self.local_omp_cpuid = omp_cpuids[self.rank]
|
||||
|
||||
if self.local_omp_cpuid != "all":
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
# Note: unique identifier for creating allreduce shared memory
|
||||
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
|
||||
":")[-1]
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
current_platform.dist_backend)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner: CPUModelRunner = CPUModelRunner(
|
||||
self.vllm_config, torch.device("cpu"))
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
logger.warning("sleep mode is not supported on CPU, ignore it.")
|
||||
pass
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
logger.warning("sleep mode is not supported on CPU, ignore it.")
|
||||
pass
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
return self.cache_config.cpu_kvcache_space_bytes # type: ignore
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
self.model_runner.warming_up_model()
|
||||
|
||||
def _get_autobind_cpu_ids(
|
||||
self, cpu_selector: Callable[[list[LogicalCPUInfo]],
|
||||
list[LogicalCPUInfo]]
|
||||
) -> str:
|
||||
"""
|
||||
Return CPU ids to bind based on NUMA nodes.
|
||||
Currently for rank N, only CPU ids on the N-th node in available NUMA
|
||||
node list will be selected.
|
||||
Args:
|
||||
cpu_selector: a callable object to select CPUs from a CPU list
|
||||
of a physical core. The input is a LogicalCPUInfo list, sorted by
|
||||
the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be
|
||||
returned.
|
||||
"""
|
||||
|
||||
allowed_numa_nodes, logical_cpu_list = \
|
||||
CpuPlatform.get_allowed_cpu_core_node_list()
|
||||
assert len(allowed_numa_nodes) >= self.parallel_config.world_size, (
|
||||
f"No enough allowed NUMA nodes to bind threads of "
|
||||
f"{self.parallel_config.world_size} CPUWorkers. "
|
||||
f"Allowed NUMA nodes are {allowed_numa_nodes}. "
|
||||
"Please try to bind threads manually.")
|
||||
|
||||
# Get CPUs on NUMA node `allowed_numa_nodes[local_rank]``
|
||||
selected_numa_node = allowed_numa_nodes[
|
||||
self.local_rank] # type: ignore
|
||||
logical_cpu_list = [
|
||||
x for x in logical_cpu_list if x.numa_node == selected_numa_node
|
||||
]
|
||||
|
||||
# Select CPUs from each physical core via cpu_selector
|
||||
core_to_cpus: dict[int, list[LogicalCPUInfo]] = {}
|
||||
for cpu_info in logical_cpu_list:
|
||||
if cpu_info.physical_core not in core_to_cpus:
|
||||
core_to_cpus[cpu_info.physical_core] = []
|
||||
core_to_cpus[cpu_info.physical_core].append(cpu_info)
|
||||
logical_cpu_list = []
|
||||
for cpu_list in core_to_cpus.values():
|
||||
cpu_list = sorted(cpu_list, key=lambda x: x.id)
|
||||
logical_cpu_list.extend(cpu_selector(cpu_list))
|
||||
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.id)
|
||||
|
||||
# Reserve CPUs for other processes
|
||||
reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU
|
||||
if reserve_cpu_num is None:
|
||||
need_reserve = (self.parallel_config.world_size > 1 or
|
||||
self.parallel_config.data_parallel_size_local > 1)
|
||||
reserve_cpu_num = 1 if need_reserve else 0
|
||||
assert len(logical_cpu_list) > reserve_cpu_num, (
|
||||
f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) "
|
||||
f"should less than {len(logical_cpu_list)}.")
|
||||
if reserve_cpu_num != 0:
|
||||
logical_cpu_list = logical_cpu_list[:-reserve_cpu_num]
|
||||
|
||||
logger.info("auto thread-binding list (id, physical core): %s",
|
||||
[(x.id, x.physical_core) for x in logical_cpu_list])
|
||||
return ",".join([str(x.id) for x in logical_cpu_list])
|
||||
863
vllm/v1/worker/gpu_input_batch.py
Normal file
863
vllm/v1/worker/gpu_input_batch.py
Normal file
@@ -0,0 +1,863 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Datastructures defining a GPU input batch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
|
||||
LogitsProcessors,
|
||||
MoveDirectionality)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestState:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: Optional[list[int]]
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
generator: Optional[torch.Generator]
|
||||
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
output_token_ids: list[int]
|
||||
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
prompt_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds)
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
||||
|
||||
# Temporary back-compatibility for plugins that define model runner
|
||||
@property
|
||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||
def mm_inputs(self) -> list[MultiModalKwargsItems]:
|
||||
return [
|
||||
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
|
||||
if f.data is not None
|
||||
]
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
if self.prompt_token_ids is None:
|
||||
raise ValueError(
|
||||
f"Tried to access token index {idx}, but that token was "
|
||||
"provided via prompt_embeds, and its ID is unknown.")
|
||||
return self.prompt_token_ids[idx]
|
||||
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
|
||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
class InputBatch:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self._req_ids: list[Optional[str]] = []
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
|
||||
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
||||
# Find a way to reduce the CPU memory usage.
|
||||
# This buffer is not directly transferred to the GPU, so it does not
|
||||
# need to be pinned.
|
||||
self.token_ids_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=bool,
|
||||
pin_memory=False)
|
||||
# Store prompt embeddings per request to avoid OOM from large upfront
|
||||
# allocation if max_model_len is big.
|
||||
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
|
||||
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
|
||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, ),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.num_computed_tokens_cpu = \
|
||||
self.num_computed_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Block table.
|
||||
self.block_table = MultiGroupBlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
block_sizes=block_sizes,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
||||
self.greedy_reqs: set[str] = set()
|
||||
self.random_reqs: set[str] = set()
|
||||
|
||||
self.top_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
||||
self.top_p_reqs: set[str] = set()
|
||||
|
||||
self.top_k = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: set[str] = set()
|
||||
|
||||
# IDs of requests which do not support spec decoding
|
||||
self.spec_decode_unsupported_reqs: set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.frequency_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.frequency_penalties_cpu = \
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_reqs: set[str] = set()
|
||||
|
||||
# Presence penalty related data structures
|
||||
self.presence_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
||||
)
|
||||
self.presence_penalties_reqs: set[str] = set()
|
||||
|
||||
# Repetition penalty related data structures
|
||||
self.repetition_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.repetition_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.repetition_penalties_cpu = \
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# Speculative decoding
|
||||
self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ),
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.num_accepted_tokens_cpu = \
|
||||
self.num_accepted_tokens_cpu_tensor.numpy()
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||
dtype=np.int32)
|
||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
||||
|
||||
# req_index -> generator
|
||||
# NOTE(woosuk): The indices of the requests that do not have their own
|
||||
# generator should not be included in the dictionary.
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: dict[str, int] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
|
||||
# Internal representation of per-step batch state changes, used for
|
||||
# reordering persistent batch and generating logitsprocs batch state
|
||||
# updates. Should reset each step.
|
||||
self.batch_update_builder = BatchUpdateBuilder()
|
||||
|
||||
# TODO convert this to LogitsProcessor
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
||||
# the value is False. Since we use masked_fill_ to set -inf.
|
||||
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||
|
||||
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
|
||||
dtype=bool)
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
# Store provided logitsprocs. If none are provided, initialize empty
|
||||
# data structure
|
||||
self.logitsprocs = logitsprocs or LogitsProcessors()
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
self.pooling_params: dict[str, PoolingParams] = {}
|
||||
|
||||
# Cached reference to the GPU tensor of previously sampled tokens
|
||||
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
||||
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
# None elements should only be present transiently
|
||||
# while performing state updates to the batch.
|
||||
return cast(list[str], self._req_ids)
|
||||
|
||||
def _register_add_request(self, request: "CachedRequestState") -> int:
|
||||
"""Track add-request operations for logits processors.
|
||||
Not applicable to pooling models.
|
||||
"""
|
||||
|
||||
# Fill the next empty index if there is one.
|
||||
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
|
||||
# Append to end otherwise.
|
||||
new_req_index = self.num_reqs
|
||||
|
||||
assert new_req_index < self.max_num_reqs
|
||||
self.batch_update_builder.batch_changed = True
|
||||
if request.sampling_params:
|
||||
# Detailed added request metadata is only required for non-pooling
|
||||
# models, to support logitsprocs.
|
||||
self.batch_update_builder.added.append(
|
||||
(new_req_index, request.sampling_params,
|
||||
request.prompt_token_ids, request.output_token_ids))
|
||||
|
||||
return new_req_index
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
) -> int:
|
||||
req_index = self._register_add_request(request)
|
||||
|
||||
req_id = request.req_id
|
||||
if req_index == len(self._req_ids):
|
||||
self._req_ids.append(req_id)
|
||||
self.req_output_token_ids.append(request.output_token_ids)
|
||||
else:
|
||||
self._req_ids[req_index] = req_id
|
||||
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
request.prompt_token_ids, request.prompt_embeds)
|
||||
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
||||
start_idx = num_prompt_tokens
|
||||
end_idx = start_idx + len(request.output_token_ids)
|
||||
if request.prompt_token_ids is not None:
|
||||
self.token_ids_cpu[
|
||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
self.is_token_ids[req_index, :num_prompt_tokens] = True
|
||||
else:
|
||||
self.is_token_ids[req_index, :num_prompt_tokens] = False
|
||||
if request.prompt_embeds is not None:
|
||||
self.req_prompt_embeds[req_index] = request.prompt_embeds
|
||||
self.token_ids_cpu[req_index,
|
||||
start_idx:end_idx] = request.output_token_ids
|
||||
self.is_token_ids[req_index, start_idx:end_idx] = True
|
||||
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
|
||||
# NOTE(woosuk): This may include spec decode tokens.
|
||||
self.num_tokens[req_index] = request.num_tokens
|
||||
# Number of tokens without spec decode tokens.
|
||||
self.num_tokens_no_spec[req_index] = request.num_tokens
|
||||
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
if sampling_params := request.sampling_params:
|
||||
if (self.is_spec_decode
|
||||
and is_spec_decode_unsupported(sampling_params)):
|
||||
self.spec_decode_unsupported_reqs.add(req_id)
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Should avoid division by zero later when apply_temperature.
|
||||
self.temperature_cpu[req_index] = 0.0
|
||||
self.greedy_reqs.add(req_id)
|
||||
else:
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1:
|
||||
self.top_p_reqs.add(req_id)
|
||||
top_k = sampling_params.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_reqs.add(req_id)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k_cpu[req_index] = top_k
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[
|
||||
req_index] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
if request.generator is not None:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = (self.vocab_size
|
||||
if sampling_params.logprobs == -1
|
||||
else sampling_params.logprobs)
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = (
|
||||
self.vocab_size if sampling_params.prompt_logprobs == -1
|
||||
else sampling_params.prompt_logprobs)
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||
# Lazy allocation for this tensor, which can be large.
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device)
|
||||
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device="cpu")
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids] = False
|
||||
|
||||
if sampling_params.bad_words_token_ids:
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
elif pooling_params := request.pooling_params:
|
||||
self.pooling_params[req_id] = pooling_params
|
||||
self.logits_processing_needs_token_ids[req_index] = (
|
||||
pooling_params.requires_token_ids)
|
||||
else:
|
||||
raise NotImplementedError("Unrecognized request type")
|
||||
|
||||
# Speculative decoding: by default 1 token is generated.
|
||||
self.num_accepted_tokens_cpu[req_index] = 1
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
if lora_id not in self.lora_id_to_request_ids:
|
||||
self.lora_id_to_request_ids[lora_id] = set()
|
||||
|
||||
self.request_lora_mapping[req_index] = lora_id
|
||||
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
||||
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
||||
else:
|
||||
# No LoRA
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
return req_index
|
||||
|
||||
def remove_request(self, req_id: str) -> Optional[int]:
|
||||
"""This method must always be followed by a call to condense().
|
||||
|
||||
Args:
|
||||
req_id: request to remove
|
||||
|
||||
Returns:
|
||||
Removed request index, or `None` if `req_id` not recognized
|
||||
"""
|
||||
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
|
||||
self.batch_update_builder.removed_append(req_index)
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
lora_req_ids = self.lora_id_to_request_ids[lora_id]
|
||||
lora_req_ids.discard(req_id)
|
||||
if not lora_req_ids:
|
||||
del self.lora_id_to_request_ids[lora_id]
|
||||
del self.lora_id_to_lora_request[lora_id]
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
if self.is_pooling_model:
|
||||
self.pooling_params.pop(req_id, None)
|
||||
return req_index
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.spec_decode_unsupported_reqs.discard(req_id)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||
self.bad_words_token_ids.pop(req_index, None)
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
old_id_i2 = self._req_ids[i2]
|
||||
self._req_ids[i1], self._req_ids[i2] =\
|
||||
self._req_ids[i2], self._req_ids[i1] # noqa
|
||||
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
||||
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
||||
assert old_id_i1 is not None and old_id_i2 is not None
|
||||
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
||||
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
||||
self.num_tokens[i1], self.num_tokens[i2] =\
|
||||
self.num_tokens[i2], self.num_tokens[i1]
|
||||
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
||||
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
||||
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
||||
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
||||
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
||||
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
||||
|
||||
# NOTE: the following is unsafe
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
||||
# instead, we need to temporiarily copy the data for one of the indices
|
||||
# TODO(lucas): optimize this by only copying valid indices
|
||||
tmp = self.token_ids_cpu[i1, ...].copy()
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
|
||||
|
||||
# Swap prompt embeddings if they exist
|
||||
embeds_i1 = self.req_prompt_embeds.get(i1)
|
||||
embeds_i2 = self.req_prompt_embeds.get(i2)
|
||||
if embeds_i1 is not None:
|
||||
self.req_prompt_embeds[i2] = embeds_i1
|
||||
else:
|
||||
self.req_prompt_embeds.pop(i2, None)
|
||||
if embeds_i2 is not None:
|
||||
self.req_prompt_embeds[i1] = embeds_i2
|
||||
else:
|
||||
self.req_prompt_embeds.pop(i1, None)
|
||||
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
|
||||
if self.is_pooling_model:
|
||||
# Sampling and logits parameters don't apply to pooling models.
|
||||
return
|
||||
|
||||
# For autoregressive models, track detailed request reordering info
|
||||
# to support logitsprocs.
|
||||
self.batch_update_builder.moved.append(
|
||||
(i1, i2, MoveDirectionality.SWAP))
|
||||
|
||||
self.temperature_cpu[i1], self.temperature_cpu[i2] = \
|
||||
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
||||
self.top_p_cpu[i1], self.top_p_cpu[i2] = \
|
||||
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
||||
self.top_k_cpu[i1], self.top_k_cpu[i2] = \
|
||||
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
||||
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \
|
||||
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
||||
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \
|
||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
|
||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||
self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\
|
||||
self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1]
|
||||
|
||||
swap_dict_values(self.generators, i1, i2)
|
||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1]
|
||||
|
||||
def condense(self) -> None:
|
||||
"""Slide non-empty requests down into lower, empty indices.
|
||||
|
||||
Any consecutive empty indices at the very end of the list are not
|
||||
filled.
|
||||
|
||||
Returns:
|
||||
swaps: list of (from,to) swap tuples for moved requests
|
||||
empty_req_indices: indices not filled by condensation
|
||||
"""
|
||||
num_reqs = self.num_reqs
|
||||
|
||||
if not (empty_req_indices := self.batch_update_builder.removed):
|
||||
# All removed requests were replaced by added requests, or else no
|
||||
# requests were removed at all. No condense() needed
|
||||
return
|
||||
if num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
self._req_ids.clear()
|
||||
self.req_output_token_ids.clear()
|
||||
return
|
||||
|
||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||
# is sorted in descending order.
|
||||
last_req_index = num_reqs + len(empty_req_indices) - 1
|
||||
while empty_req_indices:
|
||||
# Find the largest non-empty index.
|
||||
while last_req_index in empty_req_indices:
|
||||
last_req_index -= 1
|
||||
|
||||
# Find the smallest empty index.
|
||||
empty_index = self.batch_update_builder.peek_removed()
|
||||
assert empty_index is not None
|
||||
if empty_index >= last_req_index:
|
||||
break
|
||||
|
||||
# Move active request down into empty request
|
||||
# index.
|
||||
self.batch_update_builder.pop_removed()
|
||||
req_id = self._req_ids[last_req_index]
|
||||
output_token_ids = self.req_output_token_ids[last_req_index]
|
||||
assert req_id is not None
|
||||
self._req_ids[empty_index] = req_id
|
||||
self._req_ids[last_req_index] = None
|
||||
self.req_output_token_ids[empty_index] = output_token_ids
|
||||
self.req_output_token_ids[last_req_index] = None
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
num_tokens = self.num_tokens[last_req_index]
|
||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||
last_req_index, :num_tokens]
|
||||
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
|
||||
last_req_index, :num_tokens]
|
||||
if last_req_index in self.req_prompt_embeds:
|
||||
self.req_prompt_embeds[
|
||||
empty_index] = self.req_prompt_embeds.pop(last_req_index)
|
||||
self.num_tokens[empty_index] = num_tokens
|
||||
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
||||
last_req_index]
|
||||
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
||||
last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
if self.is_pooling_model:
|
||||
last_req_index -= 1
|
||||
# Sampling state not used by pooling models.
|
||||
continue
|
||||
|
||||
# Autoregressive models require detailed tracking of condense
|
||||
# operations to support logitsprocs
|
||||
self.batch_update_builder.moved.append(
|
||||
(last_req_index, empty_index,
|
||||
MoveDirectionality.UNIDIRECTIONAL))
|
||||
|
||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||
self.frequency_penalties_cpu[
|
||||
empty_index] = self.frequency_penalties_cpu[last_req_index]
|
||||
self.presence_penalties_cpu[
|
||||
empty_index] = self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[
|
||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||
self.num_accepted_tokens_cpu[
|
||||
empty_index] = self.num_accepted_tokens_cpu[last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
# TODO convert these to LogitsProcessors
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[
|
||||
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
||||
last_req_index]
|
||||
|
||||
bad_words_token_ids = self.bad_words_token_ids.pop(
|
||||
last_req_index, None)
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
# Trim lists to the batch size.
|
||||
del self._req_ids[num_reqs:]
|
||||
del self.req_output_token_ids[num_reqs:]
|
||||
|
||||
def refresh_metadata(self):
|
||||
"""Apply any batch updates to sampling metadata."""
|
||||
|
||||
if self.is_pooling_model:
|
||||
batch_changed = self.batch_update_builder.reset()
|
||||
if batch_changed:
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
return
|
||||
|
||||
# For non-pooling models - generate and apply logitsprocs update;
|
||||
# reset batch update tracking.
|
||||
# Update sampling metadata if batch state is changed.
|
||||
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
|
||||
for logit_proc in self.logitsprocs.all:
|
||||
logit_proc.update_state(batch_update)
|
||||
if batch_update:
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
def _make_sampling_metadata(self) -> SamplingMetadata:
|
||||
num_reqs = self.num_reqs
|
||||
if not self.all_greedy:
|
||||
temperature = copy_slice(self.temperature_cpu_tensor,
|
||||
self.temperature, num_reqs)
|
||||
else:
|
||||
temperature = None
|
||||
if not self.no_top_p:
|
||||
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
||||
if not self.no_top_k:
|
||||
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
||||
|
||||
if not self.no_penalties:
|
||||
# Since syncing these tensors is expensive only copy them
|
||||
# if necessary i.e. if there are requests which require
|
||||
# penalties to be applied during sampling.
|
||||
copy_slice(self.frequency_penalties_cpu_tensor,
|
||||
self.frequency_penalties, num_reqs)
|
||||
copy_slice(self.presence_penalties_cpu_tensor,
|
||||
self.presence_penalties, num_reqs)
|
||||
copy_slice(self.repetition_penalties_cpu_tensor,
|
||||
self.repetition_penalties, num_reqs)
|
||||
|
||||
needs_prompt_token_ids = (
|
||||
not self.no_penalties
|
||||
or self.logits_processing_needs_token_ids[:num_reqs].any())
|
||||
if needs_prompt_token_ids:
|
||||
# The prompt tokens are used only for applying penalties or
|
||||
# step pooling during the sampling/pooling process.
|
||||
# Hence copy these tensors only when there are requests which
|
||||
# need penalties/step_pooler to be applied.
|
||||
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
||||
else:
|
||||
prompt_token_ids = None
|
||||
|
||||
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
if not self.no_allowed_token_ids:
|
||||
assert self.allowed_token_ids_mask is not None
|
||||
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
|
||||
self.allowed_token_ids_mask, num_reqs)
|
||||
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
all_greedy=self.all_greedy,
|
||||
all_random=self.all_random,
|
||||
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
||||
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
||||
generators=self.generators,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
frequency_penalties=self.frequency_penalties[:num_reqs],
|
||||
presence_penalties=self.presence_penalties[:num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
||||
no_penalties=self.no_penalties,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
logitsprocs=self.logitsprocs,
|
||||
)
|
||||
|
||||
def get_pooling_params(self) -> list[PoolingParams]:
|
||||
assert len(self.req_ids) == len(self.pooling_params)
|
||||
return [self.pooling_params[req_id] for req_id in self.req_ids]
|
||||
|
||||
def get_pooling_metadata(self) -> PoolingMetadata:
|
||||
pooling_params = self.get_pooling_params()
|
||||
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(
|
||||
self.num_prompt_tokens[:self.num_reqs]),
|
||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
num_reqs = self.num_reqs
|
||||
max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
(self.num_reqs, max_prompt_len),
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
||||
prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
for i in range(num_reqs):
|
||||
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
||||
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
||||
non_blocking=True)
|
||||
|
||||
def make_lora_inputs(
|
||||
self, num_scheduled_tokens: np.ndarray
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||
"""
|
||||
Given the num_scheduled_tokens for each request in the batch, return
|
||||
datastructures used to activate the current LoRAs.
|
||||
Returns:
|
||||
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
||||
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
||||
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
||||
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
||||
3. lora_requests: Set of relevant LoRA requests.
|
||||
"""
|
||||
|
||||
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
||||
prompt_lora_mapping = tuple(req_lora_mapping)
|
||||
token_lora_mapping = tuple(
|
||||
req_lora_mapping.repeat(num_scheduled_tokens))
|
||||
active_lora_requests: set[LoRARequest] = set(
|
||||
self.lora_id_to_lora_request.values())
|
||||
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
@property
|
||||
def all_greedy(self) -> bool:
|
||||
return len(self.random_reqs) == 0
|
||||
|
||||
@property
|
||||
def all_random(self) -> bool:
|
||||
return len(self.greedy_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_p(self) -> bool:
|
||||
return len(self.top_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_penalties(self) -> bool:
|
||||
return (len(self.presence_penalties_reqs) == 0
|
||||
and len(self.frequency_penalties_reqs) == 0
|
||||
and len(self.repetition_penalties_reqs) == 0)
|
||||
|
||||
@property
|
||||
def max_num_logprobs(self) -> Optional[int]:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
||||
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return not self.num_prompt_logprobs
|
||||
|
||||
@property
|
||||
def no_allowed_token_ids(self) -> bool:
|
||||
return len(self.has_allowed_token_ids) == 0
|
||||
4160
vllm/v1/worker/gpu_model_runner.py
Normal file
4160
vllm/v1/worker/gpu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
399
vllm/v1/worker/gpu_ubatch_wrapper.py
Normal file
399
vllm/v1/worker/gpu_ubatch_wrapper.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id)
|
||||
from vllm.forward_context import (create_forward_context, get_forward_context,
|
||||
override_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UbatchMetadata:
|
||||
context: UBatchContext
|
||||
input_ids: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
inputs_embeds: Optional[torch.Tensor]
|
||||
intermediate_tensors: Optional[IntermediateTensors]
|
||||
num_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class CUDAGraphMetaData:
|
||||
cudagraph: torch.cuda.CUDAGraph
|
||||
ubatch_metadata: UbatchMetadata
|
||||
outputs: Optional[Any] = None
|
||||
|
||||
|
||||
class SMControlContextManager:
|
||||
|
||||
def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None],
|
||||
set_compute_sms: Callable[[int], None]):
|
||||
"""
|
||||
Context manager for controlling SM (Streaming Multiprocessor)
|
||||
allocation. Upon entering the context, it sets the number of SMs
|
||||
allocated for communication and computation to comm_sms and
|
||||
total_sms - comm_sms respectively. Upon exiting, it restores the
|
||||
allocation to use all available SMs (i.e. total_sms).
|
||||
|
||||
Args:
|
||||
comm_sms (int): The number of SMs to allocate for communication.
|
||||
(The remainder will be used for computation.)
|
||||
set_comm_sms (Callable[[int], None]):
|
||||
A function that sets the number of SMs for communication.
|
||||
set_compute_sms (Callable[[int], None]):
|
||||
A function that sets the number of SMs for computation.
|
||||
"""
|
||||
|
||||
assert current_platform.is_cuda(), \
|
||||
"SM control is currently only supported on CUDA"
|
||||
|
||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
total_sms = props.multi_processor_count
|
||||
|
||||
assert comm_sms < total_sms
|
||||
self.total_sms = total_sms
|
||||
self.compute_sms = total_sms - comm_sms
|
||||
self.comm_sms = comm_sms
|
||||
self.set_comm_sms = set_comm_sms
|
||||
self.set_compute_sms = set_compute_sms
|
||||
|
||||
def __enter__(self):
|
||||
self.set_comm_sms(self.comm_sms)
|
||||
self.set_compute_sms(self.compute_sms)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.set_comm_sms(self.total_sms)
|
||||
self.set_compute_sms(self.total_sms)
|
||||
|
||||
|
||||
class UBatchWrapper:
|
||||
|
||||
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode, device: torch.cuda.device):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.comm_stream = torch.cuda.Stream(device=device)
|
||||
# Two ubatch threads plus the main thread
|
||||
self.ready_barrier = threading.Barrier(3)
|
||||
|
||||
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
|
||||
|
||||
self.cudagraph_wrapper = None
|
||||
self.graph_pool = None
|
||||
if runtime_mode is not CUDAGraphMode.NONE:
|
||||
self.cudagraph_wrapper = CUDAGraphWrapper(
|
||||
runnable, vllm_config, runtime_mode=runtime_mode)
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
self.sm_control = self._create_sm_control_context(vllm_config)
|
||||
self.device = device
|
||||
|
||||
@staticmethod
|
||||
def _create_sm_control_context(vllm_config: VllmConfig):
|
||||
comm_sms = envs.VLLM_DBO_COMM_SMS
|
||||
|
||||
set_comm_sms = lambda sms: None
|
||||
if vllm_config.parallel_config.enable_expert_parallel:
|
||||
# Currently only DeepEP highthroughput supports SM control so this
|
||||
# only affects that case.
|
||||
all2all_manager = get_ep_group(
|
||||
).device_communicator.all2all_manager
|
||||
|
||||
if all2all_manager.max_sms_used() is not None:
|
||||
comm_sms = min(comm_sms, all2all_manager.max_sms_used())
|
||||
|
||||
if comm_sms > 0:
|
||||
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
|
||||
|
||||
# TODO(lucas): support other kernels besides DeepGEMM
|
||||
set_compute_sms = lambda sms: None
|
||||
if has_deep_gemm() and comm_sms > 0:
|
||||
import deep_gemm as dg
|
||||
set_compute_sms = lambda sms: dg.set_num_sms(sms)
|
||||
|
||||
return SMControlContextManager(comm_sms=comm_sms,
|
||||
set_comm_sms=set_comm_sms,
|
||||
set_compute_sms=set_compute_sms)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"cudagraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||
"""
|
||||
Capture a cudagraph for a microbatched run.
|
||||
|
||||
The logic here is somewhat complicated because we need to make sure that
|
||||
each of the ubatch threads initialize the cuda context before we start
|
||||
the graph capture.
|
||||
|
||||
The flow is as follows:
|
||||
1. The main thread starts up each ubatch thread. Each thread will
|
||||
initialize its cuda context (torch.cuda.current_blas_handle())
|
||||
before going to sleep upon entering the ubatch_context.
|
||||
|
||||
2. The main thread starts the graph capture and wakes up the first
|
||||
ubatch thread.
|
||||
|
||||
3. Each ubatch thread runs the model to completion and returns the
|
||||
completed output tensors back to the main thread.
|
||||
|
||||
4. The main thread stores the captured cudagraph along with its metadata
|
||||
and returns
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def _capture_ubatch_thread(results, ubatch_metadata):
|
||||
torch.cuda.set_device(self.device)
|
||||
ubatch_context = ubatch_metadata.context
|
||||
with torch.cuda.stream(ubatch_context.compute_stream):
|
||||
_ = torch.cuda.current_blas_handle()
|
||||
with torch.cuda.stream(ubatch_context.comm_stream):
|
||||
_ = torch.cuda.current_blas_handle()
|
||||
with ubatch_context:
|
||||
model_output = model(
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
)
|
||||
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
compute_stream = ubatch_metadata[0].context.compute_stream
|
||||
num_tokens = ubatch_metadata[0].num_tokens + \
|
||||
ubatch_metadata[1].num_tokens
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(target=_capture_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
metadata,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
self.ready_barrier.wait() # Wait for both threads to be ready
|
||||
|
||||
# Capture the cudagraph
|
||||
cudagraph_metadata = \
|
||||
CUDAGraphMetaData(
|
||||
cudagraph=torch.cuda.CUDAGraph(),
|
||||
ubatch_metadata=ubatch_metadata,
|
||||
)
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
with torch.cuda.graph(cudagraph_metadata.cudagraph,
|
||||
stream=compute_stream,
|
||||
pool=self.graph_pool):
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
cudagraph_metadata.outputs = result
|
||||
self.cudagraphs[num_tokens] = cudagraph_metadata
|
||||
return cudagraph_metadata.outputs
|
||||
|
||||
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(results, model, ubatch_metadata):
|
||||
with ubatch_metadata.context:
|
||||
model_output = model(
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
)
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
|
||||
# Ubatch threads will manually manage the forward context, so we
|
||||
# override it to None here so we can have it restored correctly
|
||||
# after both threads have finished
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(target=_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
model,
|
||||
metadata,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
self.ready_barrier.wait() # Wait for both threads to be ready
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
return result
|
||||
|
||||
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
|
||||
positions, inputs_embeds, intermediate_tensors,
|
||||
compute_stream, dp_metadata, batch_descriptor,
|
||||
cudagraph_runtime_mode) -> list[UbatchMetadata]:
|
||||
|
||||
# Create one forward context per ubatch
|
||||
forward_contexts = []
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
forward_contexts.append(
|
||||
create_forward_context(
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
self.vllm_config,
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode))
|
||||
|
||||
ubatch_ctxs = make_ubatch_contexts(
|
||||
num_micro_batches=len(ubatch_slices),
|
||||
comm_stream=self.comm_stream,
|
||||
compute_stream=compute_stream,
|
||||
forward_contexts=forward_contexts,
|
||||
ready_barrier=self.ready_barrier)
|
||||
|
||||
ubatch_metadata: list[UbatchMetadata] = []
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
sliced_input_ids, sliced_positions, sliced_inputs_embeds, \
|
||||
sliced_intermediate_tensors = \
|
||||
self._slice_model_inputs(
|
||||
ubatch_slice.token_slice, input_ids, positions,
|
||||
inputs_embeds, intermediate_tensors)
|
||||
ubatch_metadata.append(
|
||||
UbatchMetadata(
|
||||
context=ubatch_ctxs[i],
|
||||
input_ids=sliced_input_ids,
|
||||
positions=sliced_positions,
|
||||
inputs_embeds=sliced_inputs_embeds,
|
||||
intermediate_tensors=sliced_intermediate_tensors,
|
||||
num_tokens=ubatch_slice.token_slice.stop -
|
||||
ubatch_slice.token_slice.start))
|
||||
|
||||
return ubatch_metadata
|
||||
|
||||
def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions,
|
||||
inputs_embeds, intermediate_tensors):
|
||||
sliced_input_ids = input_ids[tokens_slice]
|
||||
# if we are using mrope. Mrope adds an additional dimension to the
|
||||
# positions tensor
|
||||
if positions.ndim == 2:
|
||||
sliced_positions = positions[:, tokens_slice]
|
||||
else:
|
||||
sliced_positions = positions[tokens_slice]
|
||||
sliced_inputs_embeds = inputs_embeds[
|
||||
tokens_slice] if inputs_embeds else None
|
||||
sliced_intermediate_tensors = intermediate_tensors[
|
||||
tokens_slice] if intermediate_tensors else None
|
||||
|
||||
return (sliced_input_ids, sliced_positions, sliced_inputs_embeds,
|
||||
sliced_intermediate_tensors)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
ubatch_slices = forward_context.ubatch_slices
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
# If there's no ubatching, just run the runnable object
|
||||
if ubatch_slices is None:
|
||||
|
||||
# This is to account for the case where ubatching was aborted.
|
||||
# When we capture full graphs we only capture one graph per shape,
|
||||
# meaning that if we have a ubatched cudagraph for the current
|
||||
# num_tokens, we don't have a non-ubatched one. Without this
|
||||
# check, the cudagraph wrapper will try to capture a cudagraph
|
||||
# for this shape during a normal run.
|
||||
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
|
||||
assert batch_descriptor is not None
|
||||
if batch_descriptor.num_tokens in self.cudagraphs:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
if cudagraph_runtime_mode in (CUDAGraphMode.NONE,
|
||||
CUDAGraphMode.PIECEWISE):
|
||||
return self.runnable(*args, **kwargs)
|
||||
else:
|
||||
assert self.cudagraph_wrapper is not None
|
||||
return self.cudagraph_wrapper(*args, **kwargs)
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
num_tokens = (ubatch_slices[0].token_slice.stop -
|
||||
ubatch_slices[0].token_slice.start) * 2
|
||||
input_ids = kwargs['input_ids']
|
||||
positions = kwargs['positions']
|
||||
intermediate_tensors = kwargs['intermediate_tensors']
|
||||
inputs_embeds = kwargs['inputs_embeds']
|
||||
compute_stream = torch.cuda.current_stream()
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
|
||||
# We shouldn't be here unless we are running with multiple DP ranks
|
||||
assert dp_metadata is not None
|
||||
|
||||
if num_tokens not in self.cudagraphs \
|
||||
and cudagraph_runtime_mode is CUDAGraphMode.FULL:
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
compute_stream=compute_stream,
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
||||
with self.sm_control:
|
||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||
elif num_tokens in self.cudagraphs \
|
||||
and cudagraph_runtime_mode is CUDAGraphMode.FULL:
|
||||
cudagraph_metadata = self.cudagraphs[num_tokens]
|
||||
cudagraph_metadata.cudagraph.replay()
|
||||
return cudagraph_metadata.outputs
|
||||
else:
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
compute_stream=compute_stream,
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
||||
with self.sm_control:
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
710
vllm/v1/worker/gpu_worker.py
Normal file
710
vllm/v1/worker/gpu_worker.py
Normal file
@@ -0,0 +1,710 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A GPU worker class."""
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, ModelRunnerOutput)
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class Worker(WorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
logger.debug(
|
||||
"Profiler config: record_shapes=%s,"
|
||||
"profile_memory=%s,with_stack=%s,with_flops=%s",
|
||||
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
|
||||
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
||||
|
||||
# Save the buffers before level 2 sleep
|
||||
if level == 2:
|
||||
model = self.model_runner.model
|
||||
self._sleep_saved_buffers = {
|
||||
name: buffer.cpu().clone()
|
||||
for name, buffer in model.named_buffers()
|
||||
}
|
||||
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
||||
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||
used_bytes = total - free_bytes_after_sleep
|
||||
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
||||
logger.info(
|
||||
"Sleep mode freed %.2f GiB memory, "
|
||||
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
||||
used_bytes / GiB_bytes)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.wake_up(tags)
|
||||
|
||||
# Restore the buffers after level 2 sleep
|
||||
if len(self._sleep_saved_buffers):
|
||||
model = self.model_runner.model
|
||||
for name, buffer in model.named_buffers():
|
||||
if name in self._sleep_saved_buffers:
|
||||
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
||||
self._sleep_saved_buffers = {}
|
||||
|
||||
def _maybe_get_memory_pool_context(self,
|
||||
tag: str) -> AbstractContextManager:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
if tag == "weights":
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be "
|
||||
"used for one instance per process.")
|
||||
context = allocator.use_memory_pool(tag=tag)
|
||||
else:
|
||||
context = nullcontext()
|
||||
return context
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "cuda":
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
||||
|
||||
# Initialize the distributed environment BEFORE taking
|
||||
# memory snapshot
|
||||
# This ensures NCCL buffers are allocated before we measure
|
||||
# available memory
|
||||
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
current_platform.dist_backend)
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Now take memory snapshot after NCCL is initialized
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# take current memory snapshot
|
||||
self.init_snapshot = MemorySnapshot()
|
||||
self.requested_memory = (self.init_snapshot.total_memory *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
if self.init_snapshot.free_memory < self.requested_memory:
|
||||
GiB = lambda b: round(b / GiB_bytes, 2)
|
||||
raise ValueError(
|
||||
f"Free memory on device "
|
||||
f"({GiB(self.init_snapshot.free_memory)}/"
|
||||
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
|
||||
f"is less than desired GPU memory utilization "
|
||||
f"({self.cache_config.gpu_memory_utilization}, "
|
||||
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
|
||||
f"utilization or reduce GPU memory used by other processes."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||
self.vllm_config, self.device)
|
||||
|
||||
if self.rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
report_usage_stats(self.vllm_config)
|
||||
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
with self._maybe_get_memory_pool_context(tag="weights"):
|
||||
self.model_runner.load_model(eep_scale_up=eep_scale_up)
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
|
||||
def reload_weights(self) -> None:
|
||||
self.model_runner.reload_weights()
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Profiles the peak memory usage of the model to determine how much
|
||||
memory can be used for KV cache without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculates the free memory that can be used for KV cache in
|
||||
bytes.
|
||||
|
||||
Tip:
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
GiB = lambda b: b / GiB_bytes
|
||||
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
|
||||
# still need a profile run which compiles the model for
|
||||
# max_num_batched_tokens
|
||||
self.model_runner.profile_run()
|
||||
|
||||
msg = (
|
||||
f"Initial free memory {GiB(self.init_snapshot.free_memory)} "
|
||||
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
|
||||
"KV Cache as specified by kv_cache_memory_bytes config and "
|
||||
"skipped memory profiling. This does does not respect the "
|
||||
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
|
||||
"config when you want manual control of KV cache memory "
|
||||
"size. If OOM'ed, check the difference of initial free "
|
||||
"memory between the current run and the previous run "
|
||||
"where kv_cache_memory_bytes is suggested and update it "
|
||||
"correspondingly.")
|
||||
logger.info(msg)
|
||||
return kv_cache_memory_bytes
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with memory_profiling(
|
||||
self.init_snapshot,
|
||||
weights_memory=int(self.model_runner.model_memory_usage),
|
||||
) as profile_result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
self.non_torch_memory = profile_result.non_torch_increase
|
||||
self.peak_activation_memory = profile_result.torch_peak_increase
|
||||
|
||||
free_gpu_memory = profile_result.after_profile.free_memory
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
assert self.init_snapshot.free_memory > free_gpu_memory, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
|
||||
f"current free memory {GiB(free_gpu_memory)} GiB. "
|
||||
"This happens when other processes sharing the same container "
|
||||
"release GPU memory while vLLM is profiling during initialization. "
|
||||
"To fix this, ensure consistent GPU memory allocation or "
|
||||
"isolate vLLM in its own container.")
|
||||
self.available_kv_cache_memory_bytes = self.requested_memory \
|
||||
- profile_result.non_kv_cache_memory
|
||||
|
||||
unrequested_memory = self.init_snapshot.free_memory \
|
||||
- self.requested_memory
|
||||
logger.debug(
|
||||
"Initial free memory: %.2f GiB; "
|
||||
"Requested memory: %.2f (util), %.2f GiB",
|
||||
GiB(self.init_snapshot.free_memory),
|
||||
self.cache_config.gpu_memory_utilization,
|
||||
GiB(self.requested_memory),
|
||||
)
|
||||
logger.debug(
|
||||
"Free memory after profiling: %.2f GiB (total), "
|
||||
"%.2f GiB (within requested)",
|
||||
GiB(free_gpu_memory),
|
||||
GiB(free_gpu_memory - unrequested_memory),
|
||||
)
|
||||
logger.debug(profile_result)
|
||||
logger.info("Available KV cache memory: %.2f GiB",
|
||||
GiB(self.available_kv_cache_memory_bytes))
|
||||
gc.collect()
|
||||
|
||||
return int(self.available_kv_cache_memory_bytes)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
else:
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
||||
if not self.model_config.enforce_eager:
|
||||
warmup_sizes = [
|
||||
x for x in warmup_sizes if x not in
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
]
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
|
||||
|
||||
# Warmup and tune the kernels used during model execution before
|
||||
# cuda graph capture.
|
||||
kernel_warmup(self)
|
||||
|
||||
cuda_graph_memory_bytes = 0
|
||||
if not self.model_config.enforce_eager:
|
||||
cuda_graph_memory_bytes = self.model_runner.capture_model()
|
||||
|
||||
if (self.cache_config.kv_cache_memory_bytes is None
|
||||
and hasattr(self, "peak_activation_memory")):
|
||||
# Suggests optimal kv cache memory size if we rely on
|
||||
# memory_profiling to guess the kv cache memory size which
|
||||
# provides peak_activation_memory and a few other memory
|
||||
# consumption. `memory_profiling` does not consider
|
||||
# CUDAGraph memory size and may not utilize all gpu memory.
|
||||
# Users may want fine-grained control to specify kv cache
|
||||
# memory size.
|
||||
GiB = lambda b: round(b / GiB_bytes, 2)
|
||||
|
||||
# empirically observed that the memory profiling may
|
||||
# slightly underestimate the memory consumption.
|
||||
# So leave a small buffer (=150MiB) to avoid OOM.
|
||||
redundancy_buffer_memory = 150 * (1 << 20)
|
||||
non_kv_cache_memory = (self.model_runner.model_memory_usage +
|
||||
self.peak_activation_memory +
|
||||
self.non_torch_memory +
|
||||
cuda_graph_memory_bytes)
|
||||
kv_cache_memory_bytes_to_gpu_limit = (
|
||||
self.init_snapshot.free_memory - non_kv_cache_memory -
|
||||
redundancy_buffer_memory)
|
||||
kv_cache_memory_bytes_to_requested_limit = (
|
||||
int(self.requested_memory) - non_kv_cache_memory -
|
||||
redundancy_buffer_memory)
|
||||
|
||||
msg = (
|
||||
f"Free memory on device "
|
||||
f"({GiB(self.init_snapshot.free_memory)}/"
|
||||
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
|
||||
f"Desired GPU memory utilization is "
|
||||
f"({self.cache_config.gpu_memory_utilization}, "
|
||||
f"{GiB(self.requested_memory)} GiB). "
|
||||
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
|
||||
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
|
||||
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
|
||||
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
|
||||
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
|
||||
f"config with `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_requested_limit}` "
|
||||
f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
|
||||
f"into requested memory, or `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_gpu_limit}` "
|
||||
f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
|
||||
f"utilize gpu memory. Current kv cache memory in use is "
|
||||
f"{GiB(self.available_kv_cache_memory_bytes)} GiB.")
|
||||
|
||||
logger.debug(msg)
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||
if get_pp_group().is_last_rank:
|
||||
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
hidden_states, last_hidden_states = \
|
||||
self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs,
|
||||
skip_eplb=True,
|
||||
)
|
||||
if self.model_runner.is_pooling_model:
|
||||
self.model_runner._dummy_pooler_run(hidden_states)
|
||||
else:
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=last_hidden_states)
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.model_runner.get_supported_tasks()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
num_input_tokens = self.model_runner._get_num_input_tokens(
|
||||
num_scheduled_tokens)
|
||||
all_gather_tensors = {
|
||||
"residual":
|
||||
not is_residual_scattered_for_sp(self.vllm_config,
|
||||
num_input_tokens)
|
||||
}
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
||||
return output
|
||||
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
assert parallel_config.distributed_executor_backend != (
|
||||
"external_launcher") and not get_pp_group().is_last_rank
|
||||
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors)
|
||||
|
||||
kv_connector_output = output.kv_connector_output
|
||||
if not kv_connector_output:
|
||||
return None
|
||||
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# kv_connector_output
|
||||
if (not kv_connector_output.finished_sending
|
||||
and not kv_connector_output.finished_recving):
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
return self.model_runner.take_draft_token_ids()
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
self.profiler.start()
|
||||
else:
|
||||
self.profiler.stop()
|
||||
# only print profiler results on rank 0
|
||||
if self.local_rank == 0:
|
||||
print(self.profiler.key_averages().table(
|
||||
sort_by="self_cuda_time_total"))
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.model_runner._dummy_run(1, uniform_decode=True)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def _eplb_before_scale_down(self, old_ep_size: int,
|
||||
new_ep_size: int) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding "
|
||||
"before scaling down...")
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(self.model_runner.model,
|
||||
execute_shuffle=True,
|
||||
global_expert_load=None,
|
||||
rank_mapping=rank_mapping)
|
||||
torch.cuda.synchronize()
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def _eplb_after_scale_up(
|
||||
self, old_ep_size: int, new_ep_size: int,
|
||||
global_expert_load: Optional[torch.Tensor]) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding "
|
||||
"after scaling up...")
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(
|
||||
self.model_runner.model,
|
||||
execute_shuffle=True,
|
||||
global_expert_load=global_expert_load,
|
||||
rank_mapping=rank_mapping)
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def _reconfigure_parallel_config(
|
||||
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
||||
"""
|
||||
Update parallel config with provided reconfig_request
|
||||
"""
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
parallel_config.data_parallel_size = \
|
||||
reconfig_request.new_data_parallel_size
|
||||
if reconfig_request.new_data_parallel_rank != \
|
||||
ReconfigureRankType.KEEP_CURRENT_RANK:
|
||||
parallel_config.data_parallel_rank = \
|
||||
reconfig_request.new_data_parallel_rank
|
||||
if reconfig_request.new_data_parallel_rank_local != \
|
||||
ReconfigureRankType.KEEP_CURRENT_RANK:
|
||||
parallel_config.data_parallel_rank_local = \
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
parallel_config.data_parallel_master_ip = \
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
parallel_config.data_parallel_master_port = \
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
|
||||
def _reconfigure_moe(self, old_ep_size: int,
|
||||
new_ep_size: int) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Reconfigure MoE modules with provided reconfig_request
|
||||
|
||||
Return the global expert load if new_ep_size > old_ep_size,
|
||||
otherwise None
|
||||
"""
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group, get_ep_group, prepare_communication_buffer_for_model)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoEParallelConfig)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
moe_modules = [
|
||||
module for module in self.model_runner.model.modules()
|
||||
if (module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE")
|
||||
]
|
||||
num_local_experts = moe_modules[0].moe_config.num_local_experts
|
||||
assert all(module.moe_config.num_local_experts == num_local_experts
|
||||
for module in moe_modules), (
|
||||
"All MoE modules must have the same number of experts")
|
||||
for module in moe_modules:
|
||||
module.moe_config.num_experts = num_local_experts * new_ep_size
|
||||
module.global_num_experts = module.moe_config.num_experts
|
||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=get_tp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
module.moe_config.moe_parallel_config = module.moe_parallel_config
|
||||
if new_ep_size < old_ep_size:
|
||||
num_local_physical_experts = num_local_experts
|
||||
assert self.model_runner.eplb_state is not None
|
||||
new_physical_experts = \
|
||||
self.model_runner.eplb_state.physical_to_logical_map.shape[1]
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
new_physical_experts -
|
||||
self.model_runner.eplb_state.logical_replica_count.shape[1])
|
||||
global_expert_load = None
|
||||
else:
|
||||
num_local_physical_experts = torch.tensor([num_local_experts],
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
torch.distributed.broadcast(num_local_physical_experts,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0)
|
||||
num_local_physical_experts = num_local_physical_experts.item()
|
||||
new_physical_experts = num_local_physical_experts * new_ep_size
|
||||
assert self.model_runner.eplb_state is not None
|
||||
global_expert_load = self.model_runner.eplb_state.rearrange(
|
||||
self.model_runner.model, execute_shuffle=False)
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
new_physical_experts - global_expert_load.shape[1])
|
||||
prepare_communication_buffer_for_model(self.model_runner.model)
|
||||
self.model_runner.model.update_physical_experts_metadata(
|
||||
num_physical_experts=new_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts)
|
||||
return global_expert_load
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
cleanup_dist_env_and_memory, get_ep_group)
|
||||
|
||||
old_ep_size = get_ep_group().world_size
|
||||
old_ep_rank = get_ep_group().rank
|
||||
new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group(
|
||||
).world_size * get_pp_group().world_size
|
||||
if new_ep_size < old_ep_size:
|
||||
self._eplb_before_scale_down(old_ep_size, new_ep_size)
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
if reconfig_request.new_data_parallel_rank == \
|
||||
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
|
||||
assert old_ep_rank >= new_ep_size
|
||||
# shutdown
|
||||
return
|
||||
|
||||
self._reconfigure_parallel_config(reconfig_request)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
|
||||
global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)
|
||||
|
||||
if new_ep_size > old_ep_size:
|
||||
assert global_expert_load is not None
|
||||
self._eplb_after_scale_up(old_ep_size, new_ep_size,
|
||||
global_expert_load)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.model_loader import ShardedStateLoader
|
||||
ShardedStateLoader.save_model(
|
||||
self.model_runner.model,
|
||||
path,
|
||||
pattern=pattern,
|
||||
max_size=max_size,
|
||||
)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config: "TensorizerConfig",
|
||||
) -> None:
|
||||
self.model_runner.save_tensorized_model(
|
||||
tensorizer_config=tensorizer_config, )
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if runner := getattr(self, "model_runner", None):
|
||||
runner.ensure_kv_transfer_shutdown()
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
vllm_config: VllmConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
backend: str = "nccl",
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank, backend)
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.decode_context_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
132
vllm/v1/worker/kv_connector_model_runner_mixin.py
Normal file
132
vllm/v1/worker/kv_connector_model_runner_mixin.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Define KV connector functionality mixin for model runners.
|
||||
"""
|
||||
import copy
|
||||
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
||||
from typing import Generator # noqa: UP035
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput,
|
||||
ModelRunnerOutput)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU)
|
||||
class KVConnectorModelRunnerMixin:
|
||||
|
||||
@staticmethod
|
||||
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
if has_kv_transfer_group():
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
|
||||
# Background KV cache transfers happen here.
|
||||
# These transfers are designed to be async and the requests
|
||||
# involved may be disjoint from the running requests.
|
||||
# Do this here to save a collective_rpc.
|
||||
kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
@staticmethod
|
||||
def ensure_kv_transfer_shutdown() -> None:
|
||||
# has_kv_transfer_group can be None during interpreter shutdown.
|
||||
if has_kv_transfer_group and has_kv_transfer_group():
|
||||
ensure_kv_transfer_shutdown()
|
||||
|
||||
@staticmethod
|
||||
def maybe_wait_for_kv_save() -> None:
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().wait_for_save()
|
||||
|
||||
@staticmethod
|
||||
def get_finished_kv_transfers(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_finished(
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def kv_connector_no_forward(scheduler_output: "SchedulerOutput",
|
||||
vllm_config: VllmConfig) -> ModelRunnerOutput:
|
||||
# KV send/recv even if no work to do.
|
||||
with set_forward_context(
|
||||
None, vllm_config
|
||||
), KVConnectorModelRunnerMixin._get_kv_connector_output(
|
||||
scheduler_output, wait_for_save=False) as kv_connector_output:
|
||||
pass
|
||||
|
||||
if (not kv_connector_output.finished_sending
|
||||
and not kv_connector_output.finished_recving):
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def maybe_get_kv_connector_output(
|
||||
scheduler_output: "SchedulerOutput"
|
||||
) -> AbstractContextManager[Optional[KVConnectorOutput]]:
|
||||
return KVConnectorModelRunnerMixin._get_kv_connector_output(
|
||||
scheduler_output) if has_kv_transfer_group() else nullcontext()
|
||||
|
||||
# This context manager must be used within an active forward context.
|
||||
# It encapsulates the entire KV connector lifecycle within execute_model
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _get_kv_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
wait_for_save: bool = True
|
||||
) -> Generator[KVConnectorOutput, None, None]:
|
||||
output = KVConnectorOutput()
|
||||
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
|
||||
# Background KV cache transfers happen here.
|
||||
# These transfers are designed to be async and the requests
|
||||
# involved may be disjoint from the running requests.
|
||||
# Do this here to save a collective_rpc.
|
||||
kv_connector.start_load_kv(get_forward_context())
|
||||
try:
|
||||
yield output
|
||||
finally:
|
||||
if wait_for_save:
|
||||
kv_connector.wait_for_save()
|
||||
|
||||
output.finished_sending, output.finished_recving = (
|
||||
kv_connector.get_finished(scheduler_output.finished_req_ids))
|
||||
|
||||
output.kv_connector_stats = KVConnectorModelRunnerMixin.\
|
||||
get_kv_connector_stats()
|
||||
kv_connector.clear_connector_metadata()
|
||||
|
||||
@staticmethod
|
||||
def get_kv_connector_stats() -> Optional[KVConnectorStats]:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_kv_connector_stats()
|
||||
return None
|
||||
183
vllm/v1/worker/lora_model_runner_mixin.py
Normal file
183
vllm/v1/worker/lora_model_runner_mixin.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Define LoRA functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
|
||||
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
|
||||
|
||||
InputBatch = Union[TPUInputBatch, GPUInputBatch]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a mixin for GPUModelRunner
|
||||
class LoRAModelRunnerMixin:
|
||||
|
||||
LORA_WARMUP_RANK = 8
|
||||
|
||||
def load_lora_model(self, model: nn.Module, vllm_config: VllmConfig,
|
||||
device: torch.device) -> nn.Module:
|
||||
|
||||
if not supports_lora(model):
|
||||
raise ValueError(
|
||||
f"{model.__class__.__name__} does not support LoRA yet.")
|
||||
|
||||
if supports_multimodal(model):
|
||||
logger.warning("Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model.")
|
||||
|
||||
# Add LoRA Manager to the Model Runner
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
vllm_config,
|
||||
device,
|
||||
model.embedding_modules,
|
||||
model.embedding_padding_modules,
|
||||
)
|
||||
return self.lora_manager.create_lora_manager(model)
|
||||
|
||||
def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...],
|
||||
token_lora_mapping: tuple[int, ...],
|
||||
lora_requests: set[LoRARequest]) -> None:
|
||||
self._ensure_lora_enabled()
|
||||
|
||||
# Set is_prefill to True, so we always use the SGMV kernels on
|
||||
# non-cuda platforms.
|
||||
# On cuda platforms we use the same kernels for prefill and
|
||||
# decode and this flag is generally ignored.
|
||||
lora_mapping = LoRAMapping(token_lora_mapping,
|
||||
prompt_lora_mapping,
|
||||
is_prefill=True)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
def _ensure_lora_enabled(self) -> None:
|
||||
if not hasattr(self, "lora_manager"):
|
||||
raise RuntimeError(
|
||||
"LoRA is not enabled. Use --enable-lora to enable LoRA.")
|
||||
|
||||
def set_active_loras(self, input_batch: InputBatch,
|
||||
num_scheduled_tokens: np.ndarray) -> None:
|
||||
|
||||
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
|
||||
token_lora_mapping: tuple[int,
|
||||
...] # of size np.sum(num_scheduled_tokens)
|
||||
lora_requests: set[LoRARequest]
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests = \
|
||||
input_batch.make_lora_inputs(num_scheduled_tokens)
|
||||
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
||||
lora_requests)
|
||||
|
||||
@contextmanager
|
||||
def maybe_setup_dummy_loras(self,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
remove_lora: bool = True):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
# __enter__ code
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_loras = lora_config.max_loras
|
||||
|
||||
# Make dummy lora requests
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path")
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
}
|
||||
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
||||
# load from disk.
|
||||
for lr in lora_requests:
|
||||
self.lora_manager.add_dummy_lora(
|
||||
lr, rank=self.LORA_WARMUP_RANK)
|
||||
|
||||
yield
|
||||
|
||||
# __exit__ code
|
||||
if remove_lora:
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
@contextmanager
|
||||
def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig],
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
# __enter__ code
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
num_loras = lora_config.max_loras
|
||||
|
||||
# Make prompt lora mapping
|
||||
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
||||
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) %
|
||||
num_loras) + 1
|
||||
|
||||
# Make token lora mapping
|
||||
token_lora_mapping = np.repeat(prompt_lora_mapping,
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Make dummy lora requests
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path")
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
}
|
||||
|
||||
self._set_active_loras(tuple(prompt_lora_mapping),
|
||||
tuple(token_lora_mapping), lora_requests)
|
||||
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def maybe_dummy_run_with_lora(self,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
remove_lora: bool = True):
|
||||
with (
|
||||
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
||||
self.maybe_select_dummy_loras(lora_config,
|
||||
num_scheduled_tokens),
|
||||
):
|
||||
yield
|
||||
|
||||
def maybe_remove_all_loras(self, lora_config: Optional[LoRAConfig]):
|
||||
if lora_config is None:
|
||||
return
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
self._ensure_lora_enabled()
|
||||
return self.lora_manager.add_adapter(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
self._ensure_lora_enabled()
|
||||
return self.lora_manager.remove_adapter(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
self._ensure_lora_enabled()
|
||||
return self.lora_manager.pin_adapter(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
self._ensure_lora_enabled()
|
||||
return self.lora_manager.list_adapters()
|
||||
587
vllm/v1/worker/tpu_input_batch.py
Normal file
587
vllm/v1/worker/tpu_input_batch.py
Normal file
@@ -0,0 +1,587 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Datastructures defining a TPU input batch
|
||||
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class InputBatch:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self._req_ids: list[Optional[str]] = []
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
|
||||
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
||||
# Find a way to reduce the CPU memory usage.
|
||||
# This buffer is not directly transferred to the GPU, so it does not
|
||||
# need to be pinned.
|
||||
self.token_ids_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, ),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.num_computed_tokens_cpu = \
|
||||
self.num_computed_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Block table.
|
||||
self.block_table = MultiGroupBlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
||||
self.greedy_reqs: set[str] = set()
|
||||
self.random_reqs: set[str] = set()
|
||||
|
||||
self.top_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
||||
self.top_p_reqs: set[str] = set()
|
||||
|
||||
self.top_k = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: set[str] = set()
|
||||
|
||||
self.min_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
self.min_p_reqs: set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.frequency_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.frequency_penalties_cpu = \
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_reqs: set[str] = set()
|
||||
|
||||
# Presence penalty related data structures
|
||||
self.presence_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
||||
)
|
||||
self.presence_penalties_reqs: set[str] = set()
|
||||
|
||||
# Repetition penalty related data structures
|
||||
self.repetition_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.repetition_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.repetition_penalties_cpu = \
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# req_index -> (min_tokens, stop_token_ids)
|
||||
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||
dtype=np.int32)
|
||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
||||
|
||||
# req_index -> generator
|
||||
# NOTE(woosuk): The indices of the requests that do not have their own
|
||||
# generator should not be included in the dictionary.
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: dict[str, int] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
|
||||
self.logit_bias: list[Optional[dict[int,
|
||||
float]]] = [None] * max_num_reqs
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
||||
# the value is False. Since we use masked_fill_ to set -inf.
|
||||
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
# None elements should only be present transiently
|
||||
# while performing state updates to the batch.
|
||||
return cast(list[str], self._req_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
req_index: Optional[int] = None,
|
||||
) -> None:
|
||||
if req_index is None:
|
||||
req_index = self.num_reqs
|
||||
assert req_index < self.max_num_reqs
|
||||
|
||||
req_id = request.req_id
|
||||
if req_index == len(self._req_ids):
|
||||
self._req_ids.append(req_id)
|
||||
self.req_output_token_ids.append(request.output_token_ids)
|
||||
else:
|
||||
self._req_ids[req_index] = req_id
|
||||
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
request.prompt_token_ids, request.prompt_embeds)
|
||||
# TODO: copy prompt_embeds
|
||||
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
||||
self.token_ids_cpu[
|
||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
start_idx = num_prompt_tokens
|
||||
end_idx = start_idx + len(request.output_token_ids)
|
||||
self.token_ids_cpu[req_index,
|
||||
start_idx:end_idx] = request.output_token_ids
|
||||
# Number of token ids in token_ids_cpu.
|
||||
# NOTE(woosuk): This may include spec decode tokens.
|
||||
self.num_tokens[req_index] = request.num_tokens
|
||||
# Number of tokens without spec decode tokens.
|
||||
self.num_tokens_no_spec[req_index] = request.num_tokens
|
||||
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None, "pooling requests not supported yet"
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Avoid later division by zero.
|
||||
self.temperature_cpu[req_index] = -1.0
|
||||
self.greedy_reqs.add(req_id)
|
||||
else:
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1:
|
||||
self.top_p_reqs.add(req_id)
|
||||
top_k = sampling_params.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_reqs.add(req_id)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k_cpu[req_index] = top_k
|
||||
self.min_p_cpu[req_index] = sampling_params.min_p
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.min_p > _SAMPLING_EPS:
|
||||
self.min_p_reqs.add(req_id)
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[
|
||||
req_index] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
if sampling_params.min_tokens:
|
||||
self.min_tokens[req_index] = (sampling_params.min_tokens,
|
||||
sampling_params.all_stop_token_ids)
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
if request.generator is not None:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||
# Lazy allocation for this tensor, which can be large.
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device)
|
||||
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device="cpu")
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids] = False
|
||||
|
||||
if sampling_params.bad_words_token_ids:
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
if lora_id not in self.lora_id_to_request_ids:
|
||||
self.lora_id_to_request_ids[lora_id] = set()
|
||||
|
||||
self.request_lora_mapping[req_index] = lora_id
|
||||
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
||||
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
||||
else:
|
||||
# No LoRA
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
def remove_request(self, req_id: str) -> Optional[int]:
|
||||
"""This method must always be followed by a call to condense()."""
|
||||
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.min_p_reqs.discard(req_id)
|
||||
self.min_tokens.pop(req_index, None)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
self.lora_id_to_request_ids[lora_id].discard(req_id)
|
||||
if len(self.lora_id_to_request_ids[lora_id]) == 0:
|
||||
self.lora_id_to_request_ids.pop(lora_id)
|
||||
self.lora_id_to_lora_request.pop(lora_id)
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
self.logit_bias[req_index] = None
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||
self.bad_words_token_ids.pop(req_index, None)
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
old_id_i2 = self._req_ids[i2]
|
||||
self._req_ids[i1], self._req_ids[i2] =\
|
||||
self._req_ids[i2], self._req_ids[i1] # noqa
|
||||
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
||||
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
||||
assert old_id_i1 is not None and old_id_i2 is not None
|
||||
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
||||
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
||||
self.num_tokens[i1], self.num_tokens[i2] =\
|
||||
self.num_tokens[i2], self.num_tokens[i1]
|
||||
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
||||
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
||||
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
||||
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
||||
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
||||
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
||||
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
||||
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
||||
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
||||
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
||||
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
||||
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
||||
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
|
||||
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
||||
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
|
||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
||||
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
||||
|
||||
# NOTE: the following is unsafe
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
||||
# instead, we need to temporarily copy the data for one of the indices
|
||||
# TODO(lucas): optimize this by only copying valid indices
|
||||
tmp = self.token_ids_cpu[i1, ...].copy()
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
swap_dict_values(self.generators, i1, i2)
|
||||
swap_dict_values(self.min_tokens, i1, i2)
|
||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
self.logit_bias[i1], self.logit_bias[i2] =\
|
||||
self.logit_bias[i2], self.logit_bias[i1]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1]
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
def condense(self, empty_req_indices: list[int]) -> None:
|
||||
"""Move non-empty requests down into lower, empty indices.
|
||||
|
||||
Args:
|
||||
empty_req_indices: empty batch indices, sorted descending.
|
||||
"""
|
||||
num_reqs = self.num_reqs
|
||||
if num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
self._req_ids.clear()
|
||||
self.req_output_token_ids.clear()
|
||||
return
|
||||
|
||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||
# is sorted in descending order.
|
||||
last_req_index = num_reqs + len(empty_req_indices) - 1
|
||||
while empty_req_indices:
|
||||
# Find the largest non-empty index.
|
||||
while last_req_index in empty_req_indices:
|
||||
last_req_index -= 1
|
||||
|
||||
# Find the smallest empty index.
|
||||
empty_index = empty_req_indices.pop()
|
||||
if empty_index >= last_req_index:
|
||||
break
|
||||
|
||||
# Swap the states.
|
||||
req_id = self._req_ids[last_req_index]
|
||||
output_token_ids = self.req_output_token_ids[last_req_index]
|
||||
assert req_id is not None
|
||||
self._req_ids[empty_index] = req_id
|
||||
self._req_ids[last_req_index] = None
|
||||
self.req_output_token_ids[empty_index] = output_token_ids
|
||||
self.req_output_token_ids[last_req_index] = None
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
num_tokens = self.num_tokens[last_req_index]
|
||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||
last_req_index, :num_tokens]
|
||||
self.num_tokens[empty_index] = num_tokens
|
||||
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
||||
last_req_index]
|
||||
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
||||
last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||
self.frequency_penalties_cpu[
|
||||
empty_index] = self.frequency_penalties_cpu[last_req_index]
|
||||
self.presence_penalties_cpu[
|
||||
empty_index] = self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[
|
||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
min_token = self.min_tokens.pop(last_req_index, None)
|
||||
if min_token is not None:
|
||||
self.min_tokens[empty_index] = min_token
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[
|
||||
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
||||
last_req_index]
|
||||
|
||||
bad_words_token_ids = self.bad_words_token_ids.pop(
|
||||
last_req_index, None)
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
# Trim lists to the batch size.
|
||||
del self._req_ids[self.num_reqs:]
|
||||
del self.req_output_token_ids[self.num_reqs:]
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
(self.num_reqs, max_prompt_len),
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
||||
prompt_token_ids[:] = self.token_ids_cpu[:self.
|
||||
num_reqs, :max_prompt_len]
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
for i in range(self.num_reqs):
|
||||
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
||||
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
||||
non_blocking=True)
|
||||
|
||||
def make_lora_inputs(
|
||||
self, num_scheduled_tokens: np.ndarray
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||
"""
|
||||
Given the num_scheduled_tokens for each request in the batch, return
|
||||
datastructures used to activate the current LoRAs.
|
||||
Returns:
|
||||
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
||||
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
||||
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
||||
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
||||
3. lora_requests: Set of relevant LoRA requests.
|
||||
"""
|
||||
|
||||
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
||||
prompt_lora_mapping = tuple(req_lora_mapping)
|
||||
token_lora_mapping = tuple(
|
||||
req_lora_mapping.repeat(num_scheduled_tokens))
|
||||
active_lora_requests: set[LoRARequest] = set(
|
||||
self.lora_id_to_lora_request.values())
|
||||
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
@property
|
||||
def all_greedy(self) -> bool:
|
||||
return len(self.random_reqs) == 0
|
||||
|
||||
@property
|
||||
def all_random(self) -> bool:
|
||||
return len(self.greedy_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_p(self) -> bool:
|
||||
return len(self.top_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_min_p(self) -> bool:
|
||||
return len(self.min_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_penalties(self) -> bool:
|
||||
return (len(self.presence_penalties_reqs) == 0
|
||||
and len(self.frequency_penalties_reqs) == 0
|
||||
and len(self.repetition_penalties_reqs) == 0)
|
||||
|
||||
@property
|
||||
def max_num_logprobs(self) -> Optional[int]:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
||||
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return not self.num_prompt_logprobs
|
||||
|
||||
@property
|
||||
def no_allowed_token_ids(self) -> bool:
|
||||
return len(self.has_allowed_token_ids) == 0
|
||||
1946
vllm/v1/worker/tpu_model_runner.py
Normal file
1946
vllm/v1/worker/tpu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
346
vllm/v1/worker/tpu_worker.py
Normal file
346
vllm/v1/worker/tpu_worker.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A TPU worker class."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
||||
has_kv_transfer_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R")
|
||||
|
||||
if not USE_TPU_COMMONS:
|
||||
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
|
||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
||||
|
||||
|
||||
class TPUWorker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.use_spmd = envs.VLLM_XLA_USE_SPMD
|
||||
self.original_parallel_config = None
|
||||
if self.use_spmd:
|
||||
# Under SPMD mode, distributed env is initialized as if there is
|
||||
# only one worker/device.
|
||||
self.original_parallel_config = self.parallel_config
|
||||
self.parallel_config.tensor_parallel_size = 1
|
||||
self.parallel_config.pipeline_parallel_size = 1
|
||||
self.parallel_config.world_size = 1
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.cache_dtype = self.model_config.dtype
|
||||
else:
|
||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Delay profiler initialization to the start of the profiling.
|
||||
# This is because in vLLM V1, MP runtime is initialized before the
|
||||
# TPU Worker is initialized. The profiler server needs to start after
|
||||
# MP runtime is initialized.
|
||||
self.profiler = None
|
||||
self.profile_dir = None
|
||||
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
||||
# For TPU, we can only have 1 active profiler session for 1 profiler
|
||||
# server. So we only profile on rank0.
|
||||
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
self.profile_dir)
|
||||
|
||||
if self.model_config.seed is None:
|
||||
self.model_config.seed = 0
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
def init_device(self):
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||
# ring, the xla tpu compiler flag
|
||||
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
|
||||
# fix this. It will be removed after the bug in XLA compiler is fixed.
|
||||
os.environ["LIBTPU_INIT_ARGS"] = (
|
||||
os.environ.get("LIBTPU_INIT_ARGS", "") +
|
||||
" --xla_tpu_force_1d_allreduce_at_chunk_count=1"
|
||||
" --xla_jf_conv_input_fusion=False")
|
||||
# --xla_jf_conv_input_fusion=False is used to improve the perf of
|
||||
# quantized matmul.
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
# Initialize the distributed environment.
|
||||
self._init_tpu_worker_distributed_environment(
|
||||
self.vllm_config, self.rank, self.distributed_init_method,
|
||||
self.local_rank)
|
||||
|
||||
# Device initialization should happen after initializing
|
||||
# the distributed runtime.
|
||||
self.device = xm.xla_device()
|
||||
self.device_config.device = self.device
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
if self.model_config.seed is not None:
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
# TODO (NickLucche) On gsm we compile 80+ graphs.
|
||||
# Re-evaluate limit, with MM we may get close to this limit.
|
||||
torch._dynamo.config.cache_size_limit = 128
|
||||
# Use persistent cache to avoid XLA recompilation.
|
||||
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||
# can have slightly different XLA graphs.
|
||||
world_size = self.parallel_config.world_size
|
||||
rank = xr.global_ordinal()
|
||||
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
|
||||
# Consequently, changes in optimization flags, which affect compilation
|
||||
# results, don't change the cache key. This can result in the wrong
|
||||
# compilation being used. To prevent this, disabling the XLA compilation
|
||||
# cache during development is recommended.We can disable it by
|
||||
# `export VLLM_XLA_CACHE_PATH=`
|
||||
if envs.VLLM_XLA_CACHE_PATH:
|
||||
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
|
||||
f"tp{world_size}_rank{rank}")
|
||||
xr.initialize_cache(per_rank_path, readonly=False)
|
||||
|
||||
# Init ModelRunner here, so that we have access to self.device.
|
||||
self.model_runner = \
|
||||
TPUModelRunner(self.vllm_config, self.device,
|
||||
self.original_parallel_config)
|
||||
|
||||
if rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
report_usage_stats(self.vllm_config)
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
if isinstance(layer_spec, AttentionSpec):
|
||||
dtype = layer_spec.dtype
|
||||
|
||||
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
|
||||
kv_caches[layer_name] = tpu_kv_cache
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported KV cache spec '{type(layer_spec)}'")
|
||||
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
runner_kv_caches)
|
||||
|
||||
# `max_num_tokens >= max_num_batched_tokens` due to padding.
|
||||
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
|
||||
self.model_runner.profile_run(self.model_runner.max_num_tokens)
|
||||
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
# During the profiling run, the model runs without KV cache. After
|
||||
# the profiling run, the model always runs with KV cache. Here we clear
|
||||
# the dynamo cache and cached bytecode to ensure the model always has
|
||||
# one compiled bytecode. Having one FX graph/cached bytecode per
|
||||
# compiled model is required for `support_torch_compile` decorator to
|
||||
# skip dynamo guard.
|
||||
self.model_runner.reset_dynamo_cache()
|
||||
|
||||
# Get the maximum amount of memory used by the model weights and
|
||||
# intermediate activations.
|
||||
if self.use_spmd:
|
||||
# This is a workaround for the TPU SPMD mode. The get_memory_info
|
||||
# API doesn't work with SPMD mode in PyTorch/XLA.
|
||||
# TODO: use xm.get_memory_info for SPMD once it's supported in
|
||||
# PyTorch/XLA.
|
||||
import tpu_info
|
||||
chip_type, _ = tpu_info.device.get_local_chips()
|
||||
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
|
||||
total_memory_size = device_usage[0].total_memory
|
||||
current_mem = device_usage[0].memory_usage
|
||||
else:
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
current_mem = m["bytes_used"]
|
||||
# Ideally we would use profiled = m["peak_bytes_used"] to
|
||||
# get weights + activations. But there is memory used during
|
||||
# compilation / weight loading that impacts the peak and
|
||||
# there is no way to reset peak memory in XLA, So we
|
||||
# use the heuristic of 2% of weights.
|
||||
profiled = current_mem * 1.02
|
||||
|
||||
# Calculate the TPU KV cache size based on profiling.
|
||||
usable_memory_size = int(total_memory_size *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||
head_size = self.model_config.get_head_size()
|
||||
if head_size > 0:
|
||||
padded_head_size = cdiv(
|
||||
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
if padded_head_size != head_size:
|
||||
logger.warning_once("head size is padded to %d",
|
||||
padded_head_size)
|
||||
# We adjust the usable memory size for the KV cache to prevent OOM
|
||||
# errors, even after padding the head_size.
|
||||
tpu_kv_cache_bytes = (tpu_kv_cache_bytes * head_size //
|
||||
padded_head_size)
|
||||
return int(tpu_kv_cache_bytes)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
output = self.model_runner.execute_model(scheduler_output)
|
||||
# every worker's output is needed when kv_transfer_group is set up
|
||||
return output if self.is_driver_worker or has_kv_transfer_group(
|
||||
) else None
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.rank < 1:
|
||||
if self.profile_dir is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
if self.profiler is None:
|
||||
self.profiler = xp.start_server(9012)
|
||||
xp.start_trace(self.profile_dir)
|
||||
else:
|
||||
xp.stop_trace()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model_runner.load_model()
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
|
||||
def reload_weights(self) -> None:
|
||||
self.model_runner.reload_weights()
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.model_runner.get_supported_tasks()
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def _init_tpu_worker_distributed_environment(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
if self.use_spmd:
|
||||
xr.use_spmd()
|
||||
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||
# own context.
|
||||
parallel_config = vllm_config.parallel_config
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
backend=current_platform.dist_backend,
|
||||
)
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.model_runner.ensure_kv_transfer_shutdown()
|
||||
|
||||
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
|
||||
"""Apply a function on the model inside this worker."""
|
||||
return fn(self.get_model())
|
||||
|
||||
|
||||
if USE_TPU_COMMONS:
|
||||
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
||||
|
||||
TPUWorker = TPUCommonsWorker # type: ignore
|
||||
192
vllm/v1/worker/ubatch_splitting.py
Normal file
192
vllm/v1/worker/ubatch_splitting.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.forward_context import DPMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import round_up
|
||||
from vllm.v1.worker.ubatch_utils import (UBatchSlice, UBatchSlices,
|
||||
is_second_ubatch_empty)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def should_ubatch_with_num_tokens(
|
||||
should_ubatch: bool,
|
||||
orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
return DPMetadata.should_ubatch_across_dp(should_ubatch,
|
||||
orig_num_tokens_per_ubatch,
|
||||
padded_num_tokens_per_ubatch,
|
||||
dp_size, dp_rank)
|
||||
|
||||
|
||||
def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int,
|
||||
uniform_decode: bool) -> bool:
|
||||
if not config.enable_dbo:
|
||||
return False
|
||||
if uniform_decode:
|
||||
return num_tokens >= config.dbo_decode_token_threshold
|
||||
else:
|
||||
return num_tokens >= config.dbo_prefill_token_threshold
|
||||
|
||||
|
||||
def get_dp_padding_ubatch(
|
||||
num_tokens_unpadded: int, num_tokens_padded: int,
|
||||
should_attempt_ubatching: bool,
|
||||
vllm_config: VllmConfig) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
"""
|
||||
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||
run with microbatching or none of them do. If this function decides
|
||||
not to run with microbatching. It will "abort" meaning that no padding
|
||||
information will be returned to the caller. It will return (False, None)
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
All ranks will be padded out so that the run with the same number
|
||||
of tokens
|
||||
|
||||
Returns: tuple[
|
||||
should_ubatch: Are all DP ranks going to microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
None if should_ubatch if False
|
||||
]
|
||||
|
||||
"""
|
||||
assert num_tokens_padded >= num_tokens_unpadded
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
if dp_size == 1:
|
||||
# Early exit.
|
||||
return False, None
|
||||
|
||||
# If this DP rank doesn't want to attempt microbatching
|
||||
if not should_attempt_ubatching:
|
||||
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
|
||||
False, 0, 0, vllm_config)
|
||||
assert should_ubatch is False
|
||||
assert num_tokens_across_dp is None
|
||||
return should_ubatch, num_tokens_across_dp
|
||||
|
||||
# Round up to the next multiple of two for even divisibility
|
||||
num_tokens_padded = round_up(num_tokens_padded, 2)
|
||||
num_tokens_per_ubatch = num_tokens_padded // 2
|
||||
should_ubatch = True
|
||||
|
||||
# Sanity Check that the existing padding isn't giving us an empty second
|
||||
# ubatch. Abort if so
|
||||
if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded):
|
||||
logger.debug(
|
||||
"Empty second µbatch detected: unpadded tokens: %s, padded "
|
||||
"tokens: %s", num_tokens_unpadded, num_tokens_padded)
|
||||
should_ubatch = False
|
||||
|
||||
# Note that we compute the number of padded tokens per ubatch
|
||||
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
|
||||
should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch,
|
||||
vllm_config)
|
||||
if not should_ubatch:
|
||||
assert num_tokens_across_dp is None
|
||||
return should_ubatch, num_tokens_across_dp
|
||||
|
||||
assert num_tokens_across_dp is not None
|
||||
|
||||
max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item())
|
||||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
|
||||
dp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return should_ubatch, num_tokens_after_padding
|
||||
|
||||
def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \
|
||||
-> UBatchSlices:
|
||||
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
|
||||
|
||||
first_ubatch_token_slice = slice(0, split_point)
|
||||
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
|
||||
|
||||
# Determine request slices using exclusive stop semantics
|
||||
# First ubatch includes requests whose tokens overlap [0, split_point)
|
||||
first_ubatch_req_stop = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="left"))
|
||||
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
|
||||
|
||||
# Second ubatch starts at the request that contains the split_point
|
||||
# or the request starting exactly at split_point (if on boundary)
|
||||
second_ubatch_req_start = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="right") - 1)
|
||||
second_ubatch_req_slice = slice(second_ubatch_req_start,
|
||||
len(cu_num_tokens) - 1)
|
||||
|
||||
return [
|
||||
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
||||
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice)
|
||||
]
|
||||
|
||||
|
||||
def ubatch_split(
|
||||
num_scheduled_tokens_per_request: np.ndarray,
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
uniform_decode: bool,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Coordinates amongst all DP ranks to determine if and how the full batch
|
||||
should be split into microbatches.
|
||||
|
||||
Returns: tuple[
|
||||
ubatch_slices: if this is set then all DP ranks have agreed to
|
||||
microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
None if ubatch_slices is None
|
||||
]
|
||||
|
||||
"""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
# Don't bother with the should_ubatch handshaking unless microbatching
|
||||
# is enabled
|
||||
if not parallel_config.enable_dbo:
|
||||
return (None, None)
|
||||
|
||||
# Check preconditions for microbatching
|
||||
should_attempt_ubatching = check_ubatch_thresholds(
|
||||
parallel_config,
|
||||
num_tokens_unpadded,
|
||||
uniform_decode=uniform_decode,
|
||||
)
|
||||
|
||||
# Don't microbatch unless every other DP worker is also microbatching
|
||||
should_ubatch, num_tokens_after_padding = get_dp_padding_ubatch(
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
should_attempt_ubatching,
|
||||
vllm_config,
|
||||
)
|
||||
|
||||
if not should_ubatch:
|
||||
return (None, None)
|
||||
|
||||
# This doesn't actually pad the ubatch slices. It just initializes the
|
||||
# split point to the padded value so that padding can be applied
|
||||
# to the second ubatch in pad_out_ubatch_slice after attention
|
||||
# metadata creation
|
||||
assert num_tokens_after_padding is not None
|
||||
token_split_point = int(num_tokens_after_padding[0].item())
|
||||
|
||||
ubatch_slices = create_ubatch_slices(num_scheduled_tokens_per_request,
|
||||
token_split_point)
|
||||
|
||||
return (ubatch_slices, num_tokens_after_padding)
|
||||
27
vllm/v1/worker/ubatch_utils.py
Normal file
27
vllm/v1/worker/ubatch_utils.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
|
||||
@dataclass
|
||||
class UBatchSlice:
|
||||
request_slice: slice
|
||||
token_slice: slice
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self.request_slice.start == self.request_slice.stop \
|
||||
or self.token_slice.start == self.token_slice.stop
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.token_slice.stop - self.token_slice.start
|
||||
|
||||
|
||||
UBatchSlices: TypeAlias = list[UBatchSlice]
|
||||
|
||||
|
||||
def is_second_ubatch_empty(orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int) -> bool:
|
||||
return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch
|
||||
224
vllm/v1/worker/ubatching.py
Normal file
224
vllm/v1/worker/ubatching.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import forward_context
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import current_stream
|
||||
|
||||
_THREAD_ID_TO_CONTEXT: dict = {}
|
||||
_CURRENT_CONTEXTS: list[Optional['UBatchContext']] = [None, None]
|
||||
|
||||
|
||||
class UBatchContext:
|
||||
"""
|
||||
Context manager for micro-batching synchronization using threading events.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
id: int,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
forward_context: ForwardContext,
|
||||
ready_barrier: threading.Barrier,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_comm_done_event: torch.cuda.Event,
|
||||
gpu_compute_done_event: torch.cuda.Event,
|
||||
schedule: str = "default"):
|
||||
self.id = id
|
||||
self.comm_stream = comm_stream
|
||||
self.compute_stream = compute_stream
|
||||
self.forward_context = forward_context
|
||||
self.ready_barrier = ready_barrier
|
||||
self.cpu_wait_event = cpu_wait_event
|
||||
self.cpu_signal_event = cpu_signal_event
|
||||
self.current_stream = compute_stream
|
||||
self.gpu_comm_done_event = gpu_comm_done_event
|
||||
self.gpu_compute_done_event = gpu_compute_done_event
|
||||
self.schedule = schedule
|
||||
self.recv_hook = None
|
||||
|
||||
def __enter__(self):
|
||||
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
|
||||
_CURRENT_CONTEXTS[self.id] = self
|
||||
self.ready_barrier.wait()
|
||||
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
# Assume we want to start on the compute stream
|
||||
self.update_stream(self.compute_stream)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||
_CURRENT_CONTEXTS[self.id] = None
|
||||
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
self.maybe_run_recv_hook()
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.clear()
|
||||
return False
|
||||
|
||||
def _restore_context(self):
|
||||
forward_context._forward_context = self.forward_context
|
||||
|
||||
def update_stream(self, stream):
|
||||
self.current_stream = stream
|
||||
if current_stream() != self.current_stream:
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
|
||||
def _signal_comm_done(self):
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
|
||||
def _signal_compute_done(self):
|
||||
self.gpu_compute_done_event.record(self.compute_stream)
|
||||
|
||||
def _wait_compute_done(self):
|
||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||
|
||||
def _wait_comm_done(self):
|
||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||
|
||||
def _cpu_yield(self):
|
||||
# It is critical for correctness that only one thread is running
|
||||
# at a time. These asserts just make sure that this is the only
|
||||
# thread running before waking the other one up and going to sleep
|
||||
assert forward_context._forward_context == self.forward_context
|
||||
assert current_stream() == self.current_stream
|
||||
assert not self.cpu_wait_event.is_set()
|
||||
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
|
||||
def switch_to_comm(self):
|
||||
self.update_stream(self.comm_stream)
|
||||
|
||||
def switch_to_compute(self):
|
||||
self.update_stream(self.compute_stream)
|
||||
|
||||
def switch_to_comm_sync(self):
|
||||
self._signal_compute_done()
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def switch_to_compute_sync(self):
|
||||
self._signal_comm_done()
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
def maybe_run_recv_hook(self):
|
||||
if self.recv_hook is not None:
|
||||
self.recv_hook()
|
||||
self.recv_hook = None
|
||||
|
||||
def yield_(self):
|
||||
self.current_stream = current_stream()
|
||||
self._cpu_yield()
|
||||
self.update_stream(self.current_stream)
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
assert current_stream() == self.compute_stream
|
||||
self._signal_compute_done()
|
||||
self._cpu_yield()
|
||||
assert self.current_stream == self.compute_stream
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
assert current_stream() == self.comm_stream
|
||||
self._signal_comm_done()
|
||||
self._cpu_yield()
|
||||
assert self.current_stream == self.comm_stream
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
|
||||
def dbo_enabled() -> bool:
|
||||
return len(_THREAD_ID_TO_CONTEXT) > 0
|
||||
|
||||
|
||||
def dbo_current_ubatch_id() -> int:
|
||||
if len(_THREAD_ID_TO_CONTEXT) == 0:
|
||||
return 0
|
||||
return _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
|
||||
|
||||
def _register_ubatch_function(func):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
ctx = _CURRENT_CONTEXTS[ctx_idx]
|
||||
func(ctx, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
dbo_maybe_run_recv_hook = _register_ubatch_function(
|
||||
UBatchContext.maybe_run_recv_hook)
|
||||
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
|
||||
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
|
||||
UBatchContext.yield_and_switch_from_compute_to_comm)
|
||||
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
|
||||
UBatchContext.yield_and_switch_from_comm_to_compute)
|
||||
dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm)
|
||||
dbo_switch_to_compute = _register_ubatch_function(
|
||||
UBatchContext.switch_to_compute)
|
||||
dbo_switch_to_comm_sync = _register_ubatch_function(
|
||||
UBatchContext.switch_to_comm_sync)
|
||||
dbo_switch_to_compute_sync = _register_ubatch_function(
|
||||
UBatchContext.switch_to_compute_sync)
|
||||
|
||||
|
||||
def dbo_register_recv_hook(recv_hook):
|
||||
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2]
|
||||
next_ctx.recv_hook = recv_hook
|
||||
|
||||
|
||||
def make_ubatch_contexts(
|
||||
num_micro_batches: int,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
forward_contexts: list[ForwardContext],
|
||||
ready_barrier: threading.Barrier,
|
||||
schedule: str = "default",
|
||||
) -> list[UBatchContext]:
|
||||
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
||||
"""
|
||||
Create a context manager for micro-batching synchronization.
|
||||
"""
|
||||
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
|
||||
gpu_comm_done_events = [
|
||||
torch.cuda.Event() for _ in range(num_micro_batches)
|
||||
]
|
||||
gpu_compute_done_events = [
|
||||
torch.cuda.Event() for _ in range(num_micro_batches)
|
||||
]
|
||||
|
||||
assert len(forward_contexts) == 2
|
||||
|
||||
ctxs = []
|
||||
for i in range(num_micro_batches):
|
||||
ctx = UBatchContext(id=i,
|
||||
compute_stream=compute_stream,
|
||||
comm_stream=comm_stream,
|
||||
forward_context=forward_contexts[i],
|
||||
ready_barrier=ready_barrier,
|
||||
cpu_wait_event=cpu_events[i],
|
||||
cpu_signal_event=cpu_events[(i + 1) %
|
||||
num_micro_batches],
|
||||
gpu_comm_done_event=gpu_comm_done_events[i],
|
||||
gpu_compute_done_event=gpu_compute_done_events[i],
|
||||
schedule=schedule)
|
||||
ctxs.append(ctx)
|
||||
|
||||
return ctxs
|
||||
344
vllm/v1/worker/utils.py
Normal file
344
vllm/v1/worker/utils.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
from vllm.multimodal.registry import MultiModalRegistry
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
|
||||
class MultiModalBudget:
|
||||
"""Helper class to calculate budget information for multi-modal models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
mm_registry: MultiModalRegistry,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.mm_registry = mm_registry
|
||||
self.cache = cache = processor_only_cache_from_config(
|
||||
model_config, mm_registry)
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config,
|
||||
cache=cache)
|
||||
|
||||
max_tokens_by_modality = mm_registry \
|
||||
.get_max_tokens_per_item_by_nonzero_modality(model_config,
|
||||
cache=cache)
|
||||
|
||||
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
|
||||
scheduler_config,
|
||||
max_tokens_by_modality,
|
||||
)
|
||||
|
||||
self.encoder_compute_budget = encoder_compute_budget
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
max_items_per_prompt_by_modality = dict[str, int]()
|
||||
max_items_per_batch_by_modality = dict[str, int]()
|
||||
|
||||
for modality, max_tokens in max_tokens_by_modality.items():
|
||||
(
|
||||
max_items_per_prompt,
|
||||
max_items_per_batch,
|
||||
) = self.get_max_items(modality, max_tokens)
|
||||
|
||||
max_items_per_prompt_by_modality[modality] = max_items_per_prompt
|
||||
max_items_per_batch_by_modality[modality] = max_items_per_batch
|
||||
|
||||
self.max_tokens_by_modality = max_tokens_by_modality
|
||||
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
|
||||
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
|
||||
|
||||
def get_modality_with_max_tokens(self) -> str:
|
||||
max_tokens_by_modality = self.max_tokens_by_modality
|
||||
modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
|
||||
|
||||
return modality
|
||||
|
||||
def get_encoder_budget(self) -> int:
|
||||
return min(self.encoder_compute_budget, self.encoder_cache_size)
|
||||
|
||||
def get_max_items(
|
||||
self,
|
||||
modality: str,
|
||||
max_tokens_per_item: int,
|
||||
) -> tuple[int, int]:
|
||||
if max_tokens_per_item == 0:
|
||||
return 0, 0
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the encoder budget.
|
||||
encoder_budget = self.get_encoder_budget()
|
||||
|
||||
# TODO: handle encoder-decoder models once we support them.
|
||||
if encoder_budget == 0:
|
||||
return 0, 0
|
||||
|
||||
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the decoder budget.
|
||||
mm_limit = self.mm_limits[modality]
|
||||
|
||||
max_items_per_prompt = max(
|
||||
1,
|
||||
min(mm_limit, self.max_model_len // max_tokens_per_item),
|
||||
)
|
||||
|
||||
scheduler_config = self.scheduler_config
|
||||
max_num_reqs = self.max_num_reqs
|
||||
|
||||
if not scheduler_config.enable_chunked_prefill:
|
||||
max_num_reqs = min(
|
||||
max_num_reqs,
|
||||
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
|
||||
)
|
||||
|
||||
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
|
||||
|
||||
max_items_per_batch = max(
|
||||
1,
|
||||
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
|
||||
)
|
||||
|
||||
return max_items_per_prompt, max_items_per_batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
# When ubatching is enabled we will have a metadata builder for each ubatch
|
||||
# so that if they use internal persistant buffers for cudagraphs, and they
|
||||
# won't have to worry about conflicting with the other ubatches.
|
||||
metadata_builders: list[AttentionMetadataBuilder]
|
||||
layer_names: list[str]
|
||||
kv_cache_spec: KVCacheSpec
|
||||
|
||||
@staticmethod
|
||||
def create_with_metadata_builders(
|
||||
backend: type[AttentionBackend],
|
||||
layer_names: list[str],
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
num_metadata_builders: int = 1,
|
||||
) -> 'AttentionGroup':
|
||||
metadata_builders = [
|
||||
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config,
|
||||
device)
|
||||
for _ in range(num_metadata_builders)
|
||||
]
|
||||
return AttentionGroup(backend, metadata_builders, layer_names,
|
||||
kv_cache_spec)
|
||||
|
||||
def get_metadata_builder(self,
|
||||
ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
||||
assert len(self.metadata_builders) > ubatch_id
|
||||
return self.metadata_builders[ubatch_id]
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: MultiModalEmbeddings,
|
||||
expected_num_items: int,
|
||||
) -> None:
|
||||
"""
|
||||
Perform sanity checks for the result of
|
||||
[`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
|
||||
"""
|
||||
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
|
||||
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
|
||||
f"or a single 3D tensor, but got {type(mm_embeddings)} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
assert len(mm_embeddings) == expected_num_items, (
|
||||
"Expected number of multimodal embeddings to match number of "
|
||||
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
assert all(e.ndim == 2 for e in mm_embeddings), (
|
||||
"Expected multimodal embeddings to be a sequence of 2D tensors, "
|
||||
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
|
||||
def scatter_mm_placeholders(
|
||||
embeds: torch.Tensor,
|
||||
is_embed: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Scatter the multimodal embeddings into a contiguous tensor that represents
|
||||
the placeholder tokens.
|
||||
|
||||
[`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
|
||||
|
||||
Args:
|
||||
embeds: The multimodal embeddings.
|
||||
Shape: `(num_embeds, embed_dim)`
|
||||
is_embed: A boolean mask indicating which positions in the placeholder
|
||||
tokens need to be filled with multimodal embeddings.
|
||||
Shape: `(num_placeholders, num_embeds)`
|
||||
"""
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
|
||||
placeholders = embeds.new_full(
|
||||
(is_embed.shape[0], embeds.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
placeholders[is_embed] = embeds
|
||||
return placeholders
|
||||
|
||||
|
||||
def gather_mm_placeholders(
|
||||
placeholders: torch.Tensor,
|
||||
is_embed: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reconstructs the embeddings from the placeholder tokens.
|
||||
|
||||
This is the operation of [`scatter_mm_placeholders`]
|
||||
[vllm.v1.worker.utils.scatter_mm_placeholders].
|
||||
"""
|
||||
if is_embed is None:
|
||||
return placeholders
|
||||
|
||||
return placeholders[is_embed]
|
||||
|
||||
|
||||
def add_kv_sharing_layers_to_kv_cache_groups(
|
||||
shared_kv_cache_layers: dict[str, str],
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
runner_only_attn_layers: Optional[set[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
|
||||
for layers that do not allocate its own KV cache, based on the mapping in
|
||||
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
|
||||
group, which is needed to ensure that attention metadata is assigned later.
|
||||
|
||||
Args:
|
||||
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
|
||||
If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
means this layer will perform attention using the keys and values
|
||||
from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
kv_cache_groups: The KV cache groups of the model.
|
||||
"""
|
||||
layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
|
||||
for kv_cache_group in kv_cache_groups:
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
layer_to_kv_cache_group[layer_name] = kv_cache_group
|
||||
|
||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
|
||||
tgt_kv_cache_group.layer_names.append(layer_name)
|
||||
|
||||
if runner_only_attn_layers is not None:
|
||||
runner_only_attn_layers.add(layer_name)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
num_attn_module: Optional[int] = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name,
|
||||
num_attn_module)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
|
||||
# TODO - analyze where runner_kv_caches is used and the right
|
||||
# way to ensure it properly reflects multiple attention layers
|
||||
# in the same decoder block.
|
||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||
# We know that the GPU runner is not impacted by this
|
||||
# case. Some test code depends on runner_kv_caches, but
|
||||
# not in a way that's impacted by ignoring this.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
|
||||
def is_residual_scattered_for_sp(vllm_config: VllmConfig,
|
||||
num_input_tokens: int) -> bool:
|
||||
"""Check if the residual tensor is scattered for sequence parallelism.
|
||||
|
||||
The residual tensor is scattered across tensor parallel ranks when sequence
|
||||
parallelism and tensor parallelism is enabled, and the number of
|
||||
input tokens is one of the compilation sizes.
|
||||
"""
|
||||
if not vllm_config.compilation_config.pass_config.\
|
||||
enable_sequence_parallelism:
|
||||
return False
|
||||
|
||||
tp = vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
if tp == 1:
|
||||
return False
|
||||
|
||||
# When sequence parallelism is enabled, we always pad num_input_tokens
|
||||
# to be a multiple of tensor_parallel_size (tp) earlier.
|
||||
assert num_input_tokens % tp == 0
|
||||
|
||||
# Currently, SP is only enabled for static size fx graphs.
|
||||
return (num_input_tokens in vllm_config.compilation_config.compile_sizes)
|
||||
65
vllm/v1/worker/worker_base.py
Normal file
65
vllm/v1/worker/worker_base.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class WorkerBase(WorkerBaseV0):
|
||||
"""
|
||||
Abstract class for v1 worker, mainly define some methods for v1.
|
||||
For methods shared by v0 and v1, define them in v0 WorkerBase
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize common worker components.
|
||||
|
||||
Args:
|
||||
vllm_config: Complete vLLM configuration
|
||||
local_rank: Local device index
|
||||
rank: Global rank in distributed setup
|
||||
distributed_init_method: Distributed initialization method
|
||||
is_driver_worker: Whether this worker handles driver
|
||||
responsibilities
|
||||
"""
|
||||
# Configuration storage
|
||||
super().__init__(vllm_config=vllm_config)
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
# Device and model state
|
||||
self.device: Optional[torch.device] = None
|
||||
self.model_runner: Optional[nn.Module] = None
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""Get specifications for KV cache implementation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
"""Prepare model for execution through compilation/warmup."""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Basic health check (override for device-specific checks)."""
|
||||
return
|
||||
57
vllm/v1/worker/xpu_model_runner.py
Normal file
57
vllm/v1/worker/xpu_model_runner.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUModelRunner(GPUModelRunner):
|
||||
"""A model runner for XPU devices."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
with _torch_cuda_wrapper():
|
||||
super().__init__(vllm_config, device)
|
||||
# FIXME: To be verified.
|
||||
self.cascade_attn_enabled = False
|
||||
|
||||
def _init_device_properties(self) -> None:
|
||||
self.num_sms = None
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
torch.xpu.synchronize()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _torch_cuda_wrapper():
|
||||
|
||||
class _EventPlaceholder:
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self.record = lambda: None
|
||||
self.synchronize = lambda: None
|
||||
|
||||
try:
|
||||
# replace cuda APIs with xpu APIs, this should work by default
|
||||
torch.cuda.Event = torch.xpu.Event
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.cuda.default_stream = torch.xpu.current_stream
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
yield
|
||||
finally:
|
||||
# if anything goes wrong, just patch it with a placeholder
|
||||
torch.cuda.Event = _EventPlaceholder
|
||||
179
vllm/v1/worker/xpu_worker.py
Normal file
179
vllm/v1/worker/xpu_worker.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_world_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.gpu_worker import (Worker,
|
||||
init_worker_distributed_environment)
|
||||
from vllm.v1.worker.xpu_model_runner import XPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUWorker(Worker):
|
||||
"""A XPU worker class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
super().__init__(vllm_config, local_rank, rank,
|
||||
distributed_init_method, is_driver_worker)
|
||||
device_config = self.device_config
|
||||
assert device_config.device_type == "xpu"
|
||||
assert current_platform.is_xpu()
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
logger.debug(
|
||||
"Profiler config: record_shapes=%s,"
|
||||
"profile_memory=%s,with_stack=%s,with_flops=%s",
|
||||
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.XPU,
|
||||
],
|
||||
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
# we provide this function due to `torch.xpu.mem_get_info()` doesn't
|
||||
# return correct free_gpu_memory on intel client GPU. We need to
|
||||
# calculate/estiamte it.
|
||||
def xpu_get_mem_info(self):
|
||||
if current_platform.is_data_center_gpu():
|
||||
return torch.xpu.mem_get_info()
|
||||
else:
|
||||
_, total_gpu_memory = torch.xpu.mem_get_info()
|
||||
# FIXME: memory_allocated() doesn't count non-torch allocations,
|
||||
# and we don't have any API to get it. so we mark it as 128MB.
|
||||
used_memory = torch.xpu.memory_allocated()
|
||||
non_torch_allocations = 128 * 1024 * 1024
|
||||
free_gpu_memory = total_gpu_memory - (used_memory +
|
||||
non_torch_allocations)
|
||||
return free_gpu_memory, total_gpu_memory
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculates the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
torch.xpu.empty_cache()
|
||||
torch.xpu.reset_peak_memory_stats()
|
||||
|
||||
free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info()
|
||||
current_allocated_bytes = torch.xpu.memory_allocated()
|
||||
msg = ("Before memory profiling run, "
|
||||
f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, "
|
||||
f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, "
|
||||
f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.")
|
||||
logger.info(msg)
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
self.model_runner.profile_run()
|
||||
|
||||
free_gpu_memory, _ = self.xpu_get_mem_info()
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
assert self.init_gpu_memory > free_gpu_memory, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
# Get the peak memory allocation recorded by torch
|
||||
peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"]
|
||||
|
||||
torch.xpu.empty_cache()
|
||||
torch_allocated_bytes = torch.xpu.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
total_allocated_bytes = self.xpu_get_mem_info(
|
||||
)[1] - self.xpu_get_mem_info()[0]
|
||||
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
if non_torch_allocations > 0:
|
||||
peak_memory += non_torch_allocations
|
||||
available_kv_cache_memory = (
|
||||
total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory)
|
||||
|
||||
msg = ("After memory profiling run, "
|
||||
f"peak memory usage is {peak_memory / 1024**2:.2f} MB,"
|
||||
f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, "
|
||||
f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, "
|
||||
f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.")
|
||||
logger.info(msg)
|
||||
|
||||
return int(available_kv_cache_memory)
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
|
||||
):
|
||||
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
||||
torch.xpu.empty_cache()
|
||||
self.init_gpu_memory = torch.xpu.get_device_properties(
|
||||
self.local_rank).total_memory
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
|
||||
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "pidfd")
|
||||
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
|
||||
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
|
||||
str(self.parallel_config.world_size))
|
||||
os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE
|
||||
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
|
||||
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
|
||||
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
||||
|
||||
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
current_platform.dist_backend)
|
||||
|
||||
# global all_reduce needed for overall oneccl warm up
|
||||
torch.distributed.all_reduce(torch.zeros(1).xpu(),
|
||||
group=get_world_group().device_group)
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner = XPUModelRunner( # type: ignore
|
||||
self.vllm_config, self.device)
|
||||
Reference in New Issue
Block a user