Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

View File

@@ -0,0 +1,343 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
logger = init_logger(__name__)
class BlockTable:
def __init__(
self,
block_size: int,
max_num_reqs: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
kernel_block_size: int,
cp_kv_cache_interleave_size: int,
):
"""
Args:
block_size: Block size used for KV cache memory allocation
max_num_reqs: Maximum number of concurrent requests supported.
max_num_blocks_per_req: Maximum number of blocks per request.
max_num_batched_tokens: Maximum number of tokens in a batch.
pin_memory: Whether to pin memory for faster GPU transfers.
device: Target device for the block table.
kernel_block_size: The block_size of underlying attention kernel.
Will be the same as `block_size` if `block_size` is supported
by the attention kernel.
"""
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.pin_memory = pin_memory
self.device = device
if kernel_block_size == block_size:
# Standard case: allocation and computation use same block size
# No block splitting needed, direct mapping
self.block_size = block_size
self.blocks_per_kv_block = 1
self.use_hybrid_blocks = False
else:
# Hybrid case: allocation block size differs from kernel block size
# Memory blocks are subdivided to match kernel requirements
# Example: 32-token memory blocks with 16-token kernel blocks
# → Each memory block corresponds to 2 kernel blocks
if block_size % kernel_block_size != 0:
raise ValueError(
f"kernel_block_size {kernel_block_size} must divide "
f"kv_manager_block_size size {block_size} evenly"
)
self.block_size = kernel_block_size
self.blocks_per_kv_block = block_size // kernel_block_size
self.use_hybrid_blocks = True
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
self.block_table = self._make_buffer(
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
)
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping = self._make_buffer(
self.max_num_batched_tokens, dtype=torch.int64
)
if self.use_hybrid_blocks:
self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
1, -1
)
else:
self._kernel_block_arange = None
try:
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
# PCP might not be initialized in testing
self.pcp_world_size = 1
self.pcp_rank = 0
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
def append_row(
self,
block_ids: list[int],
row_idx: int,
) -> None:
if not block_ids:
return
if self.use_hybrid_blocks:
block_ids = self.map_to_kernel_blocks(
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
)
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table.np[row_idx, start : start + num_blocks] = block_ids
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
block_table_np = self.block_table.np
block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
def swap_row(self, src: int, tgt: int) -> None:
src_tgt, tgt_src = [src, tgt], [tgt, src]
self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray
) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
if total_cp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * total_cp_world_size
block_table_indices = (
req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = (
virtual_block_offsets
// self.cp_kv_cache_interleave_size
% total_cp_world_size
== total_cp_rank
)
# Calculate local block_offsets
block_offsets = (
virtual_block_offsets
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
* self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping.np[: req_indices.shape[0]] = np.where(
mask, slot_mapping, -1
)
else:
block_table_indices = (
req_indices * self.max_num_blocks_per_req + positions // self.block_size
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(
block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping.np[: req_indices.shape[0]],
)
def commit_block_table(self, num_reqs: int) -> None:
self.block_table.copy_to_gpu(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
self.slot_mapping.copy_to_gpu(num_tokens)
def clear(self) -> None:
self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0)
@staticmethod
def map_to_kernel_blocks(
kv_manager_block_ids: np.ndarray,
blocks_per_kv_block: int,
kernel_block_arange: np.ndarray,
) -> np.ndarray:
"""Convert kv_manager_block_id IDs to kernel block IDs.
Example:
# kv_manager_block_ids: 32 tokens,
# Kernel block size: 16 tokens
# blocks_per_kv_block = 2
>>> kv_manager_block_ids = np.array([0, 1, 2])
>>> Result: [0, 1, 2, 3, 4, 5]
# Each kv_manager_block_id maps to 2 kernel block id:
# kv_manager_block_id 0 → kernel block id [0, 1]
# kv_manager_block_id 1 → kernel block id [2, 3]
# kv_manager_block_id 2 → kernel block id [4, 5]
"""
if blocks_per_kv_block == 1:
return kv_manager_block_ids
kernel_block_ids = (
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
+ kernel_block_arange
)
return kernel_block_ids.reshape(-1)
def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
"""Returns the device tensor of the block table."""
return self.block_table.gpu[:num_reqs]
def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table.cpu
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table.np
def _make_buffer(
self, *size: int | torch.SymInt, dtype: torch.dtype
) -> CpuGpuBuffer:
return CpuGpuBuffer(
*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
)
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
block_sizes: list[int],
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0,
cp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
try:
pcp_world_size = get_pcp_group().world_size
except AssertionError:
# PCP might not be initialized in testing
pcp_world_size = 1
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
if len(kernel_block_sizes) != len(block_sizes):
raise ValueError(
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
f"must match block_sizes length ({len(block_sizes)})"
)
total_cp_world_size = dcp_world_size * pcp_world_size
self.block_tables = [
BlockTable(
block_size,
max_num_reqs,
max(
cdiv(max_model_len, block_size * total_cp_world_size),
1 + num_speculative_tokens,
),
max_num_batched_tokens,
pin_memory,
device,
kernel_block_size,
cp_kv_cache_interleave_size,
)
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
]
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)
def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.move_row(src, tgt)
def swap_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray
) -> None:
for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions)
def commit_block_table(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit_block_table(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
for block_table in self.block_tables:
block_table.commit_slot_mapping(num_tokens)
def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]

View File

@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, cast
from vllm.config import VllmConfig, get_layers_from_vllm_config
if TYPE_CHECKING:
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
else:
AttentionLayerBase = object
def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None:
pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
dcp_size = vllm_config.parallel_config.decode_context_parallel_size
interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
if pcp_size * dcp_size > 1:
layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(vllm_config, layer_type)
for layer in layers.values():
layer_impl = getattr(layer, "impl", None)
if layer_impl is None:
continue
if vllm_config.speculative_config is not None and interleave_size > 1:
assert layer_impl.supports_mtp_with_cp_non_trivial_interleave_size, (
"MTP with cp_kv_cache_interleave_size > 1 is not "
f"supported in {layer_impl.__class__.__name__}."
)
if dcp_size > 1:
assert layer_impl.need_to_return_lse_for_decode, (
"DCP requires attention impls to return"
" the softmax lse for decode, but the impl "
f"{layer_impl.__class__.__name__} "
"does not return the softmax lse for decode."
)
if pcp_size > 1:
assert layer_impl.supports_pcp, (
"PCP requires attention impls' support, "
f"but the impl {layer_impl.__class__.__name__} "
"does not support PCP."
)

View File

@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
class CPUModelRunner(GPUModelRunner):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
with _torch_cuda_wrapper():
super().__init__(vllm_config, device)
assert device == torch.device("cpu")
assert self.speculative_config is None, "spec decode is not supported."
self.use_cuda_graph = False
self.cascade_attn_enabled = False
self._postprocess_tensors()
def _postprocess_tensors(self) -> None:
# Note: replace device tensors with cpu tensors
def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None:
cpu_tensor = getattr(obj, cpu_attr_name, None)
device_tensor = getattr(obj, device_attr_name, None)
if cpu_tensor is not None and device_tensor is not None:
assert isinstance(cpu_tensor, torch.Tensor)
assert isinstance(device_tensor, torch.Tensor)
setattr(obj, device_attr_name, cpu_tensor)
for v in vars(self).values():
if isinstance(v, CpuGpuBuffer):
v.gpu = v.cpu
for k, v in vars(self.input_batch).items():
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
replace_tensor(self.input_batch, k, k[:-11])
for block_table in self.input_batch.block_table.block_tables:
for v in vars(block_table).values():
if isinstance(v, CpuGpuBuffer):
v.gpu = v.cpu
def load_model(self, eep_scale_up: bool = False) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
self.model = self.load_lora_model(self.model, self.vllm_config, self.device)
def get_model(self) -> nn.Module:
return self.model
def warming_up_model(self) -> None:
logger.info("Warming up model for the compilation...")
# Only generate graph for the generic shape
with _set_global_compilation_settings(self.vllm_config):
self._dummy_run(
min(
max(16, self.max_num_reqs),
self.scheduler_config.max_num_batched_tokens,
)
)
logger.info("Warming up done.")
def _init_device_properties(self) -> None:
pass
def _sync_device(self) -> None:
pass
def get_dp_padding(self, num_tokens: int) -> tuple[int, torch.Tensor | None]:
# Note: For CPU backend, dp padding is not required for now.
return 0, None
@contextmanager
def _torch_cuda_wrapper():
class _EventPlaceholder:
def __init__(self, *args, **kwargs) -> None:
self.record = lambda: None
self.synchronize = lambda: None
class _StreamPlaceholder:
def __init__(self, *args, **kwargs) -> None:
pass
cuda_event = torch.Event
cuda_stream = torch.cuda.Stream
try:
torch.Event = _EventPlaceholder
torch.cuda.Stream = _StreamPlaceholder
yield
finally:
torch.Event = cuda_event
torch.cuda.Stream = cuda_stream
@contextmanager
def _set_global_compilation_settings(config: VllmConfig):
import torch._inductor.config as torch_inductor_config
inductor_config = config.compilation_config.inductor_compile_config
# Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
freezing_value = torch_inductor_config.freezing
try:
if inductor_config.get("max_autotune", False):
torch_inductor_config.freezing = True
yield
finally:
torch_inductor_config.freezing = freezing_value

View File

@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import platform
from collections.abc import Callable
from typing import Any
import torch
from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import CpuArchEnum, current_platform
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
from vllm.profiler.wrapper import TorchProfilerWrapper
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
logger = init_logger(__name__)
class CPUWorker(Worker):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(
vllm_config,
local_rank,
rank,
distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.parallel_config.disable_custom_all_reduce = True
# Torch profiler. Enabled and configured through profiler_config.
self.profiler: Any | None = None
profiler_config = vllm_config.profiler_config
if profiler_config.profiler == "torch":
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
self.profiler = TorchProfilerWrapper(
profiler_config,
worker_name=worker_name,
local_rank=self.local_rank,
activities=["CPU"],
)
def init_device(self):
# Setup OpenMP threads affinity.
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
if omp_cpuids == "auto" and platform.system() == "Linux":
cpu_arch = current_platform.get_cpu_architecture()
if cpu_arch in (CpuArchEnum.POWERPC, CpuArchEnum.S390X):
# For S390X/POWERPC SMT-8/4/2
self.local_omp_cpuid = self._get_autobind_cpu_ids(
lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]
)
elif cpu_arch == CpuArchEnum.X86:
# For x86 SMT-2, use 1 CPU per core
self.local_omp_cpuid = self._get_autobind_cpu_ids(
lambda cpus: cpus[-1:]
)
else:
self.local_omp_cpuid = "nobind"
elif omp_cpuids == "nobind":
self.local_omp_cpuid = "nobind"
else:
local_dp_rank = self.parallel_config.data_parallel_rank_local
omp_cpuids_list = omp_cpuids.split("|")
if local_dp_rank is not None:
world_size = self.parallel_config.world_size
omp_cpuids_list = omp_cpuids_list[
local_dp_rank * world_size : (local_dp_rank + 1) * world_size
]
self.local_omp_cpuid = omp_cpuids_list[self.rank]
if self.local_omp_cpuid != "nobind":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret:
logger.info(ret)
# Note: unique identifier for creating allreduce shared memory
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(":")[-1]
# Initialize the distributed environment.
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
current_platform.dist_backend,
)
# Set random seed.
set_random_seed(self.model_config.seed)
# Construct the model runner
self.model_runner: CPUModelRunner = CPUModelRunner(
self.vllm_config, torch.device("cpu")
)
def sleep(self, level: int = 1) -> None:
logger.warning("sleep mode is not supported on CPU, ignore it.")
pass
def wake_up(self, tags: list[str] | None = None) -> None:
logger.warning("sleep mode is not supported on CPU, ignore it.")
pass
def determine_available_memory(self) -> int:
return self.cache_config.cpu_kvcache_space_bytes or 0
def compile_or_warm_up_model(self) -> None:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model()
def _get_autobind_cpu_ids(
self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]]
) -> str:
"""
Return CPU ids to bind based on NUMA nodes.
Currently for rank N, only CPU ids on the N-th node in available NUMA
node list will be selected.
Args:
cpu_selector: a callable object to select CPUs from a CPU list
of a physical core. The input is a LogicalCPUInfo list, sorted by
the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be
returned.
"""
allowed_numa_nodes, logical_cpu_list = (
CpuPlatform.get_allowed_cpu_core_node_list()
)
assert len(allowed_numa_nodes) >= self.parallel_config.world_size, (
f"No enough allowed NUMA nodes to bind threads of "
f"{self.parallel_config.world_size} CPUWorkers. "
f"Allowed NUMA nodes are {allowed_numa_nodes}. "
"Please try to bind threads manually."
)
# Get CPUs on NUMA node `allowed_numa_nodes[local_rank]`
selected_numa_node = allowed_numa_nodes[self.local_rank] # type: ignore
logical_cpu_list = [
x for x in logical_cpu_list if x.numa_node == selected_numa_node
]
# Select CPUs from each physical core via cpu_selector
core_to_cpus: dict[int, list[LogicalCPUInfo]] = {}
for cpu_info in logical_cpu_list:
if cpu_info.physical_core not in core_to_cpus:
core_to_cpus[cpu_info.physical_core] = []
core_to_cpus[cpu_info.physical_core].append(cpu_info)
logical_cpu_list = []
for cpu_list in core_to_cpus.values():
cpu_list = sorted(cpu_list, key=lambda x: x.id)
logical_cpu_list.extend(cpu_selector(cpu_list))
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.id)
# Reserve CPUs for other processes
reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU
if reserve_cpu_num is None:
need_reserve = (
self.parallel_config.world_size > 1
or self.parallel_config.data_parallel_size_local > 1
)
reserve_cpu_num = 1 if need_reserve else 0
assert len(logical_cpu_list) > reserve_cpu_num, (
f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) "
f"should less than {len(logical_cpu_list)}."
)
if reserve_cpu_num != 0:
logical_cpu_list = logical_cpu_list[:-reserve_cpu_num]
logger.info(
"auto thread-binding list (id, physical core): %s",
[(x.id, x.physical_core) for x in logical_cpu_list],
)
return ",".join([str(x.id) for x in logical_cpu_list])
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()

240
vllm/v1/worker/dp_utils.py Normal file
View File

@@ -0,0 +1,240 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
import torch.distributed as dist
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import (
check_ubatch_thresholds,
is_second_ubatch_empty,
)
logger = init_logger(__name__)
def _get_device_and_group(parallel_config: ParallelConfig):
# Use the actual device assigned to the DP group, not just the device type
device = get_dp_group().device
group = get_dp_group().device_group
# Transferring this tensor from GPU to CPU will introduce a GPU sync
# point that could adversely affect performance of vllm with asynch
# scheduling. This environment variable exists to quickly disable
# this optimization if we run into this case.
if parallel_config.disable_nccl_for_dp_synchronization:
logger.info_once(
"Using CPU all reduce to synchronize DP padding between ranks."
)
device = "cpu"
group = get_dp_group().cpu_group
return device, group
def _run_ar(
should_ubatch: bool,
should_dp_pad: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
cudagraph_mode: int,
parallel_config: ParallelConfig,
) -> torch.Tensor:
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
device, group = _get_device_and_group(parallel_config)
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
tensor[3][dp_rank] = 1 if should_dp_pad else 0
tensor[4][dp_rank] = cudagraph_mode
dist.all_reduce(tensor, group=group)
return tensor
def _post_process_ubatch(tensor: torch.Tensor) -> bool:
orig_num_tokens_tensor = tensor[0, :]
padded_num_tokens_tensor = tensor[1, :]
# First determine if we are going to be ubatching.
should_ubatch: bool = bool(torch.all(tensor[2] == 1).item())
if not should_ubatch:
return False
# If the DP ranks are planning to ubatch, make sure that
# there are no "empty" second ubatches
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
logger.debug(
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
)
should_ubatch = False
return should_ubatch
def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor:
num_tokens_across_dp = tensor[1, :]
if should_dp_pad:
# If DP padding is enabled, ensure that each rank is processing the same number
# of tokens
max_num_tokens = int(num_tokens_across_dp.max().item())
return torch.tensor(
[max_num_tokens] * len(num_tokens_across_dp),
device="cpu",
dtype=torch.int32,
)
else:
return num_tokens_across_dp.cpu()
def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
"""
Synchronize cudagraph_mode across DP ranks by taking the minimum.
If any rank has NONE (0), all ranks use NONE.
This ensures all ranks send consistent values (all padded or all unpadded).
"""
return int(tensor[4, :].min().item())
def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
should_attempt_ubatching: bool,
should_attempt_dp_padding: bool,
cudagraph_mode: int,
parallel_config: ParallelConfig,
) -> tuple[bool, torch.Tensor | None, int]:
"""
1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do.
2. Determines the total number of tokens that each rank will run.
When running microbatched or if should_attempt_dp_padding is True, all
ranks will be padded out so that the run with the same number of tokens
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including any DP padding.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
]
"""
assert num_tokens_padded >= num_tokens_unpadded
# Coordinate between the DP ranks via an All Reduce
# to determine the total number of tokens that each rank
# will run and if we are using ubatching or not.
tensor = _run_ar(
should_ubatch=should_attempt_ubatching,
should_dp_pad=should_attempt_dp_padding,
orig_num_tokens_per_ubatch=num_tokens_unpadded,
padded_num_tokens_per_ubatch=num_tokens_padded,
cudagraph_mode=cudagraph_mode,
parallel_config=parallel_config,
)
should_dp_pad = bool(torch.all(tensor[3] == 1).item())
# DP ranks should all have the same value for should_attempt_dp_padding.
assert should_attempt_dp_padding == should_dp_pad
# Check conditions for microbatching
should_ubatch = _post_process_ubatch(tensor)
if should_ubatch and not should_dp_pad:
logger.debug_once(
"Microbatching has been triggered and requires DP padding. "
"Enabling DP padding even though it has been explicitly "
"disabled.",
scope="global",
)
should_dp_pad = True
# Pad all DP ranks up to the maximum token count across ranks if
# should_dp_pad is True
num_tokens_after_padding = _post_process_dp_padding(
tensor,
should_dp_pad,
)
# Synchronize cudagraph_mode across ranks (take min)
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
def coordinate_batch_across_dp(
num_tokens_unpadded: int,
allow_microbatching: bool,
allow_dp_padding: bool,
parallel_config: ParallelConfig,
num_tokens_padded: int | None = None,
uniform_decode: bool | None = None,
num_scheduled_tokens_per_request: np.ndarray | None = None,
cudagraph_mode: int = 0,
) -> tuple[bool, torch.Tensor | None, int]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches.
Args:
num_tokens_unpadded: Number of tokens without accounting for padding
allow_microbatching: If microbatching should be attempted
allow_dp_padding: If all DP ranks should be padded up to the same value
parallel_config: The parallel config
num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
TP, etc)
uniform_decode: Only used if allow_microbatching is True. True if the batch
only contains single token decodes
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
number of tokens per request.
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
padded up to the max value across all DP ranks when allow_dp_padding
is True.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
]
"""
if parallel_config.data_parallel_size == 1:
# Early exit.
return False, None, cudagraph_mode
# If the caller has explicitly enabled microbatching.
should_attempt_ubatching = False
if allow_microbatching:
# Check preconditions for microbatching
assert uniform_decode is not None
should_attempt_ubatching = check_ubatch_thresholds(
parallel_config,
num_tokens_unpadded,
uniform_decode=uniform_decode,
)
if num_tokens_padded is None:
num_tokens_padded = num_tokens_unpadded
(should_ubatch, num_tokens_after_padding, synced_cudagraph_mode) = (
_synchronize_dp_ranks(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
allow_dp_padding,
cudagraph_mode,
parallel_config,
)
)
return (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode)

View File

@@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Define EC connector functionality mixin for model runners.
"""
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import (
TYPE_CHECKING, # noqa: UP035
)
import torch
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorBase
from vllm.logger import init_logger
from vllm.v1.outputs import ECConnectorOutput
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
# Defined as a EC connector functionality mixin for ModelRunner (GPU, TPU)
class ECConnectorModelRunnerMixin:
@staticmethod
def maybe_save_ec_to_connector(
encoder_cache: dict[str, torch.Tensor],
mm_hash: str,
):
if not has_ec_transfer():
logger.debug("Not have ec transfer please check")
return
connector = get_ec_transfer()
connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash)
@staticmethod
def get_finished_ec_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[set[str] | None, set[str] | None]:
if has_ec_transfer():
return get_ec_transfer().get_finished(scheduler_output.finished_req_ids)
return None, None
@staticmethod
def maybe_get_ec_connector_output(
scheduler_output: "SchedulerOutput",
encoder_cache: dict[str, torch.Tensor],
**kwargs,
) -> AbstractContextManager[ECConnectorOutput | None]:
return (
ECConnectorModelRunnerMixin._get_ec_connector_output(
scheduler_output, encoder_cache, **kwargs
)
if has_ec_transfer()
else nullcontext()
)
# This context manager must be used within an active forward context.
# It encapsulates the entire EC connector lifecycle within execute_model
@staticmethod
@contextmanager
def _get_ec_connector_output(
scheduler_output: "SchedulerOutput",
encoder_cache: dict[str, torch.Tensor],
**kwargs,
) -> Generator[ECConnectorOutput, None, None]:
output = ECConnectorOutput()
ec_connector = get_ec_transfer()
assert isinstance(ec_connector, ECConnectorBase)
assert scheduler_output.ec_connector_metadata is not None
ec_connector.bind_connector_metadata(scheduler_output.ec_connector_metadata)
if not ec_connector.is_producer:
ec_connector.start_load_caches(encoder_cache, **kwargs)
try:
yield output
finally:
output.finished_sending, output.finished_recving = (
ec_connector.get_finished(scheduler_output.finished_req_ids)
)
ec_connector.clear_connector_metadata()

View File

@@ -0,0 +1,4 @@
# [Experimental] Model Runner V2
This directory contains the new model runner which is under active development.
Ping [Woosuk Kwon](https://github.com/WoosukKwon) for any changes.

View File

View File

@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
import numpy as np
import torch
from vllm.v1.outputs import (
AsyncModelRunnerOutput,
LogprobsTensors,
ModelRunnerOutput,
)
from vllm.v1.worker.gpu.sample.output import SamplerOutput
class AsyncOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput,
num_sampled_tokens: torch.Tensor,
copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event,
):
# NOTE(woosuk): We must retain references to the GPU tensors,
# as the copy operations are performed on a different CUDA stream than
# the one where the tensors were created.
self.model_runner_output = model_runner_output
self.sampler_output = sampler_output
self.num_sampled_tokens = num_sampled_tokens
self.copy_stream = copy_stream
self.copy_event = copy_event
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.copy_stream):
self.copy_stream.wait_stream(default_stream)
self.sampled_token_ids = async_copy_to_np(sampler_output.sampled_token_ids)
if sampler_output.logprobs_tensors is not None:
self.logprobs_tensors: LogprobsTensors | None = (
sampler_output.logprobs_tensors.to_cpu_nonblocking()
)
else:
self.logprobs_tensors = None
if sampler_output.num_nans is not None:
self.num_nans = async_copy_to_np(sampler_output.num_nans)
else:
self.num_nans = None
self.num_sampled_tokens_np = async_copy_to_np(num_sampled_tokens)
self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
if self.model_runner_output.prompt_logprobs_dict:
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
if v is not None:
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
else:
self.prompt_logprobs_dict[k] = None
self.copy_event.record(self.copy_stream)
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
# NOTE(woosuk): The following code is to ensure compatibility with
# the existing model runner.
# Going forward, we should keep the data structures as NumPy arrays
# rather than Python lists.
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
num_reqs = len(sampled_token_ids)
num_sampled_tokens = self.num_sampled_tokens_np.tolist()
for i in range(num_reqs):
del sampled_token_ids[i][num_sampled_tokens[i] :]
self.model_runner_output.sampled_token_ids = sampled_token_ids
if self.num_nans is not None:
num_nans = self.num_nans.tolist()
self.model_runner_output.num_nans_in_logits = {
req_id: num_nans[i]
for i, req_id in enumerate(self.model_runner_output.req_ids)
}
if self.logprobs_tensors is not None:
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
return self.model_runner_output
@contextmanager
def async_barrier(event: torch.cuda.Event | None):
if event is not None:
event.synchronize()
try:
yield
finally:
if event is not None:
event.record()
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
return x.to("cpu", non_blocking=True).numpy()

View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any, cast
import numpy as np
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheConfig,
KVCacheSpec,
)
from vllm.v1.worker.utils import bind_kv_cache
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
kv_cache_spec: dict[str, KVCacheSpec] = {}
layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type)
for layer_name, attn_module in attn_layers.items():
# Skip modules that don't need KV cache (eg encoder-only attention)
if spec := attn_module.get_kv_cache_spec(vllm_config):
kv_cache_spec[layer_name] = spec
return kv_cache_spec
def init_attn_backend(
kv_cache_config: KVCacheConfig,
vllm_config: VllmConfig,
device: torch.device,
):
attn_backends: dict[str, type[AttentionBackend]] = {}
attn_metadata_builders: list[AttentionMetadataBuilder] = []
flashinfer_workspace: torch.Tensor | None = None
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
layer_names = kv_cache_group_spec.layer_names
any_layer_name = next(iter(layer_names))
layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
attn_backend = attn_layers[any_layer_name].get_attn_backend()
for layer_name in layer_names:
attn_backends[layer_name] = attn_backend
attn_metadata_builder = attn_backend.get_builder_cls()(
kv_cache_group_spec.kv_cache_spec,
layer_names,
vllm_config,
device,
)
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
if "FLASHINFER" in attn_backend.get_name():
if flashinfer_workspace is None:
flashinfer_workspace = attn_metadata_builder._get_workspace_buffer()
else:
attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
return attn_backends, attn_metadata_builders
def _allocate_kv_cache(
kv_cache_config: KVCacheConfig,
device: torch.device,
):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys()), (
"Some layers are not correctly initialized"
)
return kv_cache_raw_tensors
def _reshape_kv_cache(
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, AttentionBackend],
) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
assert isinstance(kv_cache_spec, AttentionSpec)
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
attn_backend = attn_backends[layer_name]
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
)
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
dtype = kv_cache_spec.dtype
raw_tensor = raw_tensor.view(dtype)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
return kv_caches
def init_kv_cache(
runner_kv_caches: list[torch.Tensor],
forward_context: dict[str, Any],
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend],
device: torch.device,
) -> None:
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,
num_tokens: int,
query_start_loc_gpu: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_np: np.ndarray,
num_computed_tokens_cpu: torch.Tensor | None,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
max_query_len = int(query_start_loc_cpu.max())
seq_lens = seq_lens[:num_reqs]
seq_lens_cpu = torch.from_numpy(seq_lens_np)
max_seq_len = int(seq_lens_np.max())
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
_seq_lens_cpu=seq_lens_cpu,
max_seq_len=max_seq_len,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
causal=True,
)
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata

View File

@@ -0,0 +1,314 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
class BlockTables:
def __init__(
self,
block_sizes: list[int],
max_num_reqs: int,
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
pin_memory: bool,
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
self.pin_memory = pin_memory
self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[torch.Tensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size)
block_table = torch.zeros(
self.max_num_reqs,
max_num_blocks,
dtype=torch.int32,
device=self.device,
)
self.block_tables.append(block_table)
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
# Block tables used for model's forward pass.
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.input_block_tables: list[torch.Tensor] = [
torch.zeros_like(block_table) for block_table in self.block_tables
]
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
self.block_table_strides = torch.tensor(
[b.stride(0) for b in self.block_tables],
dtype=torch.int64,
device=self.device,
)
self.block_sizes_tensor = torch.tensor(
self.block_sizes, dtype=torch.int32, device=self.device
)
self.num_blocks = torch.zeros(
self.num_kv_cache_groups,
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
self.slot_mappings = torch.zeros(
self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device,
)
# Misc buffers.
self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
self.cu_num_new_blocks = self._make_buffer(
self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32
)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
ptrs_tensor_cpu = torch.tensor(
[t.data_ptr() for t in x],
dtype=torch.uint64,
device="cpu",
pin_memory=self.pin_memory,
)
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids(
self,
# [num_reqs]
req_indices: list[int],
# [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks: tuple[list[int], ...],
# [num_kv_cache_groups, num_new_blocks]
new_block_ids: tuple[list[int], ...],
# [num_reqs]
overwrite: list[bool],
) -> None:
num_reqs = len(req_indices)
self.req_indices.np[:num_reqs] = req_indices
self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i]
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
# no clear upper bound to the number of new blocks in a single step.
# NOTE(woosuk): The buffer has to be cached, because otherwise we cannot
# guarantee that the buffer is not freed before the copy is completed.
self.new_block_ids_cpu = torch.empty(
self.num_kv_cache_groups,
max(len(x) for x in new_block_ids),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
new_block_ids_np = self.new_block_ids_cpu.numpy()
for i in range(self.num_kv_cache_groups):
new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i]
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True)
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
self.req_indices.copy_to_gpu(num_reqs),
self.cu_num_new_blocks.copy_to_gpu(),
self.cu_num_new_blocks.gpu.stride(0),
new_block_ids_gpu,
new_block_ids_gpu.stride(0),
self.overwrite.copy_to_gpu(num_reqs),
self.block_table_strides,
self.block_table_ptrs,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024, # type: ignore
)
def gather_block_tables(
self,
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
idx_mapping,
self.block_table_ptrs,
self.input_block_table_ptrs,
self.block_table_strides,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024, # type: ignore
)
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def compute_slot_mappings(
self,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
num_reqs = query_start_loc.shape[0] - 1
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens,
query_start_loc,
positions,
self.input_block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mappings,
self.slot_mappings.stride(0),
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024, # type: ignore
)
return self.slot_mappings[:, :num_tokens]
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
self.slot_mappings.fill_(PAD_SLOT_ID)
return self.slot_mappings[:, :num_tokens]
@triton.jit
def _append_block_ids_kernel(
# Inputs
req_indices, # [num_reqs]
cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks_stride,
new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
new_block_ids_stride,
overwrite, # [num_reqs]
block_table_strides, # [num_kv_cache_groups]
# Outputs
block_table_ptrs, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
# Constants
BLOCK_SIZE: tl.constexpr,
):
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(req_indices + batch_idx)
do_overwrite = tl.load(overwrite + batch_idx)
group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
num_new_blocks = end_idx - start_idx
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
dst_start_idx = tl.load(group_num_blocks_ptr + req_idx) if not do_overwrite else 0
dst_end_idx = dst_start_idx + num_new_blocks
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
# Destination
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
row_ptr = block_table_ptr + req_idx * block_table_stride
group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride
for i in range(0, num_new_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(
group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks
)
tl.store(
row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks
)
@triton.jit
def _gather_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_kv_cache_groups]
dst_block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
@triton.jit
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
cu_num_tokens, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
page_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
req_idx = tl.program_id(1)
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
if req_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
page_size = tl.load(page_sizes + group_id)
start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in range(start_idx, end_idx, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size
block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices
)
slot_ids = block_numbers * page_size + positions % page_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr)
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
return tl.multiple_of(ptr, 16)

View File

@@ -0,0 +1,259 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_mode: CUDAGraphMode
if self.compilation_config.cudagraph_mode is None:
self.cudagraph_mode = CUDAGraphMode.NONE
else:
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
def needs_capture(self) -> bool:
return len(self.cudagraph_sizes) > 0
def get_cudagraph_size(
self,
scheduler_output: SchedulerOutput,
num_tokens_after_padding: int,
) -> int | None:
return get_cudagraph_size(
num_tokens_after_padding,
scheduler_output.num_scheduled_tokens.values(),
self.cudagraph_sizes,
self.cudagraph_mode,
)
def capture_graph(
self,
num_tokens: int,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
attn_metadata = prepare_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
block_tables,
attn_metadata_builders,
self.max_model_len,
kv_cache_config,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
with (
set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
),
torch.cuda.graph(graph, self.pool),
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
self.hidden_states[:num_tokens] = hidden_states
self.graphs[num_tokens] = graph
@torch.inference_mode()
def capture(
self,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
model=model,
input_buffers=input_buffers,
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
kv_cache_config=kv_cache_config,
)
def run(self, num_tokens: int) -> torch.Tensor:
assert num_tokens in self.graphs
self.graphs[num_tokens].replay()
assert self.hidden_states is not None
return self.hidden_states[:num_tokens]
def get_cudagraph_sizes(
capture_sizes: list[int] | None,
max_num_reqs: int,
max_num_tokens: int,
cudagraph_mode: CUDAGraphMode,
) -> dict[int, int]:
if not cudagraph_mode.has_full_cudagraphs():
return {}
if not capture_sizes:
return {}
capture_sizes = sorted(capture_sizes)
# Limit the capture sizes to the max number of requests or tokens.
upper_bound = (
max_num_reqs
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
else max_num_tokens
)
capture_sizes = [x for x in capture_sizes if x <= upper_bound]
if not capture_sizes:
return {}
cudagraph_sizes: dict[int, int] = {}
for i in range(1, capture_sizes[-1] + 1):
for x in capture_sizes:
if i <= x:
cudagraph_sizes[i] = x
break
return cudagraph_sizes
def get_cudagraph_size(
num_tokens_after_dp_padding: int,
num_tokens_per_request: Iterable[int],
cudagraph_sizes: dict[int, int],
cudagraph_mode: CUDAGraphMode,
) -> int | None:
size = cudagraph_sizes.get(num_tokens_after_dp_padding)
if size is None:
# No CUDA graph for this size.
return None
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
all_decode = all(x == 1 for x in num_tokens_per_request)
if not all_decode:
# Prefill is included.
return None
return size
def capture_graphs(
cudagraph_sizes: dict[int, int],
device: torch.device,
capture_fn: Callable,
**capture_kwargs,
) -> None:
# Capture larger graphs first.
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
with graph_capture(device=device):
for size in sizes_to_capture:
capture_fn(size, **capture_kwargs)
def prepare_inputs_to_capture(
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
max_model_len: int,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc = input_buffers.query_start_loc
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req
query_start_loc.np[num_reqs:] = num_tokens
query_start_loc.copy_to_gpu()
seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32)
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len. This introduces a discrepancy between
# seq_lens (on GPU) and seq_lens_np (on CPU), which may cause issues for
# certain attention backends.
input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1],
query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1],
seq_lens=input_buffers.seq_lens,
seq_lens_np=seq_lens_np,
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
)
return attn_metadata

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import get_dp_group
def get_batch_metadata_across_dp(
num_tokens: int,
cudagraph_size: int,
dp_size: int,
dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert dp_size > 1
# Use CPU group to avoid CPU-GPU synchronization.
group = get_dp_group().cpu_group
tensor = torch.zeros(2, dp_size, dtype=torch.int32, device="cpu")
tensor[0][dp_rank] = num_tokens
tensor[1][dp_rank] = cudagraph_size
dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1]
def make_num_tokens_across_dp(
dp_size: int,
num_tokens: int,
) -> torch.Tensor | None:
if dp_size == 1:
return None
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")

View File

@@ -0,0 +1,479 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.utils import random_uuid
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
class InputBuffers:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
inputs_embeds_size: int,
vocab_size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.device = device
self.pin_memory = pin_memory
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
# Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.grammar_bitmask = self._make_buffer(
max_num_reqs, cdiv(vocab_size, 32), dtype=torch.int32
)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
)
@dataclass
class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
num_reqs: int
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# [num_reqs]
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
# sum(num_scheduled_tokens)
num_tokens: int
num_tokens_after_padding: int
num_draft_tokens: int
# [num_reqs + 1]
query_start_loc: torch.Tensor
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
seq_lens_np: np.ndarray
# [num_tokens_after_padding]
input_ids: torch.Tensor
# [num_tokens_after_padding]
positions: torch.Tensor
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# [total_num_logits]
logits_indices: torch.Tensor
# [num_reqs + 1]
cu_num_logits: torch.Tensor
@classmethod
def make_dummy(
cls,
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
device: torch.device,
) -> "InputBatch":
assert 0 < num_reqs <= num_tokens
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
assert int(num_scheduled_tokens.sum()) == num_tokens
input_buffers.query_start_loc.np[0] = 0
input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
num_scheduled_tokens
)
input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
# seq_len equals to query_len
seq_lens_np = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
seq_lens_np[-1] += num_tokens % num_reqs
input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
input_buffers.seq_lens[num_reqs:] = 0
seq_lens = input_buffers.seq_lens[:num_reqs]
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
num_draft_tokens=0,
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
seq_lens_np=seq_lens_np,
input_ids=input_ids,
positions=positions,
attn_metadata=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
)
@triton.jit
def _prepare_prefill_inputs_kernel(
input_ids_ptr,
next_prefill_tokens_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
prefill_token_ids_ptr,
prefill_token_ids_stride,
prefill_lens_ptr,
num_computed_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
if num_computed >= prefill_len:
# Not prefill.
return
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
prefill_ptr = prefill_token_ids_ptr + req_state_idx * prefill_token_ids_stride
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
tokens = tl.load(prefill_ptr + num_computed + block, mask=mask)
tl.store(input_ids_ptr + query_start + block, tokens, mask=mask)
next_pos = num_computed + query_len
if next_pos < prefill_len:
next_token = tl.load(prefill_ptr + next_pos)
tl.store(next_prefill_tokens_ptr + req_state_idx, next_token)
def prepare_prefill_inputs(
input_ids: torch.Tensor,
next_prefill_tokens: torch.Tensor,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
prefill_token_ids: torch.Tensor,
prefill_len: torch.Tensor,
num_computed_tokens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_prepare_prefill_inputs_kernel[(num_reqs,)](
input_ids,
next_prefill_tokens,
idx_mapping,
query_start_loc,
prefill_token_ids,
prefill_token_ids.stride(0),
prefill_len,
num_computed_tokens,
BLOCK_SIZE=1024,
)
@triton.jit
def _prepare_pos_seq_lens_kernel(
pos_ptr,
seq_lens_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
num_computed_tokens_ptr,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
req_id = tl.program_id(0)
num_reqs = tl.num_programs(0) - 1
if req_id == num_reqs:
# Pad unused seq_lens as 0 for full CUDA graphs.
for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(seq_lens_ptr + block, 0, mask=mask)
return
req_state_idx = tl.load(idx_mapping_ptr + req_id)
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
start = tl.load(query_start_loc_ptr + req_id)
end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = end - start
seq_len = num_computed_tokens + query_len
tl.store(seq_lens_ptr + req_id, seq_len)
for i in tl.range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
pos = num_computed_tokens + block
tl.store(pos_ptr + start + block, pos, mask=mask)
def prepare_pos_seq_lens(
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
num_computed_tokens: torch.Tensor,
pos: torch.Tensor,
seq_lens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
# NOTE(woosuk): We do +1 because the last thread block is used
# to pad unused seq_lens as 0 for full CUDA graphs.
_prepare_pos_seq_lens_kernel[(num_reqs + 1,)](
pos,
seq_lens,
idx_mapping,
query_start_loc,
num_computed_tokens,
seq_lens.shape[0],
BLOCK_SIZE=1024,
)
@triton.jit
def _combine_sampled_and_draft_tokens_kernel(
input_ids_ptr,
idx_mapping_ptr,
last_sampled_tokens_ptr,
query_start_loc_ptr,
seq_lens_ptr,
prefill_len_ptr,
draft_tokens_ptr,
draft_tokens_stride,
cu_num_logits_ptr,
logits_indices_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
# Get the number of logits and draft tokens.
cu_num_logits_start = tl.load(cu_num_logits_ptr + batch_idx)
cu_num_logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
num_logits = cu_num_logits_end - cu_num_logits_start
num_draft_tokens = num_logits - 1
# Compute the logits indices.
block = tl.arange(0, BLOCK_SIZE)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
logits_start = query_end - num_logits
tl.store(
logits_indices_ptr + cu_num_logits_start + block,
logits_start + block,
mask=block < num_logits,
)
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if seq_len <= prefill_len:
# Handling prefill tokens. No sampled or draft tokens.
return
# Write the last sampled token ID to input_ids.
last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
tl.store(input_ids_ptr + query_end - num_logits, last_token_id)
# Write the draft tokens (if any) to input_ids.
if num_draft_tokens > 0:
mask = block < num_draft_tokens
draft_tokens = tl.load(
draft_tokens_ptr + req_state_idx * draft_tokens_stride + block,
mask=mask,
)
tl.store(
input_ids_ptr + query_end - num_draft_tokens + block,
draft_tokens,
mask=mask,
)
def combine_sampled_and_draft_tokens(
input_ids: torch.Tensor,
idx_mapping: torch.Tensor,
last_sampled_tokens: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
prefill_len: torch.Tensor,
draft_tokens: torch.Tensor,
cu_num_logits: torch.Tensor,
num_logits: int,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
num_speculative_steps = draft_tokens.shape[-1]
logits_indices = torch.empty(
num_logits,
dtype=torch.int64,
device=input_ids.device,
)
_combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
input_ids,
idx_mapping,
last_sampled_tokens,
query_start_loc,
seq_lens,
prefill_len,
draft_tokens,
draft_tokens.stride(0),
cu_num_logits,
logits_indices,
# NOTE(woosuk): Add 1 to ensure the block can cover the last sampled token
# in addition to all draft tokens.
BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
)
return logits_indices
@triton.jit
def _get_num_sampled_and_rejected_kernel(
num_sampled_ptr,
num_rejected_ptr,
seq_lens_ptr,
cu_num_logits_ptr,
idx_mapping_ptr,
prefill_len_ptr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
is_chunked_prefilling = seq_len < prefill_len
num_sampled = tl.load(num_sampled_ptr + batch_idx)
num_sampled = tl.where(is_chunked_prefilling, 0, num_sampled)
tl.store(num_sampled_ptr + batch_idx, num_sampled)
logits_start = tl.load(cu_num_logits_ptr + batch_idx)
logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
num_logits = logits_end - logits_start
num_rejected = num_logits - num_sampled
num_rejected = tl.where(is_chunked_prefilling, 0, num_rejected)
tl.store(num_rejected_ptr + batch_idx, num_rejected)
def get_num_sampled_and_rejected(
num_sampled: torch.Tensor,
seq_lens: torch.Tensor,
cu_num_logits: torch.Tensor,
idx_mapping: torch.Tensor,
prefill_len: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = idx_mapping.shape[0]
num_rejected = torch.empty_like(num_sampled)
_get_num_sampled_and_rejected_kernel[(num_reqs,)](
num_sampled,
num_rejected,
seq_lens,
cu_num_logits,
idx_mapping,
prefill_len,
)
return num_sampled, num_rejected
@triton.jit
def _post_update_kernel(
idx_mapping_ptr,
num_computed_tokens_ptr,
last_sampled_tokens_ptr,
output_bin_counts_ptr,
output_bin_counts_stride,
sampled_tokens_ptr,
sampled_tokens_stride,
num_sampled_ptr,
num_rejected_ptr,
query_start_loc_ptr,
):
req_id = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_id)
num_sampled = tl.load(num_sampled_ptr + req_id)
if num_sampled > 0:
token_id = tl.load(
sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1
)
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
for i in range(num_sampled):
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
token_ptr = (
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
)
count = tl.load(token_ptr)
count += 1
tl.store(token_ptr, count)
query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = query_end - query_start
num_rejected = tl.load(num_rejected_ptr + req_id)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
num_computed += query_len - num_rejected
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
def post_update(
# [num_reqs]
idx_mapping: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [max_num_reqs]
last_sampled_tokens: torch.Tensor,
# [max_num_reqs, vocab_size]
output_bin_counts: torch.Tensor,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens: torch.Tensor,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [num_reqs + 1]
query_start_loc: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_post_update_kernel[(num_reqs,)](
idx_mapping,
num_computed_tokens,
last_sampled_tokens,
output_bin_counts,
output_bin_counts.stride(0),
sampled_tokens,
sampled_tokens.stride(0),
num_sampled,
num_rejected,
query_start_loc,
num_warps=1,
)

View File

View File

@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch._inductor.runtime.triton_helpers import libdevice
from vllm.triton_utils import tl, triton
@triton.jit
def _num_nans_kernel(
logits_ptr,
logits_stride,
num_nans_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_nans = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=0
)
logits = logits.to(tl.float32)
is_nan = libdevice.isnan(logits).to(tl.int1)
num_nans += tl.sum(is_nan).to(tl.int32)
tl.store(num_nans_ptr + req_idx, num_nans)
def get_num_nans(logits: torch.Tensor) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device)
_num_nans_kernel[(num_reqs,)](
logits,
logits.stride(0),
num_nans,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
return num_nans

File diff suppressed because it is too large Load Diff

View File

View File

@@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _gumbel_sample_kernel(
local_argmax_ptr,
local_argmax_stride,
local_max_ptr,
local_max_stride,
logits_ptr,
logits_stride,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
req_idx = tl.program_id(0)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_idx)
pos = tl.load(pos_ptr + req_idx)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise.
r = tl.rand(gumbel_seed, block).to(tl.float64)
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
gumbel_noise = gumbel_noise.to(tl.float32)
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
idx = tl.argmax(logits, axis=0)
token_id = block_idx * BLOCK_SIZE + idx
value = tl.max(logits, axis=0)
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
apply_temperature: bool,
) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_reqs,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_reqs,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
_gumbel_sample_kernel[(num_reqs, num_blocks)](
local_argmax,
local_argmax.stride(0),
local_max,
local_max.stride(0),
logits,
logits.stride(0),
seed,
pos,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
APPLY_TEMPERATURE=apply_temperature,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
return sampled

View File

@@ -0,0 +1,167 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
@triton.jit
def _topk_log_softmax_kernel(
output_ptr,
logits_ptr,
logits_stride,
topk_ids_ptr,
topk,
vocab_size,
BLOCK_SIZE: tl.constexpr,
PADDED_TOPK: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore
se = 0.0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
logits = logits.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e)
lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < topk
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
logits = logits.to(tl.float32)
o = logits - max_val - lse
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
@triton.jit
def _ranks_kernel(
output_ptr,
logits_ptr,
logits_stride,
token_ids_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
token_id = tl.load(token_ids_ptr + req_idx)
x = tl.load(row_ptr + token_id)
n = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
n += tl.sum((logits > x).to(tl.int32))
tl.store(output_ptr + req_idx, n)
def compute_token_logprobs(
logits: torch.Tensor,
token_ids: torch.Tensor,
) -> torch.Tensor:
batch_size = logits.shape[0]
vocab_size = logits.shape[1]
token_ids = token_ids.to(torch.int64)
num_logprobs = token_ids.shape[1]
logprobs = torch.empty(
batch_size,
num_logprobs,
dtype=torch.float32,
device=logits.device,
)
_topk_log_softmax_kernel[(batch_size,)](
logprobs,
logits,
logits.stride(0),
token_ids,
num_logprobs,
vocab_size,
BLOCK_SIZE=1024, # type: ignore
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
)
return logprobs
def compute_topk_logprobs(
logits: torch.Tensor,
num_logprobs: int,
sampled_token_ids: torch.Tensor,
) -> LogprobsTensors:
assert num_logprobs >= 0
batch_size, vocab_size = logits.shape
if num_logprobs == 0:
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
else:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat(
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
logprobs = compute_token_logprobs(logits, logprob_token_ids)
token_ranks = torch.empty(
batch_size,
dtype=torch.int64,
device=logits.device,
)
_ranks_kernel[(batch_size,)](
token_ranks,
logits,
logits.stride(0),
sampled_token_ids,
vocab_size,
BLOCK_SIZE=8192, # type: ignore
)
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=token_ranks,
)
def compute_prompt_logprobs(
prompt_token_ids: torch.Tensor,
prompt_hidden_states: torch.Tensor,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE = 1024
logprobs = []
ranks = []
prompt_token_ids = prompt_token_ids.to(torch.int64)
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
end_idx = start_idx + CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
prompt_logprobs = compute_topk_logprobs(
prompt_logits,
0, # num_logprobs
prompt_token_ids[start_idx:end_idx],
)
logprobs.append(prompt_logprobs.logprobs)
ranks.append(prompt_logprobs.selected_token_ranks)
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
return logprobs, ranks

View File

@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.triton_utils import tl, triton
@dataclass
class SamplingMetadata:
temperature: torch.Tensor
top_p: torch.Tensor | None
top_k: torch.Tensor | None
min_p: torch.Tensor | None
repetition_penalty: torch.Tensor
frequency_penalty: torch.Tensor
presence_penalty: torch.Tensor
seeds: torch.Tensor
pos: torch.Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: int | None
# For penalties
idx_mapping: torch.Tensor
prompt_bin_mask: torch.Tensor
output_bin_counts: torch.Tensor
@classmethod
def make_dummy(
cls,
num_reqs: int,
device: torch.device,
) -> "SamplingMetadata":
assert num_reqs > 0
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
temperature[0] = 0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
# Currently, they are disabled because of memory usage.
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
top_p = None
top_k = None
min_p = torch.zeros(num_reqs, dtype=torch.float32, device=device)
# NOTE(woosuk): We must set penalties to their default values to make sure
# the penalties kernel does not touch the placeholder bin_counts tensors.
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
return cls(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
idx_mapping=idx_mapping,
prompt_bin_mask=prompt_bin_mask,
output_bin_counts=output_bin_counts,
)
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
@triton.jit
def _expand_sampling_metadata_kernel(
temp_ptr,
expanded_temp_ptr,
top_p_ptr,
expanded_top_p_ptr,
top_k_ptr,
expanded_top_k_ptr,
min_p_ptr,
expanded_min_p_ptr,
rep_penalty_ptr,
expanded_rep_penalty_ptr,
freq_penalty_ptr,
expanded_freq_penalty_ptr,
pres_penalty_ptr,
expanded_pres_penalty_ptr,
seeds_ptr,
expanded_seeds_ptr,
cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
block = tl.arange(0, BLOCK_SIZE)
mask = block < num_tokens
temp = tl.load(temp_ptr + req_idx)
tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
if top_p_ptr is not None:
top_p = tl.load(top_p_ptr + req_idx)
tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
if top_k_ptr is not None:
top_k = tl.load(top_k_ptr + req_idx)
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
if min_p_ptr is not None:
min_p = tl.load(min_p_ptr + req_idx)
tl.store(expanded_min_p_ptr + start_idx + block, min_p, mask=mask)
rep_penalty = tl.load(rep_penalty_ptr + req_idx)
tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask)
freq_penalty = tl.load(freq_penalty_ptr + req_idx)
tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask)
pres_penalty = tl.load(pres_penalty_ptr + req_idx)
tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask)
seed = tl.load(seeds_ptr + req_idx)
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
def expand_sampling_metadata(
sampling_metadata: SamplingMetadata,
cu_num_logits: torch.Tensor,
max_expand_len: int,
) -> SamplingMetadata:
total_num_logits = sampling_metadata.pos.shape[0]
create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
expanded_temp = create_empty(sampling_metadata.temperature)
expanded_top_p = create_empty(sampling_metadata.top_p)
expanded_top_k = create_empty(sampling_metadata.top_k)
expanded_min_p = create_empty(sampling_metadata.min_p)
expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty)
expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty)
expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty)
expanded_seeds = create_empty(sampling_metadata.seeds)
num_reqs = cu_num_logits.shape[0] - 1
_expand_sampling_metadata_kernel[(num_reqs,)](
sampling_metadata.temperature,
expanded_temp,
sampling_metadata.top_p,
expanded_top_p,
sampling_metadata.top_k,
expanded_top_k,
sampling_metadata.min_p,
expanded_min_p,
sampling_metadata.repetition_penalty,
expanded_repetition_penalty,
sampling_metadata.frequency_penalty,
expanded_frequency_penalty,
sampling_metadata.presence_penalty,
expanded_presence_penalty,
sampling_metadata.seeds,
expanded_seeds,
cu_num_logits,
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
)
return SamplingMetadata(
temperature=expanded_temp,
top_p=expanded_top_p,
top_k=expanded_top_k,
min_p=expanded_min_p,
seeds=expanded_seeds,
repetition_penalty=expanded_repetition_penalty,
frequency_penalty=expanded_frequency_penalty,
presence_penalty=expanded_presence_penalty,
pos=sampling_metadata.pos,
max_num_logprobs=sampling_metadata.max_num_logprobs,
# TODO(woosuk): Support penalties with spec decoding.
idx_mapping=sampling_metadata.idx_mapping,
prompt_bin_mask=sampling_metadata.prompt_bin_mask,
output_bin_counts=sampling_metadata.output_bin_counts,
)

View File

@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _min_p_kernel(
logits_ptr,
logits_stride,
min_p_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
min_p = tl.load(min_p_ptr + req_idx).to(tl.float32)
if min_p == 0.0:
return
max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
)
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore
threshold = max_val + tl.log(min_p)
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
)
logits = tl.where(logits < threshold, float("-inf"), logits)
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
_min_p_kernel[(num_reqs,)](
logits,
logits.stride(0),
min_p,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)

View File

@@ -0,0 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.v1.outputs import LogprobsTensors
@dataclass
class SamplerOutput:
sampled_token_ids: torch.Tensor
logprobs_tensors: LogprobsTensors | None
num_nans: torch.Tensor | None

View File

@@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
@triton.jit
def _penalties_and_temperature_kernel(
logits_ptr,
logits_stride,
repetition_penalty_ptr,
frequency_penalty_ptr,
presence_penalty_ptr,
temperature_ptr,
idx_mapping_ptr,
prompt_bin_mask_ptr,
prompt_bin_mask_stride,
output_bin_counts_ptr,
output_bin_counts_stride,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
temperature = tl.load(temperature_ptr + batch_idx)
temperature = tl.where(temperature == 0.0, 1.0, temperature)
use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0
use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
use_temperature = temperature != 1.0
if not (use_penalty or use_temperature):
# Early return to avoid loading logits.
return
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
if use_penalty:
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
output_bin_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
mask=mask,
)
output_bin_mask = output_bin_counts > 0
# Apply repetition penalties.
if use_rep_penalty:
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
packed_mask = tl.load(
prompt_bin_mask_ptr
+ req_state_idx * prompt_bin_mask_stride
+ packed_block,
mask=packed_block < tl.cdiv(vocab_size, 32),
)
prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
logits *= tl.where(logits > 0, 1.0 / scale, scale)
# Apply frequency penalties.
logits -= freq_penalty * output_bin_counts
# Apply presence penalties.
logits -= pres_penalty * output_bin_mask
# Apply temperature.
logits = logits / temperature
# Store back to logits.
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
def apply_penalties_and_temperature(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
logits,
logits.stride(0),
sampling_metadata.repetition_penalty,
sampling_metadata.frequency_penalty,
sampling_metadata.presence_penalty,
sampling_metadata.temperature,
sampling_metadata.idx_mapping,
sampling_metadata.prompt_bin_mask,
sampling_metadata.prompt_bin_mask.stride(0),
sampling_metadata.output_bin_counts,
sampling_metadata.output_bin_counts.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
def _bincount_kernel(
prefill_token_ids_ptr,
prefill_len,
prompt_len,
prompt_bin_mask_ptr,
output_bin_counts_ptr,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
if block_idx * BLOCK_SIZE >= prefill_len:
return
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if block_idx * BLOCK_SIZE < prompt_len:
mask = block < prompt_len
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
idx = prefill_tokens // 32
bit_idx = prefill_tokens % 32
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
tl.atomic_or(prompt_bin_mask_ptr + idx, bit, mask=mask)
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
mask = block < prefill_len
mask &= block >= prompt_len
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask)
def bincount(
prefill_token_ids: torch.Tensor,
prefill_len: int,
prompt_len: int,
prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor,
) -> None:
prompt_bin_mask.zero_()
output_bin_counts.zero_()
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_blocks,)](
prefill_token_ids,
prefill_len,
prompt_len,
prompt_bin_mask,
output_bin_counts,
BLOCK_SIZE=BLOCK_SIZE,
)

View File

@@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.envs as envs
from vllm.config.model import LogprobsMode
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature
class Sampler:
def __init__(
self,
logprobs_mode: LogprobsMode = "raw_logprobs",
):
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
self.logprobs_mode = logprobs_mode
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
def __call__(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.compute_nans else None
sampled, processed_logits = self.sample(logits, sampling_metadata)
if sampling_metadata.max_num_logprobs is not None:
logits = (
processed_logits
if self.logprobs_mode == "processed_logprobs"
else logits
)
logprobs_tensors = compute_topk_logprobs(
logits,
sampling_metadata.max_num_logprobs,
sampled,
)
else:
logprobs_tensors = None
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors,
num_nans=num_nans,
)
return sampler_output
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply penalties and temperature in place.
apply_penalties_and_temperature(logits, sampling_metadata)
# Apply min_p in place.
if sampling_metadata.min_p is not None:
apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p. This might return a new tensor.
logits = apply_top_k_top_p(
logits, sampling_metadata.top_k, sampling_metadata.top_p
)
sampled = gumbel_sample(
logits,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
apply_temperature=False,
)
return sampled, logits

View File

@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import VllmConfig
def init_speculator(
vllm_config: VllmConfig,
device: torch.device,
):
speculative_config = vllm_config.speculative_config
assert speculative_config is not None
if speculative_config.use_eagle():
from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator
return EagleSpeculator(vllm_config, device)
raise NotImplementedError(f"{speculative_config.method} is not supported yet.")

View File

@@ -0,0 +1,565 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager
logger = init_logger(__name__)
class EagleSpeculator:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.device = device
self.speculative_config = vllm_config.speculative_config
assert self.speculative_config is not None
self.method = self.speculative_config.method
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
self.draft_model_config = self.speculative_config.draft_model_config
self.scheduler_config = vllm_config.scheduler_config
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = vllm_config.model_config.max_model_len
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
self.vocab_size = self.draft_model_config.get_vocab_size()
self.pin_memory = is_pin_memory_available()
self.dtype = vllm_config.model_config.dtype
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
inputs_embeds_size=self.inputs_embeds_size,
vocab_size=self.vocab_size,
dtype=self.dtype,
device=device,
pin_memory=self.pin_memory,
)
self.hidden_states = torch.zeros(
self.max_num_tokens,
self.hidden_size,
dtype=self.dtype,
device=device,
)
self.temperature = torch.zeros(
self.max_num_reqs,
dtype=torch.float32,
device=device,
)
self.seeds = torch.zeros(
self.max_num_reqs,
dtype=torch.int64,
device=device,
)
self.draft_tokens = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
dtype=torch.int64,
device=device,
)
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=self.draft_model_config
)
share_lm_head = True
if share_lm_head and hasattr(target_model, "lm_head"):
if hasattr(self.model, "lm_head"):
del self.model.lm_head
self.model.lm_head = target_model.lm_head
def set_attn(
self,
kv_cache_config: KVCacheConfig,
attn_metadata_builders: list[AttentionMetadataBuilder],
block_tables: BlockTables,
) -> None:
self.kv_cache_config = kv_cache_config
self.attn_metadata_builders = attn_metadata_builders
self.block_tables = block_tables
@torch.inference_mode()
def run_model(
self,
num_tokens: int,
attn_metadata: dict[str, Any],
num_tokens_across_dp: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
):
ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids[:num_tokens],
positions=self.input_buffers.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
return last_hidden_states, hidden_states
def generate_draft(
self,
num_reqs: int,
attn_metadata: dict[str, Any],
num_tokens_across_dp: torch.Tensor | None,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
for step in range(1, self.num_speculative_steps):
# Run the eagle model.
last_hidden_states, hidden_states = self.run_model(
num_reqs, attn_metadata, num_tokens_across_dp
)
logits = self.model.compute_logits(last_hidden_states)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits,
self.temperature[:num_reqs],
self.seeds[:num_reqs],
pos + 1,
apply_temperature=True,
)
self.draft_tokens[:num_reqs, step] = draft_tokens
if step < self.num_speculative_steps - 1:
# Update the inputs for the next step.
update_eagle_inputs(
draft_tokens,
hidden_states,
self.input_buffers,
self.hidden_states,
self.max_model_len,
)
self.block_tables.compute_slot_mappings(query_start_loc, pos)
def capture_model(self) -> None:
if self.num_speculative_steps == 1:
return
logger.info("Capturing model for Eagle speculator...")
self.cudagraph_manager.capture(
self.generate_draft,
self.input_buffers,
self.block_tables,
self.attn_metadata_builders,
self.kv_cache_config,
)
@torch.inference_mode()
def propose(
self,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
# [num_tokens, hidden_size]
last_hidden_states: torch.Tensor,
# num_layers x [num_tokens, hidden_size]
aux_hidden_states: list[torch.Tensor] | None,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [num_reqs]
last_sampled: torch.Tensor,
# [num_reqs]
next_prefill_tokens: torch.Tensor,
) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
# hidden_states the same as the target model's. This means, we pad each
# request's query length to include any rejected positions. By doing so,
# we can also reuse the attention metadata (e.g., query_start_loc,
# seq_lens) of the target model.
if aux_hidden_states:
assert self.method == "eagle3"
hidden_states = self.model.combine_hidden_states(
torch.cat(aux_hidden_states, dim=-1)
)
else:
hidden_states = last_hidden_states
num_tokens = input_batch.num_tokens_after_padding
self.hidden_states[:num_tokens] = hidden_states
# Get the input ids and last token indices for the speculator.
last_token_indices = prepare_eagle_inputs(
self.input_buffers,
input_batch,
num_sampled,
num_rejected,
last_sampled,
next_prefill_tokens,
)
# Prefill: Run the eagle speculator with eager mode.
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states, hidden_states = self.run_model(
num_tokens,
input_batch.attn_metadata,
num_tokens_across_dp=None, # FIXME
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs
cu_num_logits = input_batch.cu_num_logits[:num_reqs]
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
temperature = self.temperature[:num_reqs]
seeds = self.seeds[:num_reqs]
pos = self.input_buffers.positions[:num_reqs]
# Gather the values and copy them to the pre-allocated buffers.
torch.gather(sampling_metadata.temperature, 0, cu_num_logits, out=temperature)
torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds)
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits, temperature, seeds, pos + 1, apply_temperature=True
)
if self.num_speculative_steps == 1:
# Early exit.
return draft_tokens.view(-1, 1)
# Save the draft tokens for the first step.
self.draft_tokens[:num_reqs, 0] = draft_tokens
# Prepare the inputs for the decode steps.
prepare_eagle_decode(
draft_tokens,
hidden_states,
last_token_indices,
input_batch.seq_lens,
num_rejected,
self.input_buffers,
self.hidden_states,
self.max_model_len,
self.max_num_reqs,
)
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
if cudagraph_size is not None:
# Run CUDA graph.
self.cudagraph_manager.run(cudagraph_size)
return self.draft_tokens[:num_reqs]
# Run eager mode.
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1)
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
# HACK(woosuk)
seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
seq_lens_np=seq_lens_np,
num_computed_tokens_cpu=None, # FIXME
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
self.generate_draft(num_reqs, attn_metadata, num_tokens_across_dp=None) # FIXME
return self.draft_tokens[:num_reqs]
@triton.jit
def _prepare_eagle_inputs_kernel(
last_token_indices_ptr,
eagle_input_ids_ptr,
eagle_positions_ptr,
target_input_ids_ptr,
target_positions_ptr,
last_sampled_ptr,
next_prefill_tokens_ptr,
num_sampled_ptr,
num_rejected_ptr,
query_start_loc_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
# Get the true query length and next token after accounting for rejected tokens.
num_rejected = tl.load(num_rejected_ptr + batch_idx)
query_len -= num_rejected
num_sampled = tl.load(num_sampled_ptr + batch_idx)
if num_sampled > 0:
next_token = tl.load(last_sampled_ptr + batch_idx).to(tl.int32)
else:
# Chunked prefilling.
# Get the next prefill token.
next_token = tl.load(next_prefill_tokens_ptr + batch_idx)
# Shift target_input_ids by one.
for i in range(1, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
input_ids = tl.load(target_input_ids_ptr + query_start + block, mask=mask)
tl.store(eagle_input_ids_ptr + query_start + block - 1, input_ids, mask=mask)
last_token_index = query_start + query_len - 1
tl.store(last_token_indices_ptr + batch_idx, last_token_index)
tl.store(eagle_input_ids_ptr + last_token_index, next_token)
# Copy positions.
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
target_pos = tl.load(target_positions_ptr + query_start + block, mask=mask)
tl.store(eagle_positions_ptr + query_start + block, target_pos, mask=mask)
def prepare_eagle_inputs(
input_buffers: InputBuffers,
input_batch: InputBatch,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [num_reqs]
last_sampled: torch.Tensor,
# [num_reqs]
next_prefill_tokens: torch.Tensor,
) -> torch.Tensor:
num_reqs = input_batch.num_reqs
last_token_indices = torch.empty(
num_reqs,
dtype=torch.int64,
device=num_sampled.device,
)
_prepare_eagle_inputs_kernel[(num_reqs,)](
last_token_indices,
input_buffers.input_ids,
input_buffers.positions,
input_batch.input_ids,
input_batch.positions,
last_sampled,
next_prefill_tokens,
num_sampled,
num_rejected,
input_batch.query_start_loc,
BLOCK_SIZE=1024,
)
return last_token_indices
@triton.jit
def _prepare_eagle_docode_kernel(
draft_tokens_ptr,
output_hidden_states_ptr,
output_hidden_states_stride,
last_token_indices_ptr,
target_seq_lens_ptr,
num_rejected_ptr,
input_ids_ptr,
positions_ptr,
input_hidden_states_ptr,
input_hidden_states_stride,
query_start_loc_ptr,
seq_lens_ptr,
hidden_size,
max_model_len,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_reqs = tl.num_programs(0) - 1
if req_idx == num_reqs:
# Compute query_start_loc. Pad it with the last query_start_loc
# for CUDA graphs.
for i in range(0, max_num_reqs + 1, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
q = tl.where(block < num_reqs, block, num_reqs)
mask = block < max_num_reqs + 1
tl.store(query_start_loc_ptr + block, q, mask=mask)
# Pad seq_lens for CUDA graphs.
for i in range(req_idx, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(seq_lens_ptr + block, 0, mask=mask)
return
# draft token -> input id.
draft_token = tl.load(draft_tokens_ptr + req_idx)
tl.store(input_ids_ptr + req_idx, draft_token)
# output hidden states -> input hidden states.
src_idx = tl.load(last_token_indices_ptr + req_idx)
for i in range(0, hidden_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < hidden_size
output_hidden_states = tl.load(
output_hidden_states_ptr + src_idx * output_hidden_states_stride + block,
mask=mask,
)
tl.store(
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
output_hidden_states,
mask=mask,
)
# Compute position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position = tl.load(positions_ptr + req_idx)
position = tl.minimum(position + 1, max_model_len - 1)
tl.store(positions_ptr + req_idx, position)
target_seq_len = tl.load(target_seq_lens_ptr + req_idx)
num_rejected = tl.load(num_rejected_ptr + req_idx)
seq_len = target_seq_len - num_rejected
seq_len = tl.minimum(seq_len + 1, max_model_len)
tl.store(seq_lens_ptr + req_idx, seq_len)
def prepare_eagle_decode(
draft_tokens: torch.Tensor,
output_hidden_states: torch.Tensor,
last_token_indices: torch.Tensor,
target_seq_lens: torch.Tensor,
num_rejected: torch.Tensor,
input_buffers: InputBuffers,
input_hidden_states: torch.Tensor,
max_model_len: int,
max_num_reqs: int,
):
num_reqs = draft_tokens.shape[0]
hidden_size = output_hidden_states.shape[-1]
_prepare_eagle_docode_kernel[(num_reqs + 1,)](
draft_tokens,
output_hidden_states,
output_hidden_states.stride(0),
last_token_indices,
target_seq_lens,
num_rejected,
input_buffers.input_ids,
input_buffers.positions,
input_hidden_states,
input_hidden_states.stride(0),
input_buffers.query_start_loc.gpu,
input_buffers.seq_lens,
hidden_size,
max_model_len,
max_num_reqs,
BLOCK_SIZE=1024,
)
@triton.jit
def _update_eagle_inputs_kernel(
input_ids_ptr,
positions_ptr,
input_hidden_states_ptr,
input_hidden_states_stride,
seq_lens_ptr,
max_model_len,
draft_tokens_ptr,
output_hidden_states_ptr,
output_hidden_states_stride,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
# Draft token -> Input ID.
draft_token = tl.load(draft_tokens_ptr + req_idx)
tl.store(input_ids_ptr + req_idx, draft_token)
# Output hidden states -> Input hidden states.
for i in range(0, hidden_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < hidden_size
output_hidden_states = tl.load(
output_hidden_states_ptr + req_idx * output_hidden_states_stride + block,
mask=mask,
)
tl.store(
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
output_hidden_states,
mask=mask,
)
# Increment position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position = tl.load(positions_ptr + req_idx)
position = tl.minimum(position + 1, max_model_len - 1)
tl.store(positions_ptr + req_idx, position)
seq_len = tl.load(seq_lens_ptr + req_idx)
seq_len = tl.minimum(seq_len + 1, max_model_len)
tl.store(seq_lens_ptr + req_idx, seq_len)
def update_eagle_inputs(
draft_tokens: torch.Tensor,
output_hidden_states: torch.Tensor,
input_buffers: InputBuffers,
hidden_states: torch.Tensor,
max_model_len: int,
):
num_reqs, hidden_size = output_hidden_states.shape
_update_eagle_inputs_kernel[(num_reqs,)](
input_buffers.input_ids,
input_buffers.positions,
hidden_states,
hidden_states.stride(0),
input_buffers.seq_lens,
max_model_len,
draft_tokens,
output_hidden_states,
output_hidden_states.stride(0),
hidden_size,
BLOCK_SIZE=1024,
)

View File

@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import (
capture_graphs,
get_cudagraph_sizes,
prepare_inputs_to_capture,
)
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
class EagleCudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
cudagraph_mode: CUDAGraphMode
if self.compilation_config.cudagraph_mode is None:
cudagraph_mode = CUDAGraphMode.NONE
else:
cudagraph_mode = self.compilation_config.cudagraph_mode
if cudagraph_mode == CUDAGraphMode.FULL:
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
self.cudagraph_mode = cudagraph_mode
self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens)
def capture_graph(
self,
num_tokens: int,
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
num_reqs = min(num_tokens, self.max_num_reqs)
attn_metadata = prepare_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
block_tables,
attn_metadata_builders,
self.max_model_len,
kv_cache_config,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.pool):
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
self.graphs[num_tokens] = graph
@torch.inference_mode()
def capture(
self,
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
generate_fn=generate_fn,
input_buffers=input_buffers,
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
kv_cache_config=kv_cache_config,
)
def run(self, num_tokens: int) -> None:
assert num_tokens in self.graphs
self.graphs[num_tokens].replay()

View File

@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _rejection_sample_kernel(
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
sampled_stride,
num_sampled_ptr, # [num_reqs]
target_sampled_ptr, # [num_draft_tokens + num_reqs]
input_ids_ptr, # [num_draft_tokens + num_reqs]
cu_num_logits_ptr, # [num_reqs + 1]
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
num_sampled = 0
rejected = False
for i in range(num_tokens - 1):
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
num_sampled += 1
if target_sampled != draft_sampled:
rejected = True
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
)
num_sampled += 1
tl.store(num_sampled_ptr + req_idx, num_sampled)
def rejection_sample(
# [num_draft_tokens + num_reqs]
target_sampled: torch.Tensor,
# [num_draft_tokens + num_reqs]
input_ids: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_sampled,
input_ids,
cu_num_logits,
num_warps=1,
)
return sampled, num_sampled

View File

@@ -0,0 +1,316 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.penalties import bincount
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
NO_LORA_ID = 0
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
num_speculative_steps: int,
vocab_size: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.num_speculative_steps = num_speculative_steps
self.vocab_size = vocab_size
self.device = device
self.pin_memory = pin_memory
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs))
self.extra_data: dict[str, ExtraData] = {}
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# depending on the configured max_num_reqs and max_model_len.
self.prefill_token_ids = UvaBuffer(
self.max_num_reqs, self.max_model_len, dtype=torch.int32
)
# NOTE(woosuk): We don't use UVA for prefill_len because its GPU view
# can be used outside of update_states and prepare_inputs.
# Without async barrier, using UVA can cause race conditions.
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
# Number of computed tokens.
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
# Last sampled tokens.
self.last_sampled_tokens = torch.zeros(
self.max_num_reqs,
1,
dtype=torch.int64,
device=device,
)
# Draft tokens.
self.draft_tokens = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
dtype=torch.int64,
device=device,
)
self.next_prefill_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
# LoRA.
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
# Sampling parameters.
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
self.min_p = self._make_param(self.max_num_reqs, torch.float32)
self.repetition_penalty = self._make_param(self.max_num_reqs, torch.float32)
self.frequency_penalty = self._make_param(self.max_num_reqs, torch.float32)
self.presence_penalty = self._make_param(self.max_num_reqs, torch.float32)
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(-1)
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
# Statistics for penalties.
self.prompt_bin_mask = torch.zeros(
self.max_num_reqs,
cdiv(self.vocab_size, 32),
dtype=torch.int32,
device=self.device,
)
# TODO(woosuk): This tensor is rarely used but can be extremely large.
# Optimize the memory usage.
self.output_bin_counts = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
def add_request(
self,
req_id: str,
prompt_len: int,
prefill_token_ids: list[int],
num_computed_tokens: int,
sampling_params: SamplingParams,
lora_request: LoRARequest | None,
) -> None:
assert len(self.free_indices) > 0, "No free indices"
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
self.extra_data[req_id] = ExtraData(lora_request)
self.prompt_len[req_idx] = prompt_len
prefill_len = len(prefill_token_ids)
assert prefill_len >= prompt_len, (
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
)
self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids.np[req_idx, :prefill_len] = prefill_token_ids
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
# FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
# Optimize this.
self.num_computed_tokens[req_idx] = num_computed_tokens
if lora_request is not None:
self.lora_ids[req_idx] = lora_request.lora_int_id
else:
self.lora_ids[req_idx] = NO_LORA_ID
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.min_p.np[req_idx] = sampling_params.min_p
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
if use_penalty(sampling_params):
bincount(
self.prefill_token_ids.gpu[req_idx],
prefill_len,
prompt_len,
self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx],
)
if sampling_params.seed is not None:
seed = sampling_params.seed
else:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
num_logprobs = -1
self.num_logprobs[req_idx] = num_logprobs
# For now, only support prompt logprobs for the prompt tokens.
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
def remove_request(self, req_id: str) -> None:
self.extra_data.pop(req_id, None)
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
def make_sampling_metadata(
self,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> SamplingMetadata:
temperature = self.temperature.np[idx_mapping_np]
temperature = self.temperature.copy_np_to_gpu(temperature)
top_p = self.top_p.np[idx_mapping_np]
no_top_p = np.all(top_p == 1.0)
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
top_k = self.top_k.np[idx_mapping_np]
no_top_k = np.all(top_k == self.vocab_size)
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
min_p = self.min_p.np[idx_mapping_np]
no_min_p = np.all(min_p == 0.0)
min_p = self.min_p.copy_np_to_gpu(min_p) if not no_min_p else None
rep_penalty = self.repetition_penalty.np[idx_mapping_np]
rep_penalty = self.repetition_penalty.copy_np_to_gpu(rep_penalty)
freq_penalty = self.frequency_penalty.np[idx_mapping_np]
freq_penalty = self.frequency_penalty.copy_np_to_gpu(freq_penalty)
pres_penalty = self.presence_penalty.np[idx_mapping_np]
pres_penalty = self.presence_penalty.copy_np_to_gpu(pres_penalty)
seeds = self.seeds.np[idx_mapping_np]
seeds = self.seeds.copy_np_to_gpu(seeds)
num_logprobs = self.num_logprobs[idx_mapping_np]
max_num_logprobs: int | None = int(np.max(num_logprobs))
if max_num_logprobs == -1:
max_num_logprobs = None
return SamplingMetadata(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
repetition_penalty=rep_penalty,
frequency_penalty=freq_penalty,
presence_penalty=pres_penalty,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
idx_mapping=idx_mapping,
prompt_bin_mask=self.prompt_bin_mask,
output_bin_counts=self.output_bin_counts,
)
def make_lora_inputs(
self,
req_ids: list[str],
idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray,
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
lora_ids = self.lora_ids[idx_mapping]
prompt_lora_mapping = tuple(lora_ids)
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set()
for req_id in req_ids:
lora_request = self.extra_data[req_id].lora_request
if lora_request is not None:
active_lora_requests.add(lora_request)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
class Param:
def __init__(
self,
size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.buffer = CpuGpuBuffer(
size,
dtype=dtype,
device=device,
pin_memory=pin_memory,
)
self.np = np.zeros_like(self.buffer.np)
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
n = x.shape[0]
self.buffer.np[:n] = x
return self.buffer.copy_to_gpu(n)
@dataclass
class ExtraData:
lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
class UvaBuffer:
def __init__(self, *size: int | torch.SymInt, dtype: torch.dtype):
assert is_uva_available()
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
def use_penalty(sampling_params: SamplingParams) -> bool:
return (
sampling_params.repetition_penalty != 1.0
or sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
)

View File

@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.input_batch import InputBuffers
def apply_grammar_bitmask(
logits: torch.Tensor,
req_ids: list[str],
grammar_req_ids: list[str],
grammar_bitmask: np.ndarray,
input_buffers: InputBuffers,
) -> None:
input_buffers.grammar_bitmask.np[: grammar_bitmask.shape[0]] = grammar_bitmask
input_buffers.grammar_bitmask.copy_to_gpu(grammar_bitmask.shape[0])
batch_size = logits.shape[0]
grammar_req_id_to_idx = {req_id: i for i, req_id in enumerate(grammar_req_ids)}
# logits -> bitmask mapping
mapping = [grammar_req_id_to_idx.get(req_id, -1) for req_id in req_ids]
input_buffers.bitmask_indices.np[:batch_size] = mapping
input_buffers.bitmask_indices.copy_to_gpu(batch_size)
vocab_size = logits.shape[-1]
BLOCK_SIZE = 8192
grid = (batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))
_apply_grammar_bitmask_kernel[grid](
logits,
logits.stride(0),
input_buffers.grammar_bitmask.gpu,
input_buffers.grammar_bitmask.gpu.stride(0),
input_buffers.bitmask_indices.gpu,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
# Adapted from
# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
@triton.jit
def _apply_grammar_bitmask_kernel(
logits_ptr,
logits_stride,
bitmask_ptr,
bitmask_stride,
bitmask_indices_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
logits_idx = tl.program_id(0)
bitmask_idx = tl.load(bitmask_indices_ptr + logits_idx)
if bitmask_idx == -1:
# No bitmask to apply.
return
# Load the bitmask.
block_id = tl.program_id(1)
bitmask_offset = (block_id * BLOCK_SIZE) // 32 + tl.arange(0, BLOCK_SIZE // 32)
packed_bitmask = tl.load(
bitmask_ptr + bitmask_idx * bitmask_stride + bitmask_offset,
mask=bitmask_offset < bitmask_stride,
)
# Unpack the bitmask.
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)
# Apply the bitmask to the logits.
block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
logits_ptr + logits_idx * logits_stride + block_offset,
-float("inf"),
mask=bitmask & (block_offset < vocab_size),
)

View File

@@ -0,0 +1,990 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from typing import cast
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
from vllm.v1.sample.logits_processor import (
BatchUpdateBuilder,
LogitsProcessors,
MoveDirectionality,
)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable
@dataclass
class CachedRequestState:
req_id: str
prompt_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec]
sampling_params: SamplingParams | None
generator: torch.Generator | None
block_ids: tuple[list[int], ...]
num_computed_tokens: int
output_token_ids: list[int]
mrope_positions: torch.Tensor | None = None
mrope_position_delta: int | None = None
xdrope_positions: torch.Tensor | None = None
lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None
# Used when both async_scheduling and spec_decode are enabled.
prev_num_draft_len: int = 0
# for pooling models
pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
)
if self.pooling_params is not None:
self.pooling_states = PoolingStates()
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
if self.prompt_token_ids is None:
raise ValueError(
f"Tried to access token index {idx}, but that token was "
"provided via prompt_embeds, and its ID is unknown."
)
return self.prompt_token_ids[idx]
if idx - self.num_prompt_tokens < len(self.output_token_ids):
return self.output_token_ids[idx - self.num_prompt_tokens]
return -1
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[int],
logitsprocs: LogitsProcessors | None = None,
logitsprocs_need_output_token_ids: bool = False,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
cp_kv_cache_interleave_size: int = 1,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self._req_ids: list[str | None] = []
self.req_id_to_index: dict[str, int] = {}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.is_token_ids_tensor = torch.zeros(
(max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
)
self.is_token_ids = self.is_token_ids_tensor.numpy()
# Store prompt embeddings per request to avoid OOM from large upfront
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
# Sampling-related.
self.temperature = torch.empty(
(max_num_reqs,), dtype=torch.float32, device=device
)
self.temperature_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
self.top_p_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
self.top_k_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set()
# IDs of requests which do not support spec decoding
self.spec_decode_unsupported_reqs: set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty(
(max_num_reqs,), dtype=torch.float, device=device
)
self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
)
self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures
self.presence_penalties = torch.empty(
(max_num_reqs,), dtype=torch.float, device=device
)
self.presence_penalties_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures
self.repetition_penalties = torch.empty(
(max_num_reqs,), dtype=torch.float, device=device
)
self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
)
self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# Speculative decoding
self.num_accepted_tokens_cpu_tensor = torch.ones(
(max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
)
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
# Internal representation of per-step batch state changes, used for
# reordering persistent batch and generating logitsprocs batch state
# updates. Should reset each step.
self.batch_update_builder = BatchUpdateBuilder()
# TODO convert this to LogitsProcessor
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: torch.Tensor | None = None
self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
self.req_output_token_ids: list[list[int] | None] = []
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
# Store last speculative tokens for sampler.
self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
# for pooling models
self.pooling_params: dict[str, PoolingParams] = {}
self.pooling_states: dict[str, PoolingStates] = {}
# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: torch.Tensor | None = None
self.prev_req_id_to_index: dict[str, int] | None = None
# These are used to update output_token_ids with real sampled
# ids from prior step, if required by current sampling params
# (e.g. penalties).
self.sampled_token_ids_cpu: torch.Tensor | None = None
self.async_copy_ready_event: torch.Event | None = None
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
# while performing state updates to the batch.
return cast(list[str], self._req_ids)
def _register_add_request(self, request: "CachedRequestState") -> int:
"""Track add-request operations for logits processors.
Not applicable to pooling models.
"""
# Fill the next empty index if there is one.
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
# Append to end otherwise.
new_req_index = self.num_reqs
assert new_req_index < self.max_num_reqs
self.batch_update_builder.batch_changed = True
if request.sampling_params:
# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs.
self.batch_update_builder.added.append(
(
new_req_index,
request.sampling_params,
request.prompt_token_ids,
request.output_token_ids,
)
)
return new_req_index
def add_request(
self,
request: "CachedRequestState",
) -> int:
req_index = self._register_add_request(request)
req_id = request.req_id
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
self.spec_token_ids.append([])
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.spec_token_ids[req_index].clear()
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds
)
self.num_prompt_tokens[req_index] = num_prompt_tokens
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params:
if self.is_spec_decode and is_spec_decode_unsupported(sampling_params):
self.spec_decode_unsupported_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.GREEDY:
# Should avoid division by zero later when apply_temperature.
self.temperature_cpu[req_index] = 0.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[req_index] = (
sampling_params.repetition_penalty
)
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = (
self.vocab_size
if sampling_params.logprobs == -1
else sampling_params.logprobs
)
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device,
)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu",
)
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids
] = False
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[req_index] = (
sampling_params.bad_words_token_ids
)
elif pooling_params := request.pooling_params:
pooling_states = request.pooling_states
assert pooling_states is not None
self.pooling_params[req_id] = pooling_params
self.pooling_states[req_id] = pooling_states
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids
)
else:
raise NotImplementedError("Unrecognized request type")
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0
return req_index
def remove_request(self, req_id: str) -> int | None:
"""This method must always be followed by a call to condense().
Args:
req_id: request to remove
Returns:
Removed request index, or `None` if `req_id` not recognized
"""
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.spec_token_ids[req_index].clear()
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0
if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
self.pooling_states.pop(req_id, None)
return req_index
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.spec_decode_unsupported_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
if self.prev_req_id_to_index is not None:
self.prev_req_id_to_index.pop(req_id, None)
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
self.req_output_token_ids[i2],
self.req_output_token_ids[i1],
)
self.spec_token_ids[i1], self.spec_token_ids[i2] = (
self.spec_token_ids[i2],
self.spec_token_ids[i1],
)
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
self.req_id_to_index[old_id_i2],
self.req_id_to_index[old_id_i1],
)
self.num_tokens[i1], self.num_tokens[i2] = (
self.num_tokens[i2],
self.num_tokens[i1],
)
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
self.num_tokens_no_spec[i2],
self.num_tokens_no_spec[i1],
)
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
self.num_prompt_tokens[i2],
self.num_prompt_tokens[i1],
)
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
self.num_computed_tokens_cpu[i2],
self.num_computed_tokens_cpu[i1],
)
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
# Swap prompt embeddings if they exist
embeds_i1 = self.req_prompt_embeds.get(i1)
embeds_i2 = self.req_prompt_embeds.get(i2)
if embeds_i1 is not None:
self.req_prompt_embeds[i2] = embeds_i1
else:
self.req_prompt_embeds.pop(i2, None)
if embeds_i2 is not None:
self.req_prompt_embeds[i1] = embeds_i2
else:
self.req_prompt_embeds.pop(i1, None)
self.block_table.swap_row(i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
self.request_lora_mapping[i2],
self.request_lora_mapping[i1],
)
if self.is_pooling_model:
# Sampling and logits parameters don't apply to pooling models.
return
# For autoregressive models, track detailed request reordering info
# to support logitsprocs.
self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))
self.temperature_cpu[i1], self.temperature_cpu[i2] = (
self.temperature_cpu[i2],
self.temperature_cpu[i1],
)
self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
self.frequency_penalties_cpu[i2],
self.frequency_penalties_cpu[i1],
)
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
self.presence_penalties_cpu[i2],
self.presence_penalties_cpu[i1],
)
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
self.repetition_penalties_cpu[i2],
self.repetition_penalties_cpu[i1],
)
self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
self.num_accepted_tokens_cpu[i2],
self.num_accepted_tokens_cpu[i1],
)
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
if self.allowed_token_ids_mask_cpu_tensor is not None:
(
self.allowed_token_ids_mask_cpu_tensor[i1],
self.allowed_token_ids_mask_cpu_tensor[i2],
) = (
self.allowed_token_ids_mask_cpu_tensor[i2],
self.allowed_token_ids_mask_cpu_tensor[i1],
)
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
Any consecutive empty indices at the very end of the list are not
filled.
Returns:
swaps: list of (from,to) swap tuples for moved requests
empty_req_indices: indices not filled by condensation
"""
num_reqs = self.num_reqs
if not (empty_req_indices := self.batch_update_builder.removed):
# All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed
return
if num_reqs == 0:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
self.spec_token_ids.clear()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = self.batch_update_builder.peek_removed()
assert empty_index is not None
if empty_index >= last_req_index:
break
# Move active request down into empty request
# index.
self.batch_update_builder.pop_removed()
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
self._req_ids[empty_index] = req_id
self._req_ids[last_req_index] = None
self.req_output_token_ids[empty_index] = output_token_ids
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
if last_req_index != empty_index:
(
self.spec_token_ids[last_req_index],
self.spec_token_ids[empty_index],
) = (
self.spec_token_ids[empty_index],
self.spec_token_ids[last_req_index],
)
self.spec_token_ids[last_req_index].clear()
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens
]
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
last_req_index, :num_tokens
]
if last_req_index in self.req_prompt_embeds:
self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
last_req_index
)
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index
]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
last_req_index
]
self.block_table.move_row(last_req_index, empty_index)
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index
]
if self.is_pooling_model:
last_req_index -= 1
# Sampling state not used by pooling models.
continue
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append(
(last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
)
self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
last_req_index
]
self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
last_req_index
]
self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
last_req_index
]
self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
last_req_index
]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
# TODO convert these to LogitsProcessors
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
self.allowed_token_ids_mask_cpu_tensor[last_req_index]
)
bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index -= 1
# Trim lists to the batch size.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
del self.spec_token_ids[num_reqs:]
def refresh_metadata(self):
"""Apply any batch updates to sampling metadata."""
if self.is_pooling_model:
batch_changed = self.batch_update_builder.reset()
if batch_changed:
self.sampling_metadata = self._make_sampling_metadata()
return
# For non-pooling models - generate and apply logitsprocs update;
# reset batch update tracking.
# Update sampling metadata if batch state is changed.
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
if batch_update:
self.sampling_metadata = self._make_sampling_metadata()
def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
if not self.all_greedy:
temperature = copy_slice(
self.temperature_cpu_tensor, self.temperature, num_reqs
)
else:
temperature = None
if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
copy_slice(
self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs
)
copy_slice(
self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs
)
copy_slice(
self.repetition_penalties_cpu_tensor,
self.repetition_penalties,
num_reqs,
)
needs_prompt_token_ids = (
not self.no_penalties
or self.logits_processing_needs_token_ids[:num_reqs].any()
)
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = (
self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
)
# Only set output_token_ids if required by the current requests'
# sampling parameters.
needs_output_token_ids = (
not self.no_penalties
or bool(self.bad_words_token_ids)
or self.logitsprocs_need_output_token_ids
)
output_token_ids = (
cast(list[list[int]], self.req_output_token_ids)
if needs_output_token_ids
else []
)
allowed_token_ids_mask: torch.Tensor | None = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(
self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask,
num_reqs,
)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=output_token_ids,
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
)
def get_pooling_params(self) -> list[PoolingParams]:
assert len(self.req_ids) == len(self.pooling_params)
return [self.pooling_params[req_id] for req_id in self.req_ids]
def get_pooling_states(self) -> list[PoolingStates]:
assert len(self.req_ids) == len(self.pooling_states)
return [self.pooling_states[req_id] for req_id in self.req_ids]
def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params()
pooling_states = self.get_pooling_states()
return PoolingMetadata(
prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
pooling_states=pooling_states,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
num_reqs = self.num_reqs
max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
where, prompt_lora_mapping[i] is the LoRA id to use for the ith
sampled token.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values()
)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
def set_async_sampled_token_ids(
self,
sampled_token_ids_cpu: torch.Tensor,
async_copy_ready_event: torch.Event,
) -> None:
"""
In async scheduling case, store ref to sampled_token_ids_cpu
tensor and corresponding copy-ready event. Used to repair
output_token_ids prior to sampling, if needed by logits processors.
"""
if self.sampling_metadata.output_token_ids:
self.sampled_token_ids_cpu = sampled_token_ids_cpu
self.async_copy_ready_event = async_copy_ready_event
else:
self.sampled_token_ids_cpu = None
self.async_copy_ready_event = None
def update_async_output_token_ids(self) -> None:
"""
In async scheduling case, update output_token_ids in sampling metadata
from prior steps sampled token ids once they've finished copying to CPU.
This is called right before they are needed by the logits processors.
"""
output_token_ids = self.sampling_metadata.output_token_ids
if self.sampled_token_ids_cpu is None or not output_token_ids:
# Output token ids not needed or not async scheduling.
return
assert self.prev_req_id_to_index is not None
sampled_token_ids = None
for index, req_id in enumerate(self.req_ids):
prev_index = self.prev_req_id_to_index.get(req_id)
if prev_index is None:
continue
req_output_token_ids = output_token_ids[index]
if not req_output_token_ids or req_output_token_ids[-1] != -1:
# Final output id is not a placeholder, some tokens must have
# been discarded after a kv-load failure.
continue
if sampled_token_ids is None:
assert self.async_copy_ready_event is not None
self.async_copy_ready_event.synchronize()
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist()
# Replace placeholder token id with actual sampled id.
req_output_token_ids[-1] = sampled_token_ids[prev_index]
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_penalties(self) -> bool:
return (
len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0
)
@property
def max_num_logprobs(self) -> int | None:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,472 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
import torch
import vllm.envs as envs
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import (
DPMetadata,
create_forward_context,
get_forward_context,
override_forward_context,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import has_deep_gemm
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
logger = init_logger(__name__)
@dataclass
class UbatchMetadata:
context: UBatchContext
input_ids: torch.Tensor
positions: torch.Tensor
inputs_embeds: torch.Tensor | None
intermediate_tensors: IntermediateTensors | None
num_tokens: int
@dataclass
class CUDAGraphMetaData:
cudagraph: torch.cuda.CUDAGraph
ubatch_metadata: UbatchMetadata
outputs: Any | None = None
class SMControlContextManager:
def __init__(
self,
comm_sms: int,
set_comm_sms: Callable[[int], None],
set_compute_sms: Callable[[int], None],
):
"""
Context manager for controlling SM (Streaming Multiprocessor)
allocation. Upon entering the context, it sets the number of SMs
allocated for communication and computation to comm_sms and
total_sms - comm_sms respectively. Upon exiting, it restores the
allocation to use all available SMs (i.e. total_sms).
Args:
comm_sms (int): The number of SMs to allocate for communication.
(The remainder will be used for computation.)
set_comm_sms (Callable[[int], None]):
A function that sets the number of SMs for communication.
set_compute_sms (Callable[[int], None]):
A function that sets the number of SMs for computation.
"""
assert current_platform.is_cuda(), (
"SM control is currently only supported on CUDA"
)
props = torch.cuda.get_device_properties(torch.cuda.current_device())
total_sms = props.multi_processor_count
assert comm_sms < total_sms
self.total_sms = total_sms
self.compute_sms = total_sms - comm_sms
self.comm_sms = comm_sms
self.set_comm_sms = set_comm_sms
self.set_compute_sms = set_compute_sms
def __enter__(self):
self.set_comm_sms(self.comm_sms)
self.set_compute_sms(self.compute_sms)
def __exit__(self, exc_type, exc_value, traceback):
self.set_comm_sms(self.total_sms)
self.set_compute_sms(self.total_sms)
class UBatchWrapper:
def __init__(
self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
device: torch.cuda.device,
):
self.runnable = runnable
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.comm_stream = torch.cuda.Stream(device=device)
# Two ubatch threads plus the main thread
self.ready_barrier = threading.Barrier(3)
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
self.cudagraph_wrapper = None
self.graph_pool = None
if runtime_mode is not CUDAGraphMode.NONE:
self.cudagraph_wrapper = CUDAGraphWrapper(
runnable, vllm_config, runtime_mode=runtime_mode
)
self.graph_pool = current_platform.get_global_graph_pool()
self.sm_control = self._create_sm_control_context(vllm_config)
self.device = device
@staticmethod
def _create_sm_control_context(vllm_config: VllmConfig):
comm_sms: int = envs.VLLM_DBO_COMM_SMS
set_comm_sms = lambda sms: None
if vllm_config.parallel_config.enable_expert_parallel:
# Currently only DeepEP highthroughput supports SM control so this
# only affects that case.
ep_group = get_ep_group()
device_communicator = ep_group.device_communicator
all2all_manager = None
if device_communicator is not None:
all2all_manager = device_communicator.all2all_manager
if all2all_manager is not None:
max_sms_used = all2all_manager.max_sms_used()
if max_sms_used is not None:
comm_sms = min(comm_sms, max_sms_used)
if comm_sms > 0 and all2all_manager is not None:
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
# TODO(lucas): support other kernels besides DeepGEMM
set_compute_sms = lambda sms: None
if has_deep_gemm() and comm_sms > 0:
import deep_gemm as dg
set_compute_sms = lambda sms: dg.set_num_sms(sms)
return SMControlContextManager(
comm_sms=comm_sms,
set_comm_sms=set_comm_sms,
set_compute_sms=set_compute_sms,
)
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(
f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}"
)
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
"""
Capture a cudagraph for a microbatched run.
The logic here is somewhat complicated because we need to make sure that
each of the ubatch threads initialize the cuda context before we start
the graph capture.
The flow is as follows:
1. The main thread starts up each ubatch thread. Each thread will
initialize its cuda context (torch.cuda.current_blas_handle())
before going to sleep upon entering the ubatch_context.
2. The main thread starts the graph capture and wakes up the first
ubatch thread.
3. Each ubatch thread runs the model to completion and returns the
completed output tensors back to the main thread.
4. The main thread stores the captured cudagraph along with its metadata
and returns
"""
@torch.inference_mode()
def _capture_ubatch_thread(results, ubatch_metadata):
torch.cuda.set_device(self.device)
ubatch_context = ubatch_metadata.context
with torch.cuda.stream(ubatch_context.compute_stream):
_ = torch.cuda.current_blas_handle()
with torch.cuda.stream(ubatch_context.comm_stream):
_ = torch.cuda.current_blas_handle()
with ubatch_context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
results: list[tuple[int, torch.Tensor]] = []
compute_stream = ubatch_metadata[0].context.compute_stream
num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(
target=_capture_ubatch_thread,
args=(
results,
metadata,
),
)
ubatch_threads.append(thread)
thread.start()
self.ready_barrier.wait() # Wait for both threads to be ready
# Capture the cudagraph
cudagraph_metadata = CUDAGraphMetaData(
cudagraph=torch.cuda.CUDAGraph(),
ubatch_metadata=ubatch_metadata,
)
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
with torch.cuda.graph(
cudagraph_metadata.cudagraph,
stream=compute_stream,
pool=self.graph_pool,
):
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
cudagraph_metadata.outputs = result
self.cudagraphs[num_tokens] = cudagraph_metadata
return cudagraph_metadata.outputs
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
@torch.inference_mode()
def _ubatch_thread(results, model, ubatch_metadata):
with ubatch_metadata.context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
results: list[tuple[int, torch.Tensor]] = []
# Ubatch threads will manually manage the forward context, so we
# override it to None here so we can have it restored correctly
# after both threads have finished
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(
target=_ubatch_thread,
args=(
results,
model,
metadata,
),
)
ubatch_threads.append(thread)
thread.start()
self.ready_barrier.wait() # Wait for both threads to be ready
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
return result
def _make_ubatch_metadata(
self,
ubatch_slices,
attn_metadata,
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
compute_stream,
dp_metadata,
batch_descriptor,
cudagraph_runtime_mode,
) -> list[UbatchMetadata]:
# Create one forward context per ubatch
forward_contexts = []
for i, ubatch_slice in enumerate(ubatch_slices):
forward_contexts.append(
create_forward_context(
attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=cudagraph_runtime_mode,
)
)
ubatch_ctxs = make_ubatch_contexts(
num_micro_batches=len(ubatch_slices),
comm_stream=self.comm_stream,
compute_stream=compute_stream,
forward_contexts=forward_contexts,
ready_barrier=self.ready_barrier,
)
ubatch_metadata: list[UbatchMetadata] = []
for i, ubatch_slice in enumerate(ubatch_slices):
(
sliced_input_ids,
sliced_positions,
sliced_inputs_embeds,
sliced_intermediate_tensors,
) = self._slice_model_inputs(
ubatch_slice.token_slice,
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
)
ubatch_metadata.append(
UbatchMetadata(
context=ubatch_ctxs[i],
input_ids=sliced_input_ids,
positions=sliced_positions,
inputs_embeds=sliced_inputs_embeds,
intermediate_tensors=sliced_intermediate_tensors,
num_tokens=ubatch_slice.token_slice.stop
- ubatch_slice.token_slice.start,
)
)
return ubatch_metadata
def _slice_model_inputs(
self,
tokens_slice: slice,
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
):
sliced_input_ids = input_ids[tokens_slice]
# if we are using mrope. Mrope adds an additional dimension to the
# positions tensor
if positions.ndim == 2:
sliced_positions = positions[:, tokens_slice]
else:
sliced_positions = positions[tokens_slice]
sliced_inputs_embeds = inputs_embeds[tokens_slice] if inputs_embeds else None
sliced_intermediate_tensors = (
intermediate_tensors[tokens_slice] if intermediate_tensors else None
)
return (
sliced_input_ids,
sliced_positions,
sliced_inputs_embeds,
sliced_intermediate_tensors,
)
def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
ubatch_slices = forward_context.ubatch_slices
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
# If there's no ubatching, just run the runnable object
if ubatch_slices is None:
# This is to account for the case where ubatching was aborted.
# When we capture full graphs we only capture one graph per shape,
# meaning that if we have a ubatched cudagraph for the current
# num_tokens, we don't have a non-ubatched one. Without this
# check, the cudagraph wrapper will try to capture a cudagraph
# for this shape during a normal run.
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
assert batch_descriptor is not None
if batch_descriptor.num_tokens in self.cudagraphs:
cudagraph_runtime_mode = CUDAGraphMode.NONE
if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE):
return self.runnable(*args, **kwargs)
else:
assert self.cudagraph_wrapper is not None
return self.cudagraph_wrapper(*args, **kwargs)
attn_metadata = forward_context.attn_metadata
num_tokens = (
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
) * 2
input_ids = kwargs["input_ids"]
positions = kwargs["positions"]
intermediate_tensors = kwargs["intermediate_tensors"]
inputs_embeds = kwargs["inputs_embeds"]
compute_stream = torch.cuda.current_stream()
dp_metadata = forward_context.dp_metadata
# We shouldn't be here unless we are running with multiple DP ranks
assert dp_metadata is not None
num_tokens_per_ubatch = (
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
)
dp_size = self.vllm_config.parallel_config.data_parallel_size
ubatch_num_tokens_across_dp = torch.tensor(
[num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32
)
ubatch_dp_metadata = DPMetadata.make(
self.vllm_config.parallel_config,
num_tokens_per_ubatch,
ubatch_num_tokens_across_dp,
)
if (
num_tokens not in self.cudagraphs
and cudagraph_runtime_mode is CUDAGraphMode.FULL
):
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=ubatch_dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
)
with self.sm_control:
return self._capture_ubatches(ubatch_metadata, self.model)
elif (
num_tokens in self.cudagraphs
and cudagraph_runtime_mode is CUDAGraphMode.FULL
):
cudagraph_metadata = self.cudagraphs[num_tokens]
cudagraph_metadata.cudagraph.replay()
return cudagraph_metadata.outputs
else:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
)
with self.sm_control:
return self._run_ubatches(ubatch_metadata, self.model)

View File

@@ -0,0 +1,955 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""
import gc
import os
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any, cast
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.distributed import (
ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce,
)
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
get_kv_transfer_group,
has_kv_transfer_group,
)
from vllm.distributed.parallel_state import (
get_pcp_group,
get_pp_group,
get_tp_group,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (
AsyncModelRunnerOutput,
DraftTokenIds,
ModelRunnerOutput,
)
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
class Worker(WorkerBase):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(
vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
# configure float32 matmul precision according to vLLM env.
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
torch.backends.cuda.matmul.fp32_precision = precision
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# Torch/CUDA profiler. Enabled and configured through profiler_config.
self.profiler: Any | None = None
profiler_config = vllm_config.profiler_config
if profiler_config.profiler == "torch":
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
self.profiler = TorchProfilerWrapper(
profiler_config,
worker_name=worker_name,
local_rank=self.local_rank,
activities=["CPU", "CUDA"],
)
elif profiler_config.profiler == "cuda":
self.profiler = CudaProfilerWrapper(profiler_config)
else:
self.profiler = None
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
def sleep(self, level: int = 1) -> None:
from vllm.device_allocator.cumem import CuMemAllocator
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
# Save the buffers before level 2 sleep
if level == 2:
model = self.model_runner.model
self._sleep_saved_buffers = {
name: buffer.cpu().clone() for name, buffer in model.named_buffers()
}
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
freed_bytes / GiB_bytes,
used_bytes / GiB_bytes,
)
def wake_up(self, tags: list[str] | None = None) -> None:
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags)
# Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers):
model = self.model_runner.model
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
# If the KV cache has just been woken up,
# the internal state of cache_engine must be reset,
# especially the FP8 scaling factor.
if (
(tags is None or "kv_cache" in tags)
and self.cache_config.cache_dtype.startswith("fp8")
and hasattr(self.model_runner, "init_fp8_kv_scales")
):
self.model_runner.init_fp8_kv_scales()
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
if tag == "weights":
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be used for one instance per process."
)
return allocator.use_memory_pool(tag=tag)
else:
return nullcontext()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
def init_device(self):
device = self.device_config.device
if isinstance(device, torch.device) and device.type == "cuda":
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
if (
self.parallel_config.data_parallel_size > 1
and self.parallel_config.data_parallel_size_local > 0
and self.parallel_config.distributed_executor_backend
not in ["ray", "external_launcher"]
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
and self.vllm_config.parallel_config.nnodes_within_dp == 1
):
# Use local DP rank if available, otherwise use global DP rank.
dp_local_rank = self.parallel_config.data_parallel_rank_local
if dp_local_rank is None:
dp_local_rank = self.parallel_config.data_parallel_rank
tp_pp_world_size = (
self.parallel_config.pipeline_parallel_size
* self.parallel_config.tensor_parallel_size
)
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
self.local_rank += dp_local_rank * tp_pp_world_size
assert self.local_rank < torch.cuda.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
visible_device_count = (
torch.cuda.device_count() if torch.cuda.is_available() else 0
)
assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must "
f"be less than or equal to the number of visible devices "
f"({visible_device_count})."
)
self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype)
# Initialize the distributed environment BEFORE taking
# memory snapshot
# This ensures NCCL buffers are allocated before we measure
# available memory
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
current_platform.dist_backend,
)
# Set random seed.
set_random_seed(self.model_config.seed)
# Now take memory snapshot after NCCL is initialized
gc.collect()
torch.cuda.empty_cache()
# take current memory snapshot
self.init_snapshot = MemorySnapshot()
self.requested_memory = (
self.init_snapshot.total_memory
* self.cache_config.gpu_memory_utilization
)
if self.init_snapshot.free_memory < self.requested_memory:
GiB = lambda b: round(b / GiB_bytes, 2)
raise ValueError(
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
# Initialize workspace manager
num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
init_workspace_manager(self.device, num_ubatches)
# Construct the model runner
if self.use_v2_model_runner:
from vllm.v1.worker.gpu.model_runner import (
GPUModelRunner as GPUModelRunnerV2,
)
# HACK(woosuk): This is a temporary fix to avoid type errors.
self.model_runner: GPUModelRunner = GPUModelRunnerV2( # type: ignore
self.vllm_config, self.device
)
else:
from vllm.v1.worker.gpu_model_runner import (
GPUModelRunner as GPUModelRunnerV1,
)
self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device)
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
with self._maybe_get_memory_pool_context(tag="weights"):
self.model_runner.load_model(eep_scale_up=eep_scale_up)
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
def reload_weights(self) -> None:
self.model_runner.reload_weights()
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculates the free memory that can be used for KV cache in
bytes.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
GiB = lambda b: b / GiB_bytes
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
# still need a profile run which compiles the model for
# max_num_batched_tokens
self.model_runner.profile_run()
msg = (
f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
"KV Cache as specified by kv_cache_memory_bytes config and "
"skipped memory profiling. This does not respect the "
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
"config when you want manual control of KV cache memory "
"size. If OOM'ed, check the difference of initial free "
"memory between the current run and the previous run "
"where kv_cache_memory_bytes is suggested and update it "
"correspondingly."
)
logger.info(msg)
return kv_cache_memory_bytes
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with memory_profiling(
self.init_snapshot,
weights_memory=int(self.model_runner.model_memory_usage),
) as profile_result:
self.model_runner.profile_run()
self.non_torch_memory = profile_result.non_torch_increase
self.peak_activation_memory = profile_result.torch_peak_increase
free_gpu_memory = profile_result.after_profile.free_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert self.init_snapshot.free_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
f"current free memory {GiB(free_gpu_memory)} GiB. "
"This happens when other processes sharing the same container "
"release GPU memory while vLLM is profiling during initialization. "
"To fix this, ensure consistent GPU memory allocation or "
"isolate vLLM in its own container."
)
self.available_kv_cache_memory_bytes = (
self.requested_memory - profile_result.non_kv_cache_memory
)
unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
logger.debug(
"Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
GiB(self.init_snapshot.free_memory),
self.cache_config.gpu_memory_utilization,
GiB(self.requested_memory),
)
logger.debug(
"Free memory after profiling: %.2f GiB (total), "
"%.2f GiB (within requested)",
GiB(free_gpu_memory),
GiB(free_gpu_memory - unrequested_memory),
)
logger.debug(profile_result)
logger.info_once(
"Available KV cache memory: %.2f GiB",
GiB(self.available_kv_cache_memory_bytes),
scope="local",
)
gc.collect()
return int(self.available_kv_cache_memory_bytes)
def get_kv_connector_handshake_metadata(self) -> dict | None:
"""Get KV connector metadata from this worker if available."""
if not has_kv_transfer_group():
return None
connector = get_kv_transfer_group()
# Return None for connectors that don't need to exchange handshake
# metadata across workers.
if (metadata := connector.get_handshake_metadata()) is None:
return None
tp_rank = get_tp_group().rank_in_group
return {tp_rank: metadata}
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
# Init kv cache connector here, because it requires
# `kv_cache_config`.
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
# because `initialize_kv_cache` will inject kv cache groups not
# related to kv cache connector (e.g. kv cache sharing layers).
ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool(tag="kv_cache"):
self.model_runner.initialize_kv_cache(kv_cache_config)
else:
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
warmup_sizes = []
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
compile_sizes = self.vllm_config.compilation_config.compile_sizes
warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
cg_capture_sizes: list[int] = []
if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
cg_capture_sizes = [] if cg_sizes is None else cg_sizes
warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]
compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
# For each compile_range, if none of the batch sizes
# in warmup_sizes or cudagraph_capture_sizes are in the range,
# add the end of the range to ensure compilation/warmup.
all_sizes = set(cg_capture_sizes)
all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
for compile_range in compile_ranges:
if not any(x in compile_range for x in all_sizes):
warmup_sizes.append(compile_range.end)
# We skip EPLB here since we don't want to record dummy metrics
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
# Warmup and tune the kernels used during model execution before
# cuda graph capture.
kernel_warmup(self)
cuda_graph_memory_bytes = 0
if not self.model_config.enforce_eager:
cuda_graph_memory_bytes = self.model_runner.capture_model()
if self.cache_config.kv_cache_memory_bytes is None and hasattr(
self, "peak_activation_memory"
):
# Suggests optimal kv cache memory size if we rely on
# memory_profiling to guess the kv cache memory size which
# provides peak_activation_memory and a few other memory
# consumption. `memory_profiling` does not consider
# CUDAGraph memory size and may not utilize all gpu memory.
# Users may want fine-grained control to specify kv cache
# memory size.
GiB = lambda b: round(b / GiB_bytes, 2)
# empirically observed that the memory profiling may
# slightly underestimate the memory consumption.
# So leave a small buffer (=150MiB) to avoid OOM.
redundancy_buffer_memory = 150 * (1 << 20)
non_kv_cache_memory = (
self.model_runner.model_memory_usage
+ self.peak_activation_memory
+ self.non_torch_memory
+ cuda_graph_memory_bytes
)
kv_cache_memory_bytes_to_gpu_limit = (
self.init_snapshot.free_memory
- non_kv_cache_memory
- redundancy_buffer_memory
)
kv_cache_memory_bytes_to_requested_limit = (
int(self.requested_memory)
- non_kv_cache_memory
- redundancy_buffer_memory
)
msg = (
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
f"Desired GPU memory utilization is "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). "
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
f"config with `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_requested_limit}` "
f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
f"into requested memory, or `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_gpu_limit}` "
f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
f"utilize gpu memory. Current kv cache memory in use is "
f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
)
logger.debug(msg)
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
if get_pp_group().is_last_rank:
max_num_reqs = min(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
)
# We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = self.model_runner._dummy_run(
num_tokens=max_num_reqs,
skip_eplb=True,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
)
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
else:
self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
def annotate_profile(self, scheduler_output):
# add trace annotation so that we can easily distinguish
# new/cached request numbers in each iteration
if not self.profiler:
return nullcontext()
self.profiler.step()
num_new = len(scheduler_output.scheduled_new_reqs)
num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids)
return self.profiler.annotate_context_manager(
f"execute_new_{num_new}_cached_{num_cached}"
)
@torch.inference_mode()
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)
@torch.inference_mode()
def execute_model(
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None:
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
all_gather_tensors = {}
compilation_config = self.vllm_config.compilation_config
parallel_config = self.vllm_config.parallel_config
if (
parallel_config.pipeline_parallel_size > 1
and compilation_config.pass_config.enable_sp
and forward_pass
):
# currently only supported by V1 GPUModelRunner
assert not self.use_v2_model_runner
num_scheduled_tokens_np = np.array(
list(scheduler_output.num_scheduled_tokens.values()),
dtype=np.int32,
)
# TODO(lucas): This is pretty gross; ideally we should only ever call
# `_determine_batch_execution_and_padding` once (will get called again
# in `execute_model`) but this requires a larger refactor of PP.
_, batch_desc, _, _, _ = (
self.model_runner._determine_batch_execution_and_padding(
num_tokens=num_scheduled_tokens,
num_reqs=len(num_scheduled_tokens_np),
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
use_cascade_attn=False, # TODO(lucas): Handle cascade attention
)
)
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
self.vllm_config, batch_desc.num_tokens
)
}
if forward_pass and not get_pp_group().is_first_rank:
tensor_dict = get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors,
)
assert tensor_dict is not None
intermediate_tensors = IntermediateTensors(tensor_dict)
with self.annotate_profile(scheduler_output):
output = self.model_runner.execute_model(
scheduler_output, intermediate_tensors
)
if isinstance(output, (ModelRunnerOutput, NoneType)):
return output
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config
assert (
parallel_config.distributed_executor_backend != "external_launcher"
and not get_pp_group().is_last_rank
)
get_pp_group().send_tensor_dict(
output.tensors,
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors,
)
return None
def take_draft_token_ids(self) -> DraftTokenIds | None:
return self.model_runner.take_draft_token_ids()
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiling is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
def execute_dummy_batch(self) -> None:
if self.use_v2_model_runner:
self.model_runner.execute_model(
SchedulerOutput.make_empty(), dummy_run=True
)
else:
self.model_runner._dummy_run(1, uniform_decode=True)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> set[int]:
return self.model_runner.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info(
"[Elastic EP] Starting expert resharding before scaling down..."
)
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_loads=None,
rank_mapping=rank_mapping,
)
torch.cuda.synchronize()
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _eplb_after_scale_up(
self,
old_ep_size: int,
new_ep_size: int,
global_expert_loads: list[torch.Tensor] | None,
) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding after scaling up...")
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_loads=global_expert_loads,
rank_mapping=rank_mapping,
)
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _reconfigure_parallel_config(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
"""
Update parallel config with provided reconfig_request
"""
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
def _reconfigure_moe(
self, old_ep_size: int, new_ep_size: int
) -> list[torch.Tensor] | None:
"""
Reconfigure MoE modules with provided reconfig_request
Return the global expert load if new_ep_size > old_ep_size,
otherwise None
"""
from vllm.distributed.parallel_state import (
get_dp_group,
get_ep_group,
prepare_communication_buffer_for_model,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEParallelConfig,
)
parallel_config = self.vllm_config.parallel_config
def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
return [
module
for module in model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=get_tp_group().world_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
return moe_modules
model_moe_modules = get_moe_modules(self.model_runner.model)
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
update_moe_modules(model_moe_modules, num_local_experts)
drafter_model = None
if hasattr(self.model_runner, "drafter") and hasattr(
self.model_runner.drafter, "model"
):
drafter_model = self.model_runner.drafter.model
if drafter_model is not None and is_mixture_of_experts(drafter_model):
drafter_moe_modules = get_moe_modules(drafter_model)
# Check if drafter and model have matching configs
assert (
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
), "Drafter and model configs should be the same"
update_moe_modules(drafter_moe_modules, num_local_experts)
if new_ep_size < old_ep_size:
num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None
new_physical_experts = (
self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts
- self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
)
global_expert_loads = None
else:
num_local_physical_experts_tensor = torch.tensor(
[num_local_experts], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
num_local_physical_experts_tensor,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts_tensor.item())
new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None
global_expert_loads_any = self.model_runner.eplb_state.rearrange(
execute_shuffle=False
)
global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_loads[0].shape[1]
)
prepare_communication_buffer_for_model(self.model_runner.model)
if drafter_model is not None:
prepare_communication_buffer_for_model(drafter_model)
self.model_runner.model.update_physical_experts_metadata(
num_physical_experts=new_physical_experts,
num_local_physical_experts=num_local_physical_experts,
)
return global_expert_loads
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
cleanup_dist_env_and_memory,
get_ep_group,
)
old_ep_size = get_ep_group().world_size
old_ep_rank = get_ep_group().rank
new_ep_size = (
reconfig_request.new_data_parallel_size
* get_tp_group().world_size
* get_pp_group().world_size
)
if new_ep_size < old_ep_size:
self._eplb_before_scale_down(old_ep_size, new_ep_size)
cleanup_dist_env_and_memory()
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
assert old_ep_rank >= new_ep_size
# shutdown
return
self._reconfigure_parallel_config(reconfig_request)
with set_current_vllm_config(self.vllm_config):
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
)
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
if new_ep_size > old_ep_size:
assert global_expert_loads is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
def save_sharded_state(
self,
path: str,
pattern: str | None = None,
max_size: int | None = None,
) -> None:
from vllm.model_executor.model_loader import ShardedStateLoader
ShardedStateLoader.save_model(
self.model_runner.model,
path,
pattern=pattern,
max_size=max_size,
)
def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
) -> None:
self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config,
)
def shutdown(self) -> None:
if runner := getattr(self, "model_runner", None):
runner.ensure_kv_transfer_shutdown()
if self.profiler is not None:
self.profiler.shutdown()
def init_worker_distributed_environment(
vllm_config: VllmConfig,
rank: int,
distributed_init_method: str | None = None,
local_rank: int = -1,
backend: str = "nccl",
) -> None:
"""Initialize the distributed environment."""
attention_config = vllm_config.attention_config
parallel_config = vllm_config.parallel_config
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
init_batch_invariance(attention_config.backend)
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_method = distributed_init_method or "env://"
init_distributed_environment(
parallel_config.world_size, rank, init_method, local_rank, backend
)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.prefill_context_parallel_size,
parallel_config.decode_context_parallel_size,
)
# Init ec connector here before KV caches caches init
# NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
ensure_ec_transfer_initialized(vllm_config)

View File

@@ -0,0 +1,302 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Define KV connector functionality mixin for model runners.
"""
import copy
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import (
TYPE_CHECKING, # noqa: UP035
)
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_shutdown,
get_kv_transfer_group,
has_kv_transfer_group,
)
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
ModelRunnerOutput,
)
from vllm.v1.worker.utils import AttentionGroup
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU)
class KVConnectorModelRunnerMixin:
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def ensure_kv_transfer_shutdown() -> None:
# has_kv_transfer_group can be None during interpreter shutdown.
if has_kv_transfer_group and has_kv_transfer_group(): # type: ignore[truthy-function]
ensure_kv_transfer_shutdown()
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[set[str] | None, set[str] | None]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids
)
return None, None
@staticmethod
def kv_connector_no_forward(
scheduler_output: "SchedulerOutput", vllm_config: VllmConfig
) -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with (
set_forward_context(None, vllm_config),
KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output, wait_for_save=False
) as kv_connector_output,
):
pass
if kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
@staticmethod
def maybe_get_kv_connector_output(
scheduler_output: "SchedulerOutput",
) -> AbstractContextManager[KVConnectorOutput | None]:
return (
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
if has_kv_transfer_group()
else nullcontext()
)
# This context manager must be used within an active forward context.
# It encapsulates the entire KV connector lifecycle within execute_model
@staticmethod
@contextmanager
def _get_kv_connector_output(
scheduler_output: "SchedulerOutput", wait_for_save: bool = True
) -> Generator[KVConnectorOutput, None, None]:
output = KVConnectorOutput()
# Update KVConnector with the KVConnector metadata forward().
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
try:
yield output
finally:
if wait_for_save:
kv_connector.wait_for_save()
output.finished_sending, output.finished_recving = (
kv_connector.get_finished(scheduler_output.finished_req_ids)
)
output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
kv_connector.clear_connector_metadata()
@staticmethod
def use_uniform_kv_cache(
attn_groups: list[list[AttentionGroup]],
cache_dtype: CacheDType,
) -> bool:
"""
Determines whether a uniform KV layout should be used.
A uniform layout means all layers KV caches will share the same
underlying tensor, where for a given block number, the respective
KV data for all layers will be contiguous.
This will allow efficient KV transfer of per-block KV data for all
layers at once.
Note this layout will only be applied given 3 conditions:
1. The KV Cache config contains just a single group where all layers
have the same page size.
2. A KV connector is configured, and the KV connector instance prefers
to use this layout (prefer_cross_layer_blocks() returns True)
2. The flash attention backend supports this layout
(get_kv_cache_stride_order(True) includes a placement for a
num_layers dimension)
Note that the actual placement of the num_layers dimensions
in the unified layers tensors will be determined by the attention
backend.
Thus, the layers KV data may still not be contiguous per block
if the attention backend does not support it.
Args:
attn_groups: The list of attention groups for this model
cache_dtype: The KV cache dtype
Returns:
True if we should use a uniform KV cache layout.
"""
if not has_kv_transfer_group():
return False
if not get_kv_transfer_group().prefer_cross_layer_blocks:
return False
if len(attn_groups) != 1 or len(attn_groups[0]) != 1:
return False
attn_group = attn_groups[0][0]
kv_cache_spec = attn_group.kv_cache_spec
if not isinstance(kv_cache_spec, AttentionSpec):
return False
attn_backend = attn_group.backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
1234,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
except (AttributeError, NotImplementedError):
return False
# check that attention backend include a layers dimension
return len(kv_cache_stride_order) == len(kv_cache_shape) + 1
@staticmethod
def allocate_uniform_kv_caches(
kv_cache_config: KVCacheConfig,
attn_groups: list[list[AttentionGroup]],
cache_dtype: CacheDType,
device: torch.device,
kernel_block_sizes: list[int],
) -> tuple[dict[str, torch.Tensor], torch.Tensor, type[AttentionBackend]]:
"""
Initializes and reshapes KV caches for the simple case where all
layers have the same layout.
This function assumes use_uniform_kv_cache() returned True.
Args:
kv_cache_config: The KV cache config
attn_groups: The list of attention groups for this model
cache_dtype: The KV cache dtype
device: The torch device to allocate on.
kernel_block_sizes: The kernel block sizes for each KV cache group.
Returns:
A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where:
kv_caches is a dict mapping between layer names to their
corresponding memory buffer for KV cache.
cross_layers_kv_cache is the cross layers kv cache tensor
attn_backend is the attention backend matching this tensor
"""
attn_group = attn_groups[0][0]
kv_cache_spec = attn_group.kv_cache_spec
assert isinstance(kv_cache_spec, AttentionSpec)
tensor_sizes = set(
kv_cache_tensor.size for kv_cache_tensor in kv_cache_config.kv_cache_tensors
)
assert len(tensor_sizes) == 1
tensor_size = tensor_sizes.pop()
page_size = kv_cache_spec.page_size_bytes
assert tensor_size % page_size == 0
num_blocks = tensor_size // page_size
num_layers = len(kv_cache_config.kv_cache_tensors)
total_size = tensor_size * num_layers
assert len(kernel_block_sizes) == 1
kernel_block_size = kernel_block_sizes[0]
num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
attn_backend = attn_group.backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
)
# prepend a num_layers dimension into the shape
kv_cache_shape = (num_layers,) + kv_cache_shape
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
logger.info("Allocating a cross layer KV cache of shape %s", kv_cache_shape)
# allocate one contiguous buffer for all layers
cross_layers_kv_cache = (
torch.zeros(total_size, dtype=torch.int8, device=device)
.view(kv_cache_spec.dtype)
.view(kv_cache_shape)
)
# Maintain original KV shape view.
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]
permuted_kv_cache = cross_layers_kv_cache.permute(*inv_order)
kv_caches = {}
for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
tensor = permuted_kv_cache[i]
for layer_name in kv_cache_tensor.shared_by:
kv_caches[layer_name] = tensor
return kv_caches, cross_layers_kv_cache, attn_backend

View File

@@ -0,0 +1,212 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Define LoRA functionality mixin for model runners.
"""
from contextlib import contextmanager
import numpy as np
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
InputBatch = TPUInputBatch | GPUInputBatch
logger = init_logger(__name__)
# Defined as a mixin for GPUModelRunner
class LoRAModelRunnerMixin:
def load_lora_model(
self, model: nn.Module, vllm_config: VllmConfig, device: torch.device
) -> nn.Module:
if not supports_lora(model):
raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")
if supports_multimodal(model):
logger.warning(
"Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model."
)
# Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager(
vllm_config,
device,
model.embedding_modules,
)
return self.lora_manager.create_lora_manager(model)
def _set_active_loras(
self,
prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: tuple[int, ...],
lora_requests: set[LoRARequest],
) -> None:
self._ensure_lora_enabled()
# Set is_prefill to True, so we always use the SGMV kernels on
# non-cuda platforms.
# On cuda platforms we use the same kernels for prefill and
# decode and this flag is generally ignored.
lora_mapping = LoRAMapping(
token_lora_mapping, prompt_lora_mapping, is_prefill=True
)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def _ensure_lora_enabled(self) -> None:
if not hasattr(self, "lora_manager"):
raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")
def set_active_loras(
self,
input_batch: InputBatch,
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray | None = None,
) -> None:
if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
prompt_lora_mapping: tuple[int, ...] # of size np.sum(num_sampled_tokens)
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
lora_requests: set[LoRARequest]
prompt_lora_mapping, token_lora_mapping, lora_requests = (
input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens)
)
return self._set_active_loras(
prompt_lora_mapping, token_lora_mapping, lora_requests
)
@contextmanager
def maybe_setup_dummy_loras(
self, lora_config: LoRAConfig | None, remove_lora: bool = True
):
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_loras = lora_config.max_loras
lora_warmup_rank = (
lora_config.max_lora_rank if lora_config.max_lora_rank < 8 else 8
)
# Make dummy lora requests
lora_requests: set[LoRARequest] = {
LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
for lora_id in range(1, num_loras + 1)
}
with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
# load from disk.
for lr in lora_requests:
self.lora_manager.add_dummy_lora(lr, rank=lora_warmup_rank)
yield
# __exit__ code
if remove_lora:
self.lora_manager.remove_all_adapters()
@contextmanager
def maybe_select_dummy_loras(
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray | None = None,
activate_lora: bool = True,
):
if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_reqs = len(num_scheduled_tokens)
num_loras = lora_config.max_loras
# Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
if activate_lora:
prompt_lora_mapping = (
np.arange(num_reqs, dtype=np.int32) % num_loras
) + 1
else:
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
# Make sample lora mapping
sample_lora_mapping = np.repeat(prompt_lora_mapping, num_sampled_tokens)
# Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
# Make dummy lora requests
lora_requests: set[LoRARequest] = {
LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
for lora_id in range(1, num_loras + 1)
}
self._set_active_loras(
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
)
yield
@contextmanager
def maybe_dummy_run_with_lora(
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray,
activate_lora: bool = True,
remove_lora: bool = True,
):
with (
self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora
),
):
yield
def maybe_remove_all_loras(self, lora_config: LoRAConfig | None):
if lora_config is None:
return
self.lora_manager.remove_all_adapters()
def add_lora(self, lora_request: LoRARequest) -> bool:
self._ensure_lora_enabled()
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
self._ensure_lora_enabled()
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
self._ensure_lora_enabled()
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> set[int]:
self._ensure_lora_enabled()
return self.lora_manager.list_adapters()

View File

@@ -0,0 +1,583 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a TPU input batch
from typing import cast
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState
_SAMPLING_EPS = 1e-5
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[int],
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self._req_ids: list[str | None] = []
self.req_id_to_index: dict[str, int] = {}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
)
# Sampling-related.
self.temperature = torch.empty(
(max_num_reqs,), dtype=torch.float32, device=device
)
self.temperature_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
self.top_p_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
self.top_k_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set()
self.min_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
self.min_p_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.min_p_reqs: set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty(
(max_num_reqs,), dtype=torch.float, device=device
)
self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
)
self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures
self.presence_penalties = torch.empty(
(max_num_reqs,), dtype=torch.float, device=device
)
self.presence_penalties_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures
self.repetition_penalties = torch.empty(
(max_num_reqs,), dtype=torch.float, device=device
)
self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
)
self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# req_index -> (min_tokens, stop_token_ids)
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
self.logit_bias: list[dict[int, float] | None] = [None] * max_num_reqs
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: torch.Tensor | None = None
self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.req_output_token_ids: list[list[int] | None] = []
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
# while performing state updates to the batch.
return cast(list[str], self._req_ids)
def add_request(
self,
request: "CachedRequestState",
req_index: int | None = None,
) -> None:
if req_index is None:
req_index = self.num_reqs
assert req_index < self.max_num_reqs
req_id = request.req_id
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds
)
# TODO: copy prompt_embeds
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
sampling_params = request.sampling_params
assert sampling_params is not None, "pooling requests not supported yet"
if sampling_params.sampling_type == SamplingType.GREEDY:
# Should avoid division by zero later when apply_temperature.
self.temperature_cpu[req_index] = 0.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.min_tokens:
self.min_tokens[req_index] = (
sampling_params.min_tokens,
sampling_params.all_stop_token_ids,
)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device,
)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.bool, device="cpu"
)
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids
] = False
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[req_index] = sampling_params.bad_words_token_ids
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0
def remove_request(self, req_id: str) -> int | None:
"""This method must always be followed by a call to condense()."""
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.min_p_reqs.discard(req_id)
self.min_tokens.pop(req_index, None)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
self.lora_id_to_request_ids[lora_id].discard(req_id)
if len(self.lora_id_to_request_ids[lora_id]) == 0:
self.lora_id_to_request_ids.pop(lora_id)
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0
self.logit_bias[req_index] = None
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
self.req_output_token_ids[i2],
self.req_output_token_ids[i1],
)
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
self.req_id_to_index[old_id_i2],
self.req_id_to_index[old_id_i1],
)
self.num_tokens[i1], self.num_tokens[i2] = (
self.num_tokens[i2],
self.num_tokens[i1],
)
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
self.num_tokens_no_spec[i2],
self.num_tokens_no_spec[i1],
)
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
self.num_prompt_tokens[i2],
self.num_prompt_tokens[i1],
)
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
self.num_computed_tokens_cpu[i2],
self.num_computed_tokens_cpu[i1],
)
self.temperature_cpu[i1], self.temperature_cpu[i2] = (
self.temperature_cpu[i2],
self.temperature_cpu[i1],
)
self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
self.frequency_penalties_cpu[i2],
self.frequency_penalties_cpu[i1],
)
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
self.presence_penalties_cpu[i2],
self.presence_penalties_cpu[i1],
)
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
self.repetition_penalties_cpu[i2],
self.repetition_penalties_cpu[i1],
)
self.min_p_cpu[i1], self.min_p_cpu[i2] = self.min_p_cpu[i2], self.min_p_cpu[i1]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.min_tokens, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
self.request_lora_mapping[i2],
self.request_lora_mapping[i1],
)
self.logit_bias[i1], self.logit_bias[i2] = (
self.logit_bias[i2],
self.logit_bias[i1],
)
if self.allowed_token_ids_mask_cpu_tensor is not None:
(
self.allowed_token_ids_mask_cpu_tensor[i1],
self.allowed_token_ids_mask_cpu_tensor[i2],
) = (
self.allowed_token_ids_mask_cpu_tensor[i2],
self.allowed_token_ids_mask_cpu_tensor[i1],
)
self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: list[int]) -> None:
"""Move non-empty requests down into lower, empty indices.
Args:
empty_req_indices: empty batch indices, sorted descending.
"""
num_reqs = self.num_reqs
if num_reqs == 0:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = empty_req_indices.pop()
if empty_index >= last_req_index:
break
# Swap the states.
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
self._req_ids[empty_index] = req_id
self._req_ids[last_req_index] = None
self.req_output_token_ids[empty_index] = output_token_ids
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens
]
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index
]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
last_req_index
]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
last_req_index
]
self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
last_req_index
]
self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
last_req_index
]
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
min_token = self.min_tokens.pop(last_req_index, None)
if min_token is not None:
self.min_tokens[empty_index] = min_token
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index
]
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
self.allowed_token_ids_mask_cpu_tensor[last_req_index]
)
bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index -= 1
# Trim lists to the batch size.
del self._req_ids[self.num_reqs :]
del self.req_output_token_ids[self.num_reqs :]
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[: self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = self.token_ids_cpu[: self.num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values()
)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_min_p(self) -> bool:
return len(self.min_p_reqs) == 0
@property
def no_penalties(self) -> bool:
return (
len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0
)
@property
def max_num_logprobs(self) -> int | None:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,352 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class."""
import os
from collections.abc import Callable
from typing import Any, TypeVar
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (
ensure_model_parallel_initialized,
init_distributed_environment,
)
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_INFERENCE
from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.utils import bind_kv_cache
logger = init_logger(__name__)
_R = TypeVar("_R")
if not USE_TPU_INFERENCE:
logger.info("tpu_inference not found, using vLLM's TPUWorker.")
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
class TPUWorker:
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
self.is_driver_worker = is_driver_worker
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.use_spmd = envs.VLLM_XLA_USE_SPMD
self.original_parallel_config = None
if self.use_spmd:
# Under SPMD mode, distributed env is initialized as if there is
# only one worker/device.
self.original_parallel_config = self.parallel_config
self.parallel_config.tensor_parallel_size = 1
self.parallel_config.pipeline_parallel_size = 1
self.parallel_config.world_size = 1
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()
# Delay profiler initialization to the start of the profiling.
# This is because in vLLM V1, MP runtime is initialized before the
# TPU Worker is initialized. The profiler server needs to start after
# MP runtime is initialized.
self.profiler = None
self.profile_dir = None
if vllm_config.profiler_config.profiler == "torch" and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = vllm_config.profiler_config.torch_profiler_dir
logger.info(
"Profiling enabled. Traces will be saved to: %s", self.profile_dir
)
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
# ring, the xla tpu compiler flag
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
# fix this. It will be removed after the bug in XLA compiler is fixed.
os.environ["LIBTPU_INIT_ARGS"] = (
os.environ.get("LIBTPU_INIT_ARGS", "")
+ " --xla_tpu_force_1d_allreduce_at_chunk_count=1"
" --xla_jf_conv_input_fusion=False"
)
# --xla_jf_conv_input_fusion=False is used to improve the perf of
# quantized matmul.
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)
# Initialize the distributed environment.
self._init_tpu_worker_distributed_environment(
self.vllm_config, self.rank, self.distributed_init_method, self.local_rank
)
# Device initialization should happen after initializing
# the distributed runtime.
self.device = xm.xla_device()
self.device_config.device = self.device
# Set random seed.
set_random_seed(self.model_config.seed)
xm.set_rng_state(self.model_config.seed, self.device)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# TODO (NickLucche) On gsm we compile 80+ graphs.
# Re-evaluate limit, with MM we may get close to this limit.
torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size = self.parallel_config.world_size
rank = xr.global_ordinal()
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
# Consequently, changes in optimization flags, which affect compilation
# results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if envs.VLLM_XLA_CACHE_PATH:
per_rank_path = os.path.join(
envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}"
)
xr.initialize_cache(per_rank_path, readonly=False)
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = TPUModelRunner(
self.vllm_config, self.device, self.original_parallel_config
)
if rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
def determine_available_memory(self) -> int:
kv_caches: dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, AttentionSpec):
dtype = layer_spec.dtype
# Use an empty tensor instead of `None` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
kv_caches[layer_name] = tpu_kv_cache
else:
raise NotImplementedError(
f"Unsupported KV cache spec '{type(layer_spec)}'"
)
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches,
)
# `max_num_tokens >= max_num_batched_tokens` due to padding.
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
self.model_runner.profile_run(self.model_runner.max_num_tokens)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
with set_current_vllm_config(self.vllm_config):
self.model_runner.reset_dynamo_cache()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
if self.use_spmd:
# This is a workaround for the TPU SPMD mode. The get_memory_info
# API doesn't work with SPMD mode in PyTorch/XLA.
# TODO: use xm.get_memory_info for SPMD once it's supported in
# PyTorch/XLA.
import tpu_info
chip_type, _ = tpu_info.device.get_local_chips()
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
total_memory_size = device_usage[0].total_memory
current_mem = device_usage[0].memory_usage
else:
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
current_mem = m["bytes_used"]
# Ideally we would use profiled = m["peak_bytes_used"] to
# get weights + activations. But there is memory used during
# compilation / weight loading that impacts the peak and
# there is no way to reset peak memory in XLA, So we
# use the heuristic of 2% of weights.
profiled = current_mem * 1.02
# Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(
total_memory_size * self.cache_config.gpu_memory_utilization
)
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
head_size = self.model_config.get_head_size()
if head_size > 0:
padded_head_size = (
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
)
if padded_head_size != head_size:
logger.warning_once("head size is padded to %d", padded_head_size)
# We adjust the usable memory size for the KV cache to prevent OOM
# errors, even after padding the head_size.
tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
return int(tpu_kv_cache_bytes)
def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)
def execute_model(
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None:
return self.model_runner.execute_model(scheduler_output)
def profile(self, is_start: bool = True):
if self.rank < 1:
if self.profile_dir is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
if self.profiler is None:
self.profiler = xp.start_server(9012)
xp.start_trace(self.profile_dir)
else:
xp.stop_trace()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def load_model(self) -> None:
self.model_runner.load_model()
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
def reload_weights(self) -> None:
self.model_runner.reload_weights()
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
self.model_runner.initialize_kv_cache(kv_cache_config)
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def _init_tpu_worker_distributed_environment(
self,
vllm_config: VllmConfig,
rank: int,
distributed_init_method: str | None = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
if self.use_spmd:
xr.use_spmd()
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
parallel_config = vllm_config.parallel_config
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method or "env://",
backend=current_platform.dist_backend,
)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size
)
ensure_kv_transfer_initialized(vllm_config)
def shutdown(self) -> None:
self.model_runner.ensure_kv_transfer_shutdown()
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
"""Apply a function on the model inside this worker."""
return fn(self.get_model())
if USE_TPU_INFERENCE:
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker
TPUWorker = TpuInferenceWorker # type: ignore

View File

@@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
from vllm.config import ParallelConfig
@dataclass
class UBatchSlice:
request_slice: slice
token_slice: slice
def is_empty(self) -> bool:
return (
self.request_slice.start == self.request_slice.stop
or self.token_slice.start == self.token_slice.stop
)
@property
def num_tokens(self) -> int:
return self.token_slice.stop - self.token_slice.start
UBatchSlices: TypeAlias = list[UBatchSlice]
def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool:
return (padded_num_tokens // 2) >= orig_num_tokens
def check_ubatch_thresholds(
config: ParallelConfig, num_tokens: int, uniform_decode: bool
) -> bool:
if not config.enable_dbo:
return False
if uniform_decode:
return num_tokens >= config.dbo_decode_token_threshold
else:
return num_tokens >= config.dbo_prefill_token_threshold
# This just pads the second ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slices(
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
) -> UBatchSlices:
# TODO(lucas): handle empty second ubatch
padded_second_request_slice = slice(
ubatch_slices[1].request_slice.start, num_reqs_padded
)
padded_second_token_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
return [
ubatch_slices[0],
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
]
def maybe_create_ubatch_slices(
should_ubatch: bool,
num_scheduled_tokens: np.ndarray,
num_tokens_padded: int,
num_reqs_padded: int,
split_point: int | None = None,
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
if not should_ubatch:
return None, None
if split_point is None:
split_point = int(num_tokens_padded) // 2
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
first_ubatch_token_slice = slice(0, split_point)
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
# Determine request slices using exclusive stop semantics
# First ubatch includes requests whose tokens overlap [0, split_point)
first_ubatch_req_stop = int(
np.searchsorted(cu_num_tokens, split_point, side="left")
)
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
# Second ubatch starts at the request that contains the split_point
# or the request starting exactly at split_point (if on boundary)
second_ubatch_req_start = int(
np.searchsorted(cu_num_tokens, split_point, side="right") - 1
)
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
ubatch_slices = [
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
]
ubatch_slices_padded = _pad_out_ubatch_slices(
ubatch_slices, num_tokens_padded, num_reqs_padded
)
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
return ubatch_slices, ubatch_slices_padded

231
vllm/v1/worker/ubatching.py Normal file
View File

@@ -0,0 +1,231 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Optional
import torch
from vllm import forward_context
from vllm.forward_context import ForwardContext
from vllm.utils.torch_utils import current_stream
_THREAD_ID_TO_CONTEXT: dict = {}
_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None]
class UBatchContext:
"""
Context manager for micro-batching synchronization using threading events.
"""
def __init__(
self,
id: int,
comm_stream: torch.cuda.Stream,
compute_stream: torch.cuda.Stream,
forward_context: ForwardContext,
ready_barrier: threading.Barrier,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.Event,
gpu_compute_done_event: torch.Event,
schedule: str = "default",
):
self.id = id
self.comm_stream = comm_stream
self.compute_stream = compute_stream
self.forward_context = forward_context
self.ready_barrier = ready_barrier
self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event
self.current_stream = compute_stream
self.gpu_comm_done_event = gpu_comm_done_event
self.gpu_compute_done_event = gpu_compute_done_event
self.schedule = schedule
self.recv_hook = None
def __enter__(self):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
_CURRENT_CONTEXTS[self.id] = self
self.ready_barrier.wait()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
# Assume we want to start on the compute stream
self.update_stream(self.compute_stream)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_CURRENT_CONTEXTS[self.id] = None
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
self.maybe_run_recv_hook()
self.cpu_signal_event.set()
self.cpu_wait_event.clear()
return False
def _restore_context(self):
forward_context._forward_context = self.forward_context
def update_stream(self, stream):
self.current_stream = stream
if current_stream() != self.current_stream:
torch.cuda.set_stream(self.current_stream)
def _signal_comm_done(self):
self.gpu_comm_done_event.record(self.comm_stream)
def _signal_compute_done(self):
self.gpu_compute_done_event.record(self.compute_stream)
def _wait_compute_done(self):
self.comm_stream.wait_event(self.gpu_compute_done_event)
def _wait_comm_done(self):
self.compute_stream.wait_event(self.gpu_comm_done_event)
def _cpu_yield(self):
# It is critical for correctness that only one thread is running
# at a time. These asserts just make sure that this is the only
# thread running before waking the other one up and going to sleep
assert forward_context._forward_context == self.forward_context
assert current_stream() == self.current_stream
assert not self.cpu_wait_event.is_set()
self.cpu_signal_event.set()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
def switch_to_comm(self):
self.update_stream(self.comm_stream)
def switch_to_compute(self):
self.update_stream(self.compute_stream)
def switch_to_comm_sync(self):
self._signal_compute_done()
self.update_stream(self.comm_stream)
self._wait_compute_done()
def switch_to_compute_sync(self):
self._signal_comm_done()
self.update_stream(self.compute_stream)
self._wait_comm_done()
def maybe_run_recv_hook(self):
if self.recv_hook is not None:
self.recv_hook()
self.recv_hook = None
def yield_(self):
self.current_stream = current_stream()
self._cpu_yield()
self.update_stream(self.current_stream)
def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream
self._signal_compute_done()
self._cpu_yield()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
self._signal_comm_done()
self._cpu_yield()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
self._wait_comm_done()
def dbo_enabled() -> bool:
return len(_THREAD_ID_TO_CONTEXT) > 0
def dbo_current_ubatch_id() -> int:
if len(_THREAD_ID_TO_CONTEXT) == 0:
return 0
return _THREAD_ID_TO_CONTEXT[threading.get_ident()]
def _register_ubatch_function(func):
def wrapper(*args, **kwargs):
if len(_THREAD_ID_TO_CONTEXT) > 0:
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
ctx = _CURRENT_CONTEXTS[ctx_idx]
func(ctx, *args, **kwargs)
return wrapper
dbo_maybe_run_recv_hook = _register_ubatch_function(UBatchContext.maybe_run_recv_hook)
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
UBatchContext.yield_and_switch_from_compute_to_comm
)
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
UBatchContext.yield_and_switch_from_comm_to_compute
)
dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm)
dbo_switch_to_compute = _register_ubatch_function(UBatchContext.switch_to_compute)
dbo_switch_to_comm_sync = _register_ubatch_function(UBatchContext.switch_to_comm_sync)
dbo_switch_to_compute_sync = _register_ubatch_function(
UBatchContext.switch_to_compute_sync
)
def dbo_register_recv_hook(recv_hook):
if len(_THREAD_ID_TO_CONTEXT) > 0:
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2]
next_ctx.recv_hook = recv_hook
def dbo_get_previous_event(func, *args, **kwargs):
if len(_THREAD_ID_TO_CONTEXT) > 0:
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
ctx = _CURRENT_CONTEXTS[ctx_idx]
# execute callable on the ubatch compute stream to record/wait events there
with torch.cuda.stream(ctx.compute_stream):
return func(*args, **kwargs)
def make_ubatch_contexts(
num_micro_batches: int,
compute_stream: torch.cuda.Stream,
comm_stream: torch.cuda.Stream,
forward_contexts: list[ForwardContext],
ready_barrier: threading.Barrier,
schedule: str = "default",
) -> list[UBatchContext]:
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
"""
Create a context manager for micro-batching synchronization.
"""
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)]
gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)]
assert len(forward_contexts) == 2
ctxs = []
for i in range(num_micro_batches):
ctx = UBatchContext(
id=i,
compute_stream=compute_stream,
comm_stream=comm_stream,
forward_context=forward_contexts[i],
ready_barrier=ready_barrier,
cpu_wait_event=cpu_events[i],
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
gpu_comm_done_event=gpu_comm_done_events[i],
gpu_compute_done_event=gpu_compute_done_events[i],
schedule=schedule,
)
ctxs.append(ctx)
return ctxs

375
vllm/v1/worker/utils.py Normal file
View File

@@ -0,0 +1,375 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass, field
import torch
from typing_extensions import deprecated
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
logger = init_logger(__name__)
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
def __init__(
self,
model_config: ModelConfig,
scheduler_config: SchedulerConfig,
mm_registry: MultiModalRegistry,
) -> None:
super().__init__()
self.model_config = model_config
self.scheduler_config = scheduler_config
self.mm_registry = mm_registry
self.cache = cache = processor_only_cache_from_config(model_config, mm_registry)
self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
profiler_limits=self.mm_limits,
)
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
scheduler_config,
max_tokens_by_modality,
)
self.encoder_compute_budget = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
max_items_per_prompt_by_modality = dict[str, int]()
max_items_per_batch_by_modality = dict[str, int]()
for modality, max_tokens in max_tokens_by_modality.items():
(
max_items_per_prompt,
max_items_per_batch,
) = self.get_max_items(modality, max_tokens)
max_items_per_prompt_by_modality[modality] = max_items_per_prompt
max_items_per_batch_by_modality[modality] = max_items_per_batch
self.max_tokens_by_modality = max_tokens_by_modality
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
def get_modality_with_max_tokens(self) -> str:
max_tokens_by_modality = self.max_tokens_by_modality
modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
return modality
def get_encoder_budget(self) -> int:
return min(self.encoder_compute_budget, self.encoder_cache_size)
def get_max_items(
self,
modality: str,
max_tokens_per_item: int,
) -> tuple[int, int]:
if max_tokens_per_item == 0:
return 0, 0
# Check how many items of this modality can be supported by
# the encoder budget.
encoder_budget = self.get_encoder_budget()
# TODO: handle encoder-decoder models once we support them.
if encoder_budget == 0:
return 0, 0
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
# Check how many items of this modality can be supported by
# the decoder budget.
mm_limit = self.mm_limits[modality]
max_items_per_prompt = max(
1,
min(mm_limit, self.max_model_len // max_tokens_per_item),
)
scheduler_config = self.scheduler_config
max_num_reqs = self.max_num_reqs
if not scheduler_config.enable_chunked_prefill:
max_num_reqs = min(
max_num_reqs,
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
)
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
max_items_per_batch = max(
1,
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
)
return max_items_per_prompt, max_items_per_batch
def reset_cache(self) -> None:
if self.cache is not None:
self.cache.clear_cache()
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
layer_names: list[str]
kv_cache_spec: KVCacheSpec
kv_cache_group_id: int
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistent buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder] = field(
default_factory=lambda: []
)
def create_metadata_builders(
self,
vllm_config,
device,
kernel_block_size: int | None,
num_metadata_builders: int = 1,
):
kv_cache_spec_builder = (
self.kv_cache_spec.copy_with_new_block_size(kernel_block_size)
if kernel_block_size is not None
else self.kv_cache_spec
)
self.metadata_builders = [
self.backend.get_builder_cls()(
kv_cache_spec_builder,
self.layer_names,
vllm_config,
device,
)
for _ in range(num_metadata_builders)
]
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id
return self.metadata_builders[ubatch_id]
def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings,
expected_num_items: int,
) -> None:
"""
Perform sanity checks for the result of
[`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][].
"""
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
f"or a single 3D tensor, but got {type(mm_embeddings)} "
"instead. This is most likely due to incorrect implementation "
"of the model's `embed_multimodal` method."
)
assert len(mm_embeddings) == expected_num_items, (
"Expected number of multimodal embeddings to match number of "
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
"instead. This is most likely due to incorrect implementation "
"of the model's `embed_multimodal` method."
)
assert all(e.ndim == 2 for e in mm_embeddings), (
"Expected multimodal embeddings to be a sequence of 2D tensors, "
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
"instead. This is most likely due to incorrect implementation "
"of the model's `embed_multimodal` method."
)
@deprecated("`scatter_mm_placeholders` is deprecated and will be removed in v0.15.0.")
def scatter_mm_placeholders(
embeds: torch.Tensor,
is_embed: torch.Tensor | None,
) -> torch.Tensor:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
[`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
@deprecated("`gather_mm_placeholders` is deprecated and will be removed in v0.15.0.")
def gather_mm_placeholders(
placeholders: torch.Tensor,
is_embed: torch.Tensor | None,
) -> torch.Tensor:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of [`scatter_mm_placeholders`]
[vllm.v1.worker.utils.scatter_mm_placeholders].
"""
if is_embed is None:
return placeholders
return placeholders[is_embed]
def add_kv_sharing_layers_to_kv_cache_groups(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec],
runner_only_attn_layers: set[str] | None = None,
) -> None:
"""
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
for layers that do not allocate its own KV cache, based on the mapping in
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
group, which is needed to ensure that attention metadata is assigned later.
Args:
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
If an Attention layer `layer_name` is in the keys of this dict, it
means this layer will perform attention using the keys and values
from the KV cache of `shared_kv_cache_layers[layer_name]`.
kv_cache_groups: The KV cache groups of the model.
"""
layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
for kv_cache_group in kv_cache_groups:
for layer_name in kv_cache_group.layer_names:
layer_to_kv_cache_group[layer_name] = kv_cache_group
for layer_name, target_layer_name in shared_kv_cache_layers.items():
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
tgt_kv_cache_group.layer_names.append(layer_name)
if runner_only_attn_layers is not None:
runner_only_attn_layers.add(layer_name)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, Attention],
runner_kv_caches: list[torch.Tensor],
num_attn_module: int = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
# TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers
# in the same decoder block.
if (
current_platform.is_cuda_alike()
or current_platform.is_xpu()
or current_platform.is_cpu()
):
# We know that the GPU / CPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this.
pass
else:
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
def is_residual_scattered_for_sp(
vllm_config: VllmConfig, num_input_tokens: int
) -> bool:
"""Check if the residual tensor is scattered for sequence parallelism.
The residual tensor is scattered across tensor parallel ranks when sequence
parallelism and tensor parallelism is enabled.
This follows the same logic as SequenceParallelismPass.is_applicable_for_range():
- In full-graph compilation mode (no splitting ops or using inductor graph
partition), SP is always applied
- Otherwise, SP is only applied for specific shapes in compile_sizes
"""
if not vllm_config.compilation_config.pass_config.enable_sp:
return False
tp = vllm_config.parallel_config.tensor_parallel_size
if tp == 1:
return False
# When sequence parallelism is enabled, we always pad num_input_tokens
# to be a multiple of tensor_parallel_size (tp) earlier.
assert num_input_tokens % tp == 0
if (
not vllm_config.compilation_config.splitting_ops
or vllm_config.compilation_config.use_inductor_graph_partition
):
return True
compile_sizes = vllm_config.compilation_config.compile_sizes
if compile_sizes is None:
return False
return num_input_tokens in compile_sizes

View File

@@ -0,0 +1,377 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeVar
import torch
import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.system_utils import update_environment_variables
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.serial_utils import run_method
if TYPE_CHECKING:
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.outputs import AsyncModelRunnerOutput, ModelRunnerOutput
else:
SchedulerOutput = object
GrammarOutput = object
AsyncModelRunnerOutput = object
ModelRunnerOutput = object
logger = init_logger(__name__)
_R = TypeVar("_R")
class WorkerBase:
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
) -> None:
"""
Initialize common worker components.
Args:
vllm_config: Complete vLLM configuration
local_rank: Local device index
rank: Global rank in distributed setup
distributed_init_method: Distributed initialization method
is_driver_worker: Whether this worker handles driver
responsibilities
"""
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config
self.compilation_config = vllm_config.compilation_config
from vllm.platforms import current_platform
self.current_platform = current_platform
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
# Device and model state
self.device: torch.device | None = None
self.model_runner: nn.Module | None = None
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""Get specifications for KV cache implementation."""
raise NotImplementedError
def compile_or_warm_up_model(self) -> None:
"""Prepare model for execution through compilation/warmup."""
raise NotImplementedError
def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
return
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise NotImplementedError
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks."""
raise NotImplementedError
def reset_mm_cache(self) -> None:
reset_fn = getattr(self.model_runner, "reset_mm_cache", None)
if callable(reset_fn):
reset_fn()
def get_model(self) -> nn.Module:
raise NotImplementedError
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
"""Apply a function on the model inside this worker."""
return fn(self.get_model())
def load_model(self) -> None:
"""Load model onto target device."""
raise NotImplementedError
def execute_model(
self, scheduler_output: SchedulerOutput
) -> ModelRunnerOutput | None:
"""If this method returns None, sample_tokens should be called immediately after
to obtain the ModelRunnerOutput.
Note that this design may be changed in future if/when structured outputs
parallelism is re-architected.
"""
raise NotImplementedError
def sample_tokens(
self, grammar_output: GrammarOutput
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
"""Should be called immediately after execute_model iff it returned None."""
raise NotImplementedError
def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise NotImplementedError
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def list_loras(self) -> set[int]:
raise NotImplementedError
@property
def vocab_size(self) -> int:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()
def shutdown(self) -> None:
"""Clean up resources held by the worker."""
return
class WorkerWrapperBase:
"""
This class represents one process in an executor/engine. It is responsible
for lazily initializing the worker and handling the worker's lifecycle.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
"""
def __init__(
self,
vllm_config: VllmConfig,
rpc_rank: int = 0,
global_rank: int | None = None,
) -> None:
"""
Initialize the worker wrapper with the given vllm_config and rpc_rank.
Note: rpc_rank is the rank of the worker in the executor. In most cases,
it is also the rank of the worker in the distributed group. However,
when multiple executors work together, they can be different.
e.g. in the case of SPMD-style offline inference with TP=2,
users can launch 2 engines/executors, each with only 1 worker.
All workers have rpc_rank=0, but they have different ranks in the TP
group.
"""
self.rpc_rank = rpc_rank
self.global_rank = self.rpc_rank if global_rank is None else global_rank
self.worker: WorkerBase | None = None
# do not store this `vllm_config`, `init_worker` will set the final
# one.
# TODO: investigate if we can remove this field in `WorkerWrapperBase`,
# `init_cached_hf_modules` should be unnecessary now.
self.vllm_config: VllmConfig | None = None
# `model_config` can be None in tests
model_config = vllm_config.model_config
if model_config and model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()
def shutdown(self) -> None:
if self.worker is not None:
self.worker.shutdown()
def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
"""
Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
if self.rpc_rank in rank_mapping:
self.rpc_rank = rank_mapping[self.rpc_rank]
def update_environment_variables(
self,
envs_list: list[dict[str, str]],
) -> None:
envs = envs_list[self.rpc_rank]
key = "CUDA_VISIBLE_DEVICES"
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del os.environ[key]
update_environment_variables(envs)
def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None:
"""
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
kwargs = all_kwargs[self.rpc_rank]
self.vllm_config = kwargs.get("vllm_config")
assert self.vllm_config is not None, (
"vllm_config is required to initialize the worker"
)
self.vllm_config.enable_trace_function_call_for_thread()
from vllm.plugins import load_general_plugins
load_general_plugins()
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls
)
else:
raise ValueError(
"passing worker_cls is no longer supported. Please pass keep the class in a separate module and pass the qualified name of the class as a string." # noqa: E501
)
if self.vllm_config.parallel_config.worker_extension_cls:
worker_extension_cls = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_extension_cls
)
extended_calls = []
if worker_extension_cls not in worker_class.__bases__:
# check any conflicts between worker and worker_extension_cls
for attr in dir(worker_extension_cls):
if attr.startswith("__"):
continue
assert not hasattr(worker_class, attr), (
f"Worker class {worker_class} already has an attribute"
f" {attr}, which conflicts with the worker"
f" extension class {worker_extension_cls}."
)
if callable(getattr(worker_extension_cls, attr)):
extended_calls.append(attr)
# dynamically inherit the worker extension class
worker_class.__bases__ = worker_class.__bases__ + (
worker_extension_cls,
)
logger.info(
"Injected %s into %s for extended collective_rpc calls %s",
worker_extension_cls,
worker_class,
extended_calls,
)
shared_worker_lock = kwargs.pop("shared_worker_lock", None)
if shared_worker_lock is None:
msg = (
"Missing `shared_worker_lock` argument from executor. "
"This argument is needed for mm_processor_cache_type='shm'."
)
mm_config = self.vllm_config.model_config.multimodal_config
if mm_config and mm_config.mm_processor_cache_type == "shm":
raise ValueError(msg)
else:
logger.warning_once(msg)
self.mm_receiver_cache = None
else:
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config,
MULTIMODAL_REGISTRY,
shared_worker_lock,
)
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
assert self.worker is not None
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
kv_cache_config = kv_cache_configs[self.global_rank]
assert self.vllm_config is not None
with set_current_vllm_config(self.vllm_config):
self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self):
assert self.vllm_config is not None
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during device initialization
self.worker.init_device() # type: ignore
def execute_method(self, method: str | bytes, *args, **kwargs):
try:
# method resolution order:
# if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return run_method(self, method, args, kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (
f"Error executing method {method!r}. "
"This might cause deadlock in distributed execution."
)
logger.exception(msg)
raise e
def __getattr__(self, attr: str):
return getattr(self.worker, attr)
def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None:
mm_cache = self.mm_receiver_cache
if mm_cache is None:
return
for req_data in scheduler_output.scheduled_new_reqs:
req_data.mm_features = mm_cache.get_and_update_features(
req_data.mm_features
)
def execute_model(
self,
scheduler_output: SchedulerOutput,
*args,
**kwargs,
) -> ModelRunnerOutput | None:
self._apply_mm_cache(scheduler_output)
assert self.worker is not None
return self.worker.execute_model(scheduler_output, *args, **kwargs)
def reset_mm_cache(self) -> None:
mm_receiver_cache = self.mm_receiver_cache
if mm_receiver_cache is not None:
mm_receiver_cache.clear_cache()
assert self.worker is not None
self.worker.reset_mm_cache()

253
vllm/v1/worker/workspace.py Normal file
View File

@@ -0,0 +1,253 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
import os
from itertools import accumulate
from math import prod
from typing import Optional
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.math_utils import round_up
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__)
def _compute_bytes(shape: tuple[int, ...], dtype: torch.dtype) -> int:
return prod(shape) * dtype.itemsize
# Constants
_MB = 1024**2
_GiB = 1024**3
# Global workspace manager instance
_manager: Optional["WorkspaceManager"] = None
class WorkspaceManager:
"""Manager for workspace allocation.
Manages workspace buffers for DBO (Dual Batch Overlap) execution.
Can be locked to prevent further growth during execution.
"""
def __init__(self, device: torch.device, num_ubatches: int | None = None):
self._device = device
# Cache num ubatches at init based on configuration (default to 1)
self._num_ubatches = num_ubatches if num_ubatches is not None else 1
self._current_workspaces: list[torch.Tensor | None] = [None, None]
self._locked: bool = False
@staticmethod
def _workspace_size_bytes(workspace: torch.Tensor | None) -> int:
"""Get size of workspace in bytes."""
if workspace is None:
return 0
return workspace.numel() * workspace.element_size()
def lock(self) -> None:
"""Lock the workspace to prevent further growth.
After locking, any attempt to allocate a larger workspace will raise
an assertion error. This ensures workspace size is fixed during execution.
"""
self._locked = True
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Workspace locked. Current sizes: %s",
[
self._workspace_size_bytes(ws) / _MB
for ws in self._current_workspaces
if ws is not None
],
)
def is_locked(self) -> bool:
"""Check if workspace is locked."""
return self._locked
def get_simultaneous(
self, *shapes_and_dtypes: tuple[tuple[int, ...], torch.dtype]
) -> list[torch.Tensor]:
"""Get multiple workspace tensors simultaneously from a single allocation.
Args:
*shapes_and_dtypes: One or more (shape, dtype) tuples.
Returns:
List of tensor views into the workspace buffer, one per shape/dtype pair.
"""
actual_bytes = [_compute_bytes(s, d) for s, d in shapes_and_dtypes]
aligned_bytes = [round_up(actual, 256) for actual in actual_bytes]
total_bytes = sum(aligned_bytes)
# Calculate cumulative offsets using itertools.accumulate
offsets = list(accumulate([0] + aligned_bytes[:-1]))
current_workspace = self._ensure_workspace_size(total_bytes)
return [
current_workspace[offsets[i] : offsets[i] + actual_bytes[i]]
.view(shapes_and_dtypes[i][1])
.reshape(shapes_and_dtypes[i][0])
for i in range(len(shapes_and_dtypes))
]
def _ensure_workspace_size(self, required_bytes: int) -> torch.Tensor:
"""Ensure workspace is allocated and large enough, return current workspace.
Args:
required_bytes: The number of bytes required.
Returns:
The current workspace tensor.
"""
ubatch_id = dbo_current_ubatch_id()
current_workspace = self._current_workspaces[ubatch_id]
current_size = self._workspace_size_bytes(current_workspace)
if current_size < required_bytes:
def get_caller_info() -> str:
"""Find first frame outside WorkspaceManager."""
curr_frame = inspect.currentframe()
if curr_frame is None:
return "unknown"
# Walk up the stack skipping WorkspaceManager frames
curr_frame = curr_frame.f_back
while curr_frame is not None:
# TODO: This only catches instance methods (self), missing
# classmethods and staticmethods. Once Python 3.11+ is the
# minimum supported version, use co_qualname instead:
# qualname = curr_frame.f_code.co_qualname
# if qualname.startswith("WorkspaceManager."):
if isinstance(curr_frame.f_locals.get("self"), WorkspaceManager):
curr_frame = curr_frame.f_back
continue
filename = os.path.basename(curr_frame.f_code.co_filename)
return (
f"{filename}:{curr_frame.f_lineno}:{curr_frame.f_code.co_name}"
)
return "unknown"
if self._locked:
raise AssertionError(
f"Workspace is locked but allocation from '{get_caller_info()}' "
f"requires {required_bytes / _MB:.2f} MB, current size is "
f"{current_size / _MB:.2f} MB. "
"Workspace growth is not allowed after locking."
)
for ubatch_id in range(self._num_ubatches):
current_workspace = self._current_workspaces[ubatch_id]
if (
current_workspace is None
or self._workspace_size_bytes(current_workspace) < required_bytes
):
# Delete old tensor before allocating new one to avoid
# memory spike from resize_(). resize_() allocates new
# memory before freeing old, which can cause OOM.
# Must clear the list reference first since local var
# is just a copy of the reference.
self._current_workspaces[ubatch_id] = None
del current_workspace
self._current_workspaces[ubatch_id] = torch.empty(
(required_bytes,), dtype=torch.uint8, device=self._device
)
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Resized workspace from '%s': %.2f MB -> "
"%.2f MB (%d ubatches, total memory %.2f MB)",
get_caller_info(),
current_size / _MB,
required_bytes / _MB,
self._num_ubatches,
required_bytes * self._num_ubatches / _MB,
)
current_workspace = self._current_workspaces[dbo_current_ubatch_id()]
return current_workspace
def is_workspace_manager_initialized() -> bool:
"""Check if workspace manager has been initialized.
Returns:
True if workspace manager is initialized, False otherwise.
"""
return _manager is not None
def current_workspace_manager() -> "WorkspaceManager":
"""Get the current workspace manager instance.
Raises:
AssertionError: If workspace manager has not been initialized.
"""
assert _manager is not None, (
"WorkspaceManager not initialized. Call init_workspace_manager() "
"with a device before using workspace functions."
)
return _manager
def init_workspace_manager(
device: torch.device, num_ubatches: int | None = None
) -> None:
"""Initialize the workspace manager with a device.
Must be called before using any workspace functions. Typically called
from GPUModelRunner.__init__.
Args:
device: The device to allocate workspace on.
num_ubatches: Number of micro-batches. Defaults to 1.
"""
global _manager
if _manager is not None:
logger.warning(
"WorkspaceManager already initialized on device %s, "
"reinitializing on device %s",
_manager._device,
device,
)
_manager = WorkspaceManager(device, num_ubatches)
def lock_workspace() -> None:
"""Lock the workspace to prevent further growth.
After calling this function, any attempt to allocate a workspace larger
than the current size will raise an AssertionError. This ensures that
workspace size is fixed during execution and prevents unexpected memory
allocations in the hot path.
Example:
# During initialization
init_workspace_manager(device)
reserve_workspace(shape1, dtype1)
reserve_workspace(shape2, dtype2)
# Lock after warmup/profiling
lock_workspace()
# Now all get_workspace calls must fit in pre-allocated size
"""
current_workspace_manager().lock()
def reset_workspace_manager() -> None:
"""Reset the workspace manager to uninitialized state.
This is primarily intended for testing purposes to allow tests
to reinitialize the workspace manager cleanly.
"""
global _manager
_manager = None

View File

@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING:
pass
logger = init_logger(__name__)
class XPUModelRunner(GPUModelRunner):
"""A model runner for XPU devices."""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
with _torch_cuda_wrapper():
super().__init__(vllm_config, device)
# FIXME: To be verified.
self.cascade_attn_enabled = False
def _init_device_properties(self) -> None:
self.num_sms = None
def _sync_device(self) -> None:
torch.xpu.synchronize()
@contextmanager
def _torch_cuda_wrapper():
try:
# replace cuda APIs with xpu APIs, this should work by default
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.default_stream = torch.xpu.current_stream
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.stream = torch.xpu.stream
yield
finally:
pass

View File

@@ -0,0 +1,174 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Any
import torch
import torch.distributed
from vllm.config import VllmConfig
from vllm.distributed import get_world_group
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.profiler.wrapper import TorchProfilerWrapper
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
from vllm.v1.worker.xpu_model_runner import XPUModelRunner
logger = init_logger(__name__)
class XPUWorker(Worker):
"""A XPU worker class."""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(
vllm_config, local_rank, rank, distributed_init_method, is_driver_worker
)
device_config = self.device_config
assert device_config.device_type == "xpu"
assert current_platform.is_xpu()
# Torch profiler. Enabled and configured through profiler_config.
self.profiler: Any | None = None
profiler_config = vllm_config.profiler_config
if profiler_config.profiler == "torch":
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
self.profiler = TorchProfilerWrapper(
profiler_config,
worker_name=worker_name,
local_rank=self.local_rank,
activities=["CPU", "XPU"],
)
# we provide this function due to `torch.xpu.mem_get_info()` doesn't
# return correct free_gpu_memory on intel client GPU. We need to
# calculate/estiamte it.
def xpu_get_mem_info(self):
if current_platform.is_data_center_gpu():
return torch.xpu.mem_get_info()
else:
_, total_gpu_memory = torch.xpu.mem_get_info()
# FIXME: memory_allocated() doesn't count non-torch allocations,
# and we don't have any API to get it. so we mark it as 128MB.
used_memory = torch.xpu.memory_allocated()
non_torch_allocations = 128 * 1024 * 1024
free_gpu_memory = total_gpu_memory - (used_memory + non_torch_allocations)
return free_gpu_memory, total_gpu_memory
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculates the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.xpu.empty_cache()
torch.xpu.reset_peak_memory_stats()
free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info()
current_allocated_bytes = torch.xpu.memory_allocated()
msg = (
"Before memory profiling run, "
f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, "
f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, "
f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB."
)
logger.info(msg)
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
free_gpu_memory, _ = self.xpu_get_mem_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert self.init_gpu_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance."
)
# Get the peak memory allocation recorded by torch
peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"]
torch.xpu.empty_cache()
torch_allocated_bytes = torch.xpu.memory_stats()["allocated_bytes.all.current"]
total_allocated_bytes = self.xpu_get_mem_info()[1] - self.xpu_get_mem_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory
)
msg = (
"After memory profiling run, "
f"peak memory usage is {peak_memory / 1024**2:.2f} MB,"
f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, "
f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, "
f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB."
)
logger.info(msg)
return int(available_kv_cache_memory)
def init_device(self):
device = self.device_config.device
if (
isinstance(device, torch.device)
and device.type == "xpu"
and current_platform.is_xpu()
):
self.device = torch.device(f"xpu:{self.local_rank}")
current_platform.set_device(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype)
torch.xpu.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties(
self.local_rank
).total_memory
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
ENV_LOCAL_WORLD_SIZE = os.getenv(
"LOCAL_WORLD_SIZE", str(self.parallel_config.world_size)
)
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank)
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
current_platform.dist_backend,
)
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(
torch.zeros(1).xpu(), group=get_world_group().device_group
)
# Set random seed.
set_random_seed(self.model_config.seed)
# Construct the model runner
self.model_runner = XPUModelRunner( # type: ignore
self.vllm_config, self.device
)