Add minimal vLLM 0.16.1 build repo for BI-V150

This commit is contained in:
2026-04-18 10:56:22 +08:00
commit d69657327e
1895 changed files with 615301 additions and 0 deletions

View File

View File

@@ -0,0 +1,342 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__)
class BlockTable:
def __init__(
self,
block_size: int,
max_num_reqs: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
kernel_block_size: int,
cp_kv_cache_interleave_size: int,
):
"""
Args:
block_size: Block size used for KV cache memory allocation
max_num_reqs: Maximum number of concurrent requests supported.
max_num_blocks_per_req: Maximum number of blocks per request.
max_num_batched_tokens: Maximum number of tokens in a batch.
pin_memory: Whether to pin memory for faster GPU transfers.
device: Target device for the block table.
kernel_block_size: The block_size of underlying attention kernel.
Will be the same as `block_size` if `block_size` is supported
by the attention kernel.
"""
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.pin_memory = pin_memory
self.device = device
if kernel_block_size == block_size:
# Standard case: allocation and computation use same block size
# No block splitting needed, direct mapping
self.block_size = block_size
self.blocks_per_kv_block = 1
self.use_hybrid_blocks = False
else:
# Hybrid case: allocation block size differs from kernel block size
# Memory blocks are subdivided to match kernel requirements
# Example: 32-token memory blocks with 16-token kernel blocks
# → Each memory block corresponds to 2 kernel blocks
if block_size % kernel_block_size != 0:
raise ValueError(
f"kernel_block_size {kernel_block_size} must divide "
f"kv_manager_block_size size {block_size} evenly"
)
self.block_size = kernel_block_size
self.blocks_per_kv_block = block_size // kernel_block_size
self.use_hybrid_blocks = True
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
self.block_table = self._make_buffer(
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
)
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping = self._make_buffer(
self.max_num_batched_tokens, dtype=torch.int64
)
if self.use_hybrid_blocks:
self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
1, -1
)
else:
self._kernel_block_arange = None
try:
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
# PCP might not be initialized in testing
self.pcp_world_size = 1
self.pcp_rank = 0
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
def append_row(
self,
block_ids: list[int],
row_idx: int,
) -> None:
if not block_ids:
return
if self.use_hybrid_blocks:
block_ids = self.map_to_kernel_blocks(
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
)
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table.np[row_idx, start : start + num_blocks] = block_ids
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
block_table_np = self.block_table.np
block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
def swap_row(self, src: int, tgt: int) -> None:
src_tgt, tgt_src = [src, tgt], [tgt, src]
self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray
) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
if total_cp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * total_cp_world_size
block_table_indices = (
req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = (
virtual_block_offsets
// self.cp_kv_cache_interleave_size
% total_cp_world_size
== total_cp_rank
)
# Calculate local block_offsets
block_offsets = (
virtual_block_offsets
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
* self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping.np[: req_indices.shape[0]] = np.where(
mask, slot_mapping, -1
)
else:
block_table_indices = (
req_indices * self.max_num_blocks_per_req + positions // self.block_size
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(
block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping.np[: req_indices.shape[0]],
)
def commit_block_table(self, num_reqs: int) -> None:
self.block_table.copy_to_gpu(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
self.slot_mapping.copy_to_gpu(num_tokens)
def clear(self) -> None:
self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0)
@staticmethod
def map_to_kernel_blocks(
kv_manager_block_ids: np.ndarray,
blocks_per_kv_block: int,
kernel_block_arange: np.ndarray,
) -> np.ndarray:
"""Convert kv_manager_block_id IDs to kernel block IDs.
Example:
# kv_manager_block_ids: 32 tokens,
# Kernel block size: 16 tokens
# blocks_per_kv_block = 2
>>> kv_manager_block_ids = np.array([0, 1, 2])
>>> Result: [0, 1, 2, 3, 4, 5]
# Each kv_manager_block_id maps to 2 kernel block id:
# kv_manager_block_id 0 → kernel block id [0, 1]
# kv_manager_block_id 1 → kernel block id [2, 3]
# kv_manager_block_id 2 → kernel block id [4, 5]
"""
if blocks_per_kv_block == 1:
return kv_manager_block_ids
kernel_block_ids = (
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
+ kernel_block_arange
)
return kernel_block_ids.reshape(-1)
def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
"""Returns the device tensor of the block table."""
return self.block_table.gpu[:num_reqs]
def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table.cpu
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table.np
def _make_buffer(
self, *size: int | torch.SymInt, dtype: torch.dtype
) -> CpuGpuBuffer:
return CpuGpuBuffer(
*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
)
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
block_sizes: list[int],
kernel_block_sizes: list[int],
max_num_blocks: list[int] | None = None,
cp_kv_cache_interleave_size: int = 1,
) -> None:
if len(kernel_block_sizes) != len(block_sizes):
raise ValueError(
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
f"must match block_sizes length ({len(block_sizes)})"
)
if max_num_blocks is None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
total_cp_world_size = get_total_cp_world_size()
max_num_blocks = [
cdiv(max_model_len, block_size * total_cp_world_size)
for block_size in block_sizes
]
if len(max_num_blocks) != len(block_sizes):
raise ValueError(
f"max_num_blocks length ({len(max_num_blocks)}) "
f"must match block_sizes length ({len(block_sizes)})"
)
self.block_tables = [
BlockTable(
block_size,
max_num_reqs,
max_num_blocks_per_req,
max_num_batched_tokens,
pin_memory,
device,
kernel_block_size,
cp_kv_cache_interleave_size,
)
for block_size, kernel_block_size, max_num_blocks_per_req in zip(
block_sizes, kernel_block_sizes, max_num_blocks
)
]
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)
def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.move_row(src, tgt)
def swap_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray
) -> None:
for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions)
def commit_block_table(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit_block_table(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
for block_table in self.block_tables:
block_table.commit_slot_mapping(num_tokens)
def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, cast
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed import get_dcp_group, get_pcp_group
if TYPE_CHECKING:
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
else:
AttentionLayerBase = object
def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None:
pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
dcp_size = vllm_config.parallel_config.decode_context_parallel_size
interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
if pcp_size * dcp_size > 1:
layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(vllm_config, layer_type)
for layer in layers.values():
layer_impl = getattr(layer, "impl", None)
if layer_impl is None:
continue
if vllm_config.speculative_config is not None and interleave_size > 1:
assert layer_impl.supports_mtp_with_cp_non_trivial_interleave_size, (
"MTP with cp_kv_cache_interleave_size > 1 is not "
f"supported in {layer_impl.__class__.__name__}."
)
if dcp_size > 1:
assert layer_impl.need_to_return_lse_for_decode, (
"DCP requires attention impls to return"
" the softmax lse for decode, but the impl "
f"{layer_impl.__class__.__name__} "
"does not return the softmax lse for decode."
)
if pcp_size > 1:
assert layer_impl.supports_pcp, (
"PCP requires attention impls' support, "
f"but the impl {layer_impl.__class__.__name__} "
"does not support PCP."
)
def get_total_cp_world_size():
try:
pcp_world_size = get_pcp_group().world_size
except AssertionError:
# PCP might not be initialized in testing
pcp_world_size = 1
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
return dcp_world_size * pcp_world_size

View File

@@ -0,0 +1,125 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.tracing import instrument
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
class CPUModelRunner(GPUModelRunner):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
with _torch_cuda_wrapper():
super().__init__(vllm_config, device)
assert device == torch.device("cpu")
assert self.speculative_config is None, "spec decode is not supported."
self.use_cuda_graph = False
self.cascade_attn_enabled = False
self._postprocess_tensors()
def _postprocess_tensors(self) -> None:
# Note: replace device tensors with cpu tensors
def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None:
cpu_tensor = getattr(obj, cpu_attr_name, None)
device_tensor = getattr(obj, device_attr_name, None)
if cpu_tensor is not None and device_tensor is not None:
assert isinstance(cpu_tensor, torch.Tensor)
assert isinstance(device_tensor, torch.Tensor)
setattr(obj, device_attr_name, cpu_tensor)
for v in vars(self).values():
if isinstance(v, CpuGpuBuffer):
v.gpu = v.cpu
for k, v in vars(self.input_batch).items():
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
replace_tensor(self.input_batch, k, k[:-11])
for block_table in self.input_batch.block_table.block_tables:
for v in vars(block_table).values():
if isinstance(v, CpuGpuBuffer):
v.gpu = v.cpu
@instrument(span_name="Loading (CPU)")
def load_model(self, eep_scale_up: bool = False) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
self.model = self.load_lora_model(self.model, self.vllm_config, self.device)
def get_model(self) -> nn.Module:
return self.model
@instrument(span_name="Warmup (CPU)")
def warming_up_model(self) -> None:
logger.info("Warming up model for the compilation...")
# Only generate graph for the generic shape
with _set_global_compilation_settings(self.vllm_config):
self._dummy_run(
min(
max(16, self.max_num_reqs),
self.scheduler_config.max_num_batched_tokens,
)
)
logger.info("Warming up done.")
def _init_device_properties(self) -> None:
pass
def _sync_device(self) -> None:
pass
def get_dp_padding(self, num_tokens: int) -> tuple[int, torch.Tensor | None]:
# Note: For CPU backend, dp padding is not required for now.
return 0, None
@contextmanager
def _torch_cuda_wrapper():
class _EventPlaceholder:
def __init__(self, *args, **kwargs) -> None:
self.record = lambda: None
self.synchronize = lambda: None
class _StreamPlaceholder:
def __init__(self, *args, **kwargs) -> None:
pass
cuda_event = torch.Event
cuda_stream = torch.cuda.Stream
try:
torch.Event = _EventPlaceholder
torch.cuda.Stream = _StreamPlaceholder
yield
finally:
torch.Event = cuda_event
torch.cuda.Stream = cuda_stream
@contextmanager
def _set_global_compilation_settings(config: VllmConfig):
import torch._inductor.config as torch_inductor_config
inductor_config = config.compilation_config.inductor_compile_config
# Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
freezing_value = torch_inductor_config.freezing
try:
if inductor_config.get("max_autotune", False):
torch_inductor_config.freezing = True
yield
finally:
torch_inductor_config.freezing = freezing_value

View File

@@ -0,0 +1,221 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import platform
from collections.abc import Callable
from typing import Any
import torch
from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
from vllm.profiler.wrapper import TorchProfilerWrapper
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
logger = init_logger(__name__)
class CPUWorker(Worker):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(
vllm_config,
local_rank,
rank,
distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.parallel_config.disable_custom_all_reduce = True
# Torch profiler. Enabled and configured through profiler_config.
self.profiler: Any | None = None
profiler_config = vllm_config.profiler_config
if profiler_config.profiler == "torch":
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
self.profiler = TorchProfilerWrapper(
profiler_config,
worker_name=worker_name,
local_rank=self.local_rank,
activities=["CPU"],
)
def init_device(self):
# Setup OpenMP threads affinity.
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
# Under numa binding some cores reserved for kv transfer in nixl_connector.py
if omp_cpuids == "auto" and platform.system() == "Linux":
cpu_arch = current_platform.get_cpu_architecture()
if cpu_arch in (CpuArchEnum.POWERPC, CpuArchEnum.S390X):
# For S390X/POWERPC SMT-8/4/2
self.local_omp_cpuid = self._get_autobind_cpu_ids(
lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]
)
elif cpu_arch == CpuArchEnum.X86:
# For x86 SMT-2, use 1 CPU per core
self.local_omp_cpuid = self._get_autobind_cpu_ids(
lambda cpus: cpus[-1:]
)
elif cpu_arch == CpuArchEnum.ARM:
# For AArch64, no SMT
self.local_omp_cpuid = self._get_autobind_cpu_ids(lambda cpus: cpus)
else:
self.local_omp_cpuid = "nobind"
elif omp_cpuids == "nobind":
self.local_omp_cpuid = "nobind"
else:
local_dp_rank = self.parallel_config.data_parallel_rank_local
omp_cpuids_list = omp_cpuids.split("|")
if local_dp_rank is not None:
world_size = self.parallel_config.world_size
omp_cpuids_list = omp_cpuids_list[
local_dp_rank * world_size : (local_dp_rank + 1) * world_size
]
self.local_omp_cpuid = omp_cpuids_list[self.rank]
if self.local_omp_cpuid != "nobind":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret:
logger.info(ret)
# Note: unique identifier for creating allreduce shared memory
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(":")[-1]
# Initialize the distributed environment.
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
current_platform.dist_backend,
)
# Set random seed.
set_random_seed(self.model_config.seed)
# Construct the model runner
self.model_runner: CPUModelRunner = CPUModelRunner(
self.vllm_config, torch.device("cpu")
)
def sleep(self, level: int = 1) -> None:
logger.warning("sleep mode is not supported on CPU, ignore it.")
pass
def wake_up(self, tags: list[str] | None = None) -> None:
logger.warning("sleep mode is not supported on CPU, ignore it.")
pass
def determine_available_memory(self) -> int:
return self.cache_config.cpu_kvcache_space_bytes or 0
def compile_or_warm_up_model(self) -> None:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model()
def _get_autobind_cpu_ids(
self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]]
) -> str:
"""
Return CPU ids to bind based on NUMA nodes.
Currently for rank N, only CPU ids on the N-th node in available NUMA
node list will be selected.
Args:
cpu_selector: a callable object to select CPUs from a CPU list
of a physical core. The input is a LogicalCPUInfo list, sorted by
the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be
returned.
"""
# simulate multiple numa nodes, for testing
sim_multi_numa_nodes = os.environ.get("VLLM_CPU_SIM_MULTI_NUMA", "0") != "0"
allowed_numa_nodes, logical_cpu_list = (
CpuPlatform.get_allowed_cpu_core_node_list()
)
assert (
len(allowed_numa_nodes) >= self.parallel_config.world_size
or sim_multi_numa_nodes
), (
f"Not enough allowed NUMA nodes to bind threads of "
f"{self.parallel_config.world_size} CPUWorkers. "
f"Allowed NUMA nodes are {allowed_numa_nodes}. "
"Please try to bind threads manually."
)
if not sim_multi_numa_nodes:
# Get CPUs on NUMA node `allowed_numa_nodes[local_rank]`
selected_numa_node = allowed_numa_nodes[self.local_rank] # type: ignore
logical_cpu_list = [
x for x in logical_cpu_list if x.numa_node == selected_numa_node
]
else:
# This is a bit tricky because the internal DP size
# is always 1 for non-MoE models
world_size_across_dp = (
self.parallel_config.world_size
* self.parallel_config._api_process_count
)
assert len(logical_cpu_list) >= world_size_across_dp
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node)
sim_cpu_num_per_node = len(logical_cpu_list) // world_size_across_dp
assert self.parallel_config.data_parallel_rank_local is not None
start_idx = (
self.local_rank
+ self.parallel_config.world_size
* self.parallel_config.data_parallel_rank_local
) * sim_cpu_num_per_node
logical_cpu_list = logical_cpu_list[
start_idx : (start_idx + sim_cpu_num_per_node)
]
# Select CPUs from each physical core via cpu_selector
core_to_cpus: dict[int, list[LogicalCPUInfo]] = {}
for cpu_info in logical_cpu_list:
if cpu_info.physical_core not in core_to_cpus:
core_to_cpus[cpu_info.physical_core] = []
core_to_cpus[cpu_info.physical_core].append(cpu_info)
logical_cpu_list = []
for cpu_list in core_to_cpus.values():
cpu_list = sorted(cpu_list, key=lambda x: x.id)
logical_cpu_list.extend(cpu_selector(cpu_list))
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.id)
# Reserve CPUs for other processes
reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU
if reserve_cpu_num is None:
need_reserve = (
self.parallel_config.world_size > 1
or self.parallel_config.data_parallel_size_local > 1
)
reserve_cpu_num = 1 if need_reserve else 0
assert len(logical_cpu_list) > reserve_cpu_num, (
f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) "
f"should less than {len(logical_cpu_list)}."
)
if reserve_cpu_num != 0:
logical_cpu_list = logical_cpu_list[:-reserve_cpu_num]
logger.info(
"auto thread-binding list (id, physical core): %s",
[(x.id, x.physical_core) for x in logical_cpu_list],
)
return ",".join([str(x.id) for x in logical_cpu_list])
def profile(self, is_start: bool = True, profile_prefix: str | None = None):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()

240
vllm/v1/worker/dp_utils.py Normal file
View File

@@ -0,0 +1,240 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
import torch.distributed as dist
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import (
check_ubatch_thresholds,
is_last_ubatch_empty,
)
logger = init_logger(__name__)
def _get_device_and_group(parallel_config: ParallelConfig):
# Use the actual device assigned to the DP group, not just the device type
device = get_dp_group().device
group = get_dp_group().device_group
# Transferring this tensor from GPU to CPU will introduce a GPU sync
# point that could adversely affect performance of vllm with asynch
# scheduling. This environment variable exists to quickly disable
# this optimization if we run into this case.
if parallel_config.disable_nccl_for_dp_synchronization:
logger.info_once(
"Using CPU all reduce to synchronize DP padding between ranks."
)
device = "cpu"
group = get_dp_group().cpu_group
return device, group
def _run_ar(
should_ubatch: bool,
should_dp_pad: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
cudagraph_mode: int,
parallel_config: ParallelConfig,
) -> torch.Tensor:
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
device, group = _get_device_and_group(parallel_config)
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
tensor[3][dp_rank] = 1 if should_dp_pad else 0
tensor[4][dp_rank] = cudagraph_mode
dist.all_reduce(tensor, group=group)
return tensor
def _post_process_ubatch(tensor: torch.Tensor, num_ubatches: int) -> bool:
orig_num_tokens_tensor = tensor[0, :]
padded_num_tokens_tensor = tensor[1, :]
# First determine if we are going to be ubatching.
should_ubatch: bool = bool(torch.all(tensor[2] == 1).item())
if not should_ubatch:
return False
# If the DP ranks are planning to ubatch, make sure that
# there are no "empty" second ubatches
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
if is_last_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens, num_ubatches):
logger.debug(
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
)
should_ubatch = False
return should_ubatch
def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor:
num_tokens_across_dp = tensor[1, :]
if should_dp_pad:
# If DP padding is enabled, ensure that each rank is processing the same number
# of tokens
max_num_tokens = int(num_tokens_across_dp.max().item())
return torch.tensor(
[max_num_tokens] * len(num_tokens_across_dp),
device="cpu",
dtype=torch.int32,
)
else:
return num_tokens_across_dp.cpu()
def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
"""
Synchronize cudagraph_mode across DP ranks by taking the minimum.
If any rank has NONE (0), all ranks use NONE.
This ensures all ranks send consistent values (all padded or all unpadded).
"""
return int(tensor[4, :].min().item())
def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
should_attempt_ubatching: bool,
should_attempt_dp_padding: bool,
cudagraph_mode: int,
parallel_config: ParallelConfig,
) -> tuple[bool, torch.Tensor | None, int]:
"""
1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do.
2. Determines the total number of tokens that each rank will run.
When running microbatched or if should_attempt_dp_padding is True, all
ranks will be padded out so that the run with the same number of tokens
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including any DP padding.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
]
"""
assert num_tokens_padded >= num_tokens_unpadded
# Coordinate between the DP ranks via an All Reduce
# to determine the total number of tokens that each rank
# will run and if we are using ubatching or not.
tensor = _run_ar(
should_ubatch=should_attempt_ubatching,
should_dp_pad=should_attempt_dp_padding,
orig_num_tokens_per_ubatch=num_tokens_unpadded,
padded_num_tokens_per_ubatch=num_tokens_padded,
cudagraph_mode=cudagraph_mode,
parallel_config=parallel_config,
)
should_dp_pad = bool(torch.all(tensor[3] == 1).item())
# DP ranks should all have the same value for should_attempt_dp_padding.
assert should_attempt_dp_padding == should_dp_pad
# Check conditions for microbatching
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
if should_ubatch and not should_dp_pad:
logger.debug_once(
"Microbatching has been triggered and requires DP padding. "
"Enabling DP padding even though it has been explicitly "
"disabled.",
scope="global",
)
should_dp_pad = True
# Pad all DP ranks up to the maximum token count across ranks if
# should_dp_pad is True
num_tokens_after_padding = _post_process_dp_padding(
tensor,
should_dp_pad,
)
# Synchronize cudagraph_mode across ranks (take min)
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
def coordinate_batch_across_dp(
num_tokens_unpadded: int,
allow_microbatching: bool,
allow_dp_padding: bool,
parallel_config: ParallelConfig,
num_tokens_padded: int | None = None,
uniform_decode: bool | None = None,
num_scheduled_tokens_per_request: np.ndarray | None = None,
cudagraph_mode: int = 0,
) -> tuple[bool, torch.Tensor | None, int]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches.
Args:
num_tokens_unpadded: Number of tokens without accounting for padding
allow_microbatching: If microbatching should be attempted
allow_dp_padding: If all DP ranks should be padded up to the same value
parallel_config: The parallel config
num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
TP, etc)
uniform_decode: Only used if allow_microbatching is True. True if the batch
only contains single token decodes
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
number of tokens per request.
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
padded up to the max value across all DP ranks when allow_dp_padding
is True.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
]
"""
if parallel_config.data_parallel_size == 1:
# Early exit.
return False, None, cudagraph_mode
# If the caller has explicitly enabled microbatching.
should_attempt_ubatching = False
if allow_microbatching:
# Check preconditions for microbatching
assert uniform_decode is not None
should_attempt_ubatching = check_ubatch_thresholds(
parallel_config,
num_tokens_unpadded,
uniform_decode=uniform_decode,
)
if num_tokens_padded is None:
num_tokens_padded = num_tokens_unpadded
(should_ubatch, num_tokens_after_padding, synced_cudagraph_mode) = (
_synchronize_dp_ranks(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
allow_dp_padding,
cudagraph_mode,
parallel_config,
)
)
return (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode)

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Define EC connector functionality mixin for model runners.
"""
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import TYPE_CHECKING
import torch
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorBase
from vllm.logger import init_logger
from vllm.v1.outputs import ECConnectorOutput
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
# Defined as a EC connector functionality mixin for ModelRunner (GPU, TPU)
class ECConnectorModelRunnerMixin:
@staticmethod
def maybe_save_ec_to_connector(
encoder_cache: dict[str, torch.Tensor],
mm_hash: str,
):
if not has_ec_transfer():
logger.debug("Not have ec transfer please check")
return
connector = get_ec_transfer()
connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash)
@staticmethod
def get_finished_ec_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[set[str] | None, set[str] | None]:
if has_ec_transfer():
return get_ec_transfer().get_finished(scheduler_output.finished_req_ids)
return None, None
@staticmethod
def maybe_get_ec_connector_output(
scheduler_output: "SchedulerOutput",
encoder_cache: dict[str, torch.Tensor],
**kwargs,
) -> AbstractContextManager[ECConnectorOutput | None]:
return (
ECConnectorModelRunnerMixin._get_ec_connector_output(
scheduler_output, encoder_cache, **kwargs
)
if has_ec_transfer()
else nullcontext()
)
# This context manager must be used within an active forward context.
# It encapsulates the entire EC connector lifecycle within execute_model
@staticmethod
@contextmanager
def _get_ec_connector_output(
scheduler_output: "SchedulerOutput",
encoder_cache: dict[str, torch.Tensor],
**kwargs,
) -> Generator[ECConnectorOutput, None, None]:
output = ECConnectorOutput()
ec_connector = get_ec_transfer()
assert isinstance(ec_connector, ECConnectorBase)
assert scheduler_output.ec_connector_metadata is not None
ec_connector.bind_connector_metadata(scheduler_output.ec_connector_metadata)
# Load caches for consumer or both roles
if ec_connector.is_consumer:
ec_connector.start_load_caches(encoder_cache, **kwargs)
try:
yield output
finally:
output.finished_sending, output.finished_recving = (
ec_connector.get_finished(scheduler_output.finished_req_ids)
)
ec_connector.clear_connector_metadata()

View File

@@ -0,0 +1,4 @@
# [Experimental] Model Runner V2
This directory contains the new model runner which is under active development.
Ping [Woosuk Kwon](https://github.com/WoosukKwon) for any changes.

View File

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import numpy as np
import torch
from vllm.v1.outputs import AsyncModelRunnerOutput, LogprobsTensors, ModelRunnerOutput
from vllm.v1.worker.gpu.sample.output import SamplerOutput
class AsyncOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput,
num_sampled_tokens: torch.Tensor,
main_stream: torch.cuda.Stream,
copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event,
):
# NOTE(woosuk): We must retain references to the GPU tensors,
# as the copy operations are performed on a different CUDA stream than
# the one where the tensors were created.
self.model_runner_output = model_runner_output
self.sampler_output = sampler_output
self.num_sampled_tokens = num_sampled_tokens
self.copy_event = copy_event
with stream(copy_stream, main_stream):
copy_stream.wait_stream(main_stream)
self.sampled_token_ids = async_copy_to_np(sampler_output.sampled_token_ids)
self.logprobs_tensors: LogprobsTensors | None = None
if sampler_output.logprobs_tensors is not None:
self.logprobs_tensors = (
sampler_output.logprobs_tensors.to_cpu_nonblocking()
)
self.num_nans: np.ndarray | None = None
if sampler_output.num_nans is not None:
self.num_nans = async_copy_to_np(sampler_output.num_nans)
self.num_sampled_tokens_np = async_copy_to_np(num_sampled_tokens)
self.prompt_logprobs_dict = {
k: v.to_cpu_nonblocking() if v is not None else None
for k, v in self.model_runner_output.prompt_logprobs_dict.items()
}
self.copy_event.record(copy_stream)
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
# NOTE(woosuk): The following code is to ensure compatibility with
# the existing model runner.
# Going forward, we should keep the data structures as NumPy arrays
# rather than Python lists.
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
num_sampled_tokens: list[int] = self.num_sampled_tokens_np.tolist()
for token_ids, num_tokens in zip(sampled_token_ids, num_sampled_tokens):
del token_ids[num_tokens:]
self.model_runner_output.sampled_token_ids = sampled_token_ids
if self.num_nans is not None:
self.model_runner_output.num_nans_in_logits = dict(
zip(self.model_runner_output.req_ids, self.num_nans.tolist())
)
if self.logprobs_tensors is not None:
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
return self.model_runner_output
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
return x.to("cpu", non_blocking=True).numpy()
@contextlib.contextmanager
def stream(to_stream: torch.cuda.Stream, from_stream: torch.cuda.Stream):
"""Lightweight version of torch.cuda.stream() context manager which
avoids current_stream and device lookups.
"""
try:
torch.cuda.set_stream(to_stream)
yield
finally:
torch.cuda.set_stream(from_stream)

View File

@@ -0,0 +1,215 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any, cast
import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheConfig,
KVCacheSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
kv_cache_spec: dict[str, KVCacheSpec] = {}
layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type)
for layer_name, attn_module in attn_layers.items():
# Skip modules that don't need KV cache (eg encoder-only attention)
if spec := attn_module.get_kv_cache_spec(vllm_config):
kv_cache_spec[layer_name] = spec
return kv_cache_spec
def init_attn_backend(
kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device
):
attn_backends: dict[str, type[AttentionBackend]] = {}
attn_groups: list[list[AttentionGroup]] = []
attn_backend_workspace: torch.Tensor | None = None
for kv_cache_group_id, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups
):
layer_names = kv_cache_group_spec.layer_names
layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
group_map: dict[tuple[tuple[str, str], KVCacheSpec], AttentionGroup] = {}
group_order: list[tuple[tuple[str, str], KVCacheSpec]] = []
for layer_name in layer_names:
attn_backend = attn_layers[layer_name].get_attn_backend()
attn_backends[layer_name] = attn_backend
layer_kv_cache_spec: KVCacheSpec = kv_cache_group_spec.kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
key = (attn_backend.full_cls_name(), layer_kv_cache_spec)
if key not in group_map:
group_map[key] = AttentionGroup(
attn_backend,
[layer_name],
layer_kv_cache_spec,
kv_cache_group_id,
)
group_order.append(key)
else:
group_map[key].layer_names.append(layer_name)
groups = [group_map[key] for key in group_order]
for group in groups:
group.create_metadata_builders(
vllm_config=vllm_config,
device=device,
kernel_block_size=None,
num_metadata_builders=1,
)
builder = group.get_metadata_builder(0)
if attn_backend_workspace is None:
if hasattr(builder, "_get_workspace_buffer"):
attn_backend_workspace = builder._get_workspace_buffer()
else:
if hasattr(builder, "set_workspace_buffer"):
builder.set_workspace_buffer(attn_backend_workspace)
attn_groups.append(groups)
return attn_backends, attn_groups
def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys()), (
"Some layers are not correctly initialized"
)
return kv_cache_raw_tensors
def _reshape_kv_cache(
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, AttentionBackend],
) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
assert isinstance(kv_cache_spec, AttentionSpec)
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
attn_backend = attn_backends[layer_name]
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
)
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
dtype = kv_cache_spec.dtype
raw_tensor = raw_tensor.view(dtype)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
return kv_caches
def init_kv_cache(
runner_kv_caches: list[torch.Tensor],
forward_context: dict[str, Any],
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend],
device: torch.device,
) -> dict[str, torch.Tensor]:
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
return kv_caches
def build_slot_mappings_by_layer(
slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig
) -> dict[str, torch.Tensor]:
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for slot_mapping, kv_cache_group in zip(slot_mappings, kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
slot_mappings_by_layer[layer_name] = slot_mapping
return slot_mappings_by_layer
def build_attn_metadata(
attn_groups: list[list[AttentionGroup]],
num_reqs: int,
num_tokens: int,
query_start_loc_gpu: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
max_query_len: int,
seq_lens: torch.Tensor,
max_seq_len: int,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
dcp_local_seq_lens: torch.Tensor | None = None,
) -> dict[str, Any]:
seq_lens = seq_lens[:num_reqs]
if dcp_local_seq_lens is not None:
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]
attn_metadata: dict[str, Any] = {}
num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
for i in range(num_kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
causal=True,
dcp_local_seq_lens=dcp_local_seq_lens,
)
for attn_group in attn_groups[i]:
attn_metadata_builder = attn_group.get_metadata_builder(0)
metadata = attn_metadata_builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata

View File

@@ -0,0 +1,253 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
class BlockTables:
def __init__(
self,
block_sizes: list[int],
max_num_reqs: int,
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
cp_size: int = 1,
cp_rank: int = 0,
cp_interleave: int = 1,
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
self.cp_size = cp_size
self.cp_rank = cp_rank
self.cp_interleave = cp_interleave
self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[StagedWriteTensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
# When using DCP, each request's KV cache is sharded among different ranks.
# As a result, one block on the current rank covers `block_size * cp_size`
# tokens in the full, global (unsharded) sequence.
max_num_blocks = cdiv(self.max_model_len, block_size * self.cp_size)
block_table = StagedWriteTensor(
(self.max_num_reqs, max_num_blocks),
dtype=torch.int32,
device=device,
)
self.block_tables.append(block_table)
self.block_table_ptrs = self._make_ptr_tensor(
[b.gpu for b in self.block_tables]
)
self.block_table_strides = torch.tensor(
[b.gpu.stride(0) for b in self.block_tables],
dtype=torch.int64,
device=self.device,
)
self.block_sizes_tensor = torch.tensor(
self.block_sizes, dtype=torch.int32, device=self.device
)
self.num_blocks = UvaBackedTensor(
(self.num_kv_cache_groups, self.max_num_reqs),
dtype=torch.int32,
)
# Block tables used for model's forward pass.
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.input_block_tables: list[torch.Tensor] = [
torch.zeros_like(b.gpu) for b in self.block_tables
]
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
self.slot_mappings = torch.zeros(
self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device,
)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
return torch.tensor(
[t.data_ptr() for t in x], dtype=torch.uint64, device=self.device
)
def append_block_ids(
self,
req_index: int,
new_block_ids: tuple[list[int], ...],
overwrite: bool,
) -> None:
for i in range(self.num_kv_cache_groups):
start = self.num_blocks.np[i, req_index] if not overwrite else 0
block_ids = new_block_ids[i]
self.block_tables[i].stage_write(req_index, start, block_ids)
self.num_blocks.np[i, req_index] = start + len(block_ids)
def apply_staged_writes(self) -> None:
# TODO(woosuk): This can be inefficient since it launches one kernel per
# block table. Implement a kernel to handle all block tables at once.
for block_table in self.block_tables:
block_table.apply_write()
self.num_blocks.copy_to_uva()
def gather_block_tables(
self, idx_mapping: torch.Tensor
) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
idx_mapping,
self.block_table_ptrs,
self.input_block_table_ptrs,
self.block_table_strides,
self.num_blocks.gpu,
self.num_blocks.gpu.stride(0),
BLOCK_SIZE=1024, # type: ignore
)
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def compute_slot_mappings(
self,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
num_reqs = idx_mapping.shape[0]
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens,
idx_mapping,
query_start_loc,
positions,
self.block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mappings,
self.slot_mappings.stride(0),
self.cp_rank,
CP_SIZE=self.cp_size,
CP_INTERLEAVE=self.cp_interleave,
PAD_ID=PAD_SLOT_ID,
TRITON_BLOCK_SIZE=1024, # type: ignore
)
return self.slot_mappings[:, :num_tokens]
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
self.slot_mappings.fill_(PAD_SLOT_ID)
return self.slot_mappings[:, :num_tokens]
@triton.jit
def _gather_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_kv_cache_groups]
dst_block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
@triton.jit
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
idx_mapping, # [num_reqs]
query_start_loc, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
block_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride,
cp_rank,
CP_SIZE: tl.constexpr,
CP_INTERLEAVE: tl.constexpr,
PAD_ID: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
if batch_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
block_size = tl.load(block_sizes + group_id)
req_state_idx = tl.load(idx_mapping + batch_idx)
start_idx = tl.load(query_start_loc + batch_idx)
end_idx = tl.load(query_start_loc + batch_idx + 1)
for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // (block_size * CP_SIZE)
block_offsets = positions % (block_size * CP_SIZE)
block_numbers = tl.load(
block_table_ptr + req_state_idx * block_table_stride + block_indices
)
if CP_SIZE == 1:
# Common case: Context parallelism is not used.
slot_ids = block_numbers * block_size + block_offsets
else:
# Context parallelism is used.
is_local = block_offsets // CP_INTERLEAVE % CP_SIZE == cp_rank
rounds = block_offsets // (CP_INTERLEAVE * CP_SIZE)
remainder = block_offsets % CP_INTERLEAVE
local_offsets = rounds * CP_INTERLEAVE + remainder
slot_ids = block_numbers * block_size + local_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr)
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
return tl.multiple_of(ptr, 16)

View File

@@ -0,0 +1,220 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence
from functools import partial
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import (
async_tensor_h2d,
get_accelerator_view_from_cpu_tensor,
)
def async_copy_to_gpu(
x: torch.Tensor | np.ndarray,
out: torch.Tensor | None = None,
device: torch.device | None = None,
) -> torch.Tensor:
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
assert x.is_cpu
if out is None:
assert device is not None
out = torch.empty_like(x, device=device)
# CPU-to-CPU copy
tmp = x.pin_memory()
assert tmp is not x
# CPU-to-GPU copy
return out.copy_(tmp, non_blocking=True)
class UvaBuffer:
def __init__(self, size: int | Sequence[int], dtype: torch.dtype):
if not is_uva_available():
raise RuntimeError("UVA is not available")
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.uva = get_accelerator_view_from_cpu_tensor(self.cpu)
class UvaBufferPool:
def __init__(
self,
size: int | Sequence[int],
dtype: torch.dtype,
max_concurrency: int = 2,
):
self.size = size
self.dtype = dtype
self.max_concurrency = max_concurrency
# UVA buffers for concurrency
self._uva_bufs = [UvaBuffer(size, dtype) for _ in range(max_concurrency)]
# Current buffer index
self._curr = 0
def copy_to_uva(self, x: torch.Tensor | np.ndarray | list) -> torch.Tensor:
# Round robin to the next buffer.
self._curr = (self._curr + 1) % self.max_concurrency
buf = self._uva_bufs[self._curr]
# CPU-to-CPU copy
dst = buf.cpu if isinstance(x, torch.Tensor) else buf.np
n = len(x)
dst[:n] = x
return buf.uva[:n]
def copy_to_gpu(
self,
x: torch.Tensor | np.ndarray,
out: torch.Tensor | None = None,
) -> torch.Tensor:
uva = self.copy_to_uva(x)
# CPU-to-GPU copy
return uva.clone() if out is None else out.copy_(uva, non_blocking=True)
class UvaBackedTensor:
def __init__(
self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2
):
self.dtype = dtype
self.max_concurrency = max_concurrency
# Source of truth
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False)
self.np = self.cpu.numpy()
# Buffers for concurrency
self.pool = UvaBufferPool(size, dtype, max_concurrency)
self.gpu = self.pool.copy_to_uva(self.np)
def copy_to_uva(self, n: int | None = None) -> torch.Tensor:
# CPU-to-CPU copy
self.gpu = self.pool.copy_to_uva(self.np[:n] if n is not None else self.np)
return self.gpu
class StagedWriteTensor:
def __init__(
self,
size: int | Sequence[int],
dtype: torch.dtype,
device: torch.device,
max_concurrency: int = 2,
uva_instead_of_gpu: bool = False,
):
supported_dtypes = [torch.int32, torch.int64, torch.float32]
if dtype not in supported_dtypes:
raise ValueError(
f"Unsupported dtype {dtype}: should be one of {supported_dtypes}"
)
self.num_rows = size if isinstance(size, int) else size[0]
self.dtype = dtype
self.device = device
self.max_concurrency = max_concurrency
if not uva_instead_of_gpu:
# Create a GPU tensor (default)
self.gpu = torch.zeros(size, dtype=dtype, device=device)
else:
# For a large but not-frequently-accessed tensor, we can use UVA instead of
# GPU to save GPU memory
self._uva_buf = UvaBuffer(size, dtype)
self.gpu = self._uva_buf.uva
self._staged_write_indices: list[int] = []
self._staged_write_starts: list[int] = []
self._staged_write_contents: list[int | float] = []
self._staged_write_cu_lens: list[int] = []
new_buffer = partial(UvaBufferPool, max_concurrency=max_concurrency)
self.write_indices = new_buffer(self.num_rows, dtype=torch.int32)
self.write_starts = new_buffer(self.num_rows, dtype=torch.int32)
self.write_cu_lens = new_buffer(self.num_rows, dtype=torch.int32)
def stage_write(
self, index: int, start: int, x: Iterable[int] | Iterable[float]
) -> None:
assert index >= 0
assert start >= 0
if not x:
return
self._staged_write_indices.append(index)
self._staged_write_starts.append(start)
self._staged_write_contents.extend(x)
self._staged_write_cu_lens.append(len(self._staged_write_contents))
def stage_write_elem(self, index: int, x: int) -> None:
assert index >= 0
self._staged_write_indices.append(index)
self._staged_write_starts.append(0)
self._staged_write_contents.append(x)
self._staged_write_cu_lens.append(len(self._staged_write_contents))
def apply_write(self) -> None:
n = len(self._staged_write_indices)
if n == 0:
return
indices_uva = self.write_indices.copy_to_uva(self._staged_write_indices)
starts_uva = self.write_starts.copy_to_uva(self._staged_write_starts)
cu_lens_uva = self.write_cu_lens.copy_to_uva(self._staged_write_cu_lens)
# Special handling for write_contents
write_contents = async_tensor_h2d(
self._staged_write_contents, self.dtype, self.device, pin_memory=True
)
# Write diffs to the GPU buffer
_apply_write_kernel[(n,)](
self.gpu,
self.gpu.stride(0),
indices_uva,
starts_uva,
write_contents,
cu_lens_uva,
BLOCK_SIZE=1024,
)
# Clear the staged writes
self.clear_staged_writes()
def clear_staged_writes(self) -> None:
self._staged_write_indices.clear()
self._staged_write_starts.clear()
self._staged_write_contents.clear()
self._staged_write_cu_lens.clear()
@triton.jit
def _apply_write_kernel(
output_ptr,
output_stride,
write_indices_ptr,
write_starts_ptr,
write_contents_ptr,
write_cu_lens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = tl.load(write_indices_ptr + pid)
start_idx = tl.load(write_starts_ptr + pid)
cu_start = tl.load(write_cu_lens_ptr + pid - 1) if pid > 0 else 0
cu_end = tl.load(write_cu_lens_ptr + pid)
content_len = cu_end - cu_start
for i in range(0, content_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < content_len
content = tl.load(write_contents_ptr + cu_start + block, mask=mask)
tl.store(
output_ptr + row_idx * output_stride + start_idx + block, content, mask=mask
)

View File

@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
def prepare_dcp_local_seq_lens(
dcp_local_seq_lens: torch.Tensor,
seq_lens: torch.Tensor,
num_reqs: int,
dcp_size: int,
dcp_rank: int,
cp_interleave: int,
) -> None:
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
if dcp_size == 1:
return
max_num_reqs = dcp_local_seq_lens.shape[0]
BLOCK_SIZE = 128
num_blocks = triton.cdiv(max_num_reqs, BLOCK_SIZE)
_dcp_local_seq_lens_kernel[(num_blocks,)](
dcp_local_seq_lens,
seq_lens,
dcp_size,
dcp_rank,
cp_interleave,
num_reqs,
max_num_reqs,
BLOCK_SIZE,
)
@triton.jit
def _dcp_local_seq_lens_kernel(
out_ptr,
seq_lens_ptr,
dcp_size,
dcp_rank,
cp_interleave,
num_reqs,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
seq_lens = tl.load(seq_lens_ptr + block, mask=block < num_reqs)
# Distribute KV cache among different ranks, in a round-robin manner.
rounds = seq_lens // (dcp_size * cp_interleave)
remainder = seq_lens % (dcp_size * cp_interleave)
remainder = tl.maximum(remainder - dcp_rank * cp_interleave, 0)
remainder = tl.minimum(remainder, cp_interleave)
local_seq_lens = rounds * cp_interleave + remainder
# For [num_reqs, max_num_reqs), pad with 0
local_seq_lens = tl.where(block < num_reqs, local_seq_lens, 0)
tl.store(out_ptr + block, local_seq_lens, mask=block < max_num_reqs)

View File

@@ -0,0 +1,462 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.model_executor.offloader.base import get_offloader
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.utils import AttentionGroup
class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
uses_mrope: bool,
use_aux_hidden_state_outputs: bool,
device: torch.device,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.uses_mrope = uses_mrope
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.uniform_decode_query_len = 1
spec_config = vllm_config.speculative_config
if spec_config is not None:
self.uniform_decode_query_len += spec_config.num_speculative_tokens
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_mode = self.compilation_config.cudagraph_mode
use_uniform_decode_cudagraph = (
self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.cudagraph_mode.separate_routine()
)
self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
self.uniform_decode_query_len,
use_uniform_decode_cudagraph,
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
self.aux_hidden_states: list[torch.Tensor] = []
def needs_capture(self) -> bool:
return len(self.cudagraph_sizes) > 0
def get_cudagraph_size(
self, num_tokens: int, uniform_decode: bool = False
) -> int | None:
if uniform_decode and self.uniform_decode_cudagraph_sizes:
return self.uniform_decode_cudagraph_sizes.get(num_tokens)
return self.cudagraph_sizes.get(num_tokens)
def capture_graph(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
model: nn.Module,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
uniform_decode: bool = False,
) -> None:
# select and check capture function
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
# prepare inputs
if uniform_decode:
num_reqs = min(
cdiv(num_tokens, self.uniform_decode_query_len),
self.max_num_reqs,
)
else:
num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
if self.uses_mrope:
assert mrope_positions is not None
positions = mrope_positions[:, :num_tokens]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:num_tokens]
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
block_tables,
attn_groups,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=(
self.uniform_decode_query_len if uniform_decode else 0
),
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
):
model_output = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Allocate output buffers if not already done.
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states]
capture_fn(
num_tokens=num_tokens,
num_reqs=num_reqs,
model=model,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
num_tokens_across_dp=num_tokens_across_dp,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
has_lora=has_lora,
)
def _capture_full_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
assert attn_metadata is not None
# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
with (
set_forward_context(
attn_metadata=attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
),
torch.cuda.graph(graph, self.pool),
):
model_output = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader().join_after_forward()
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Copy outputs to the output buffers.
assert self.hidden_states is not None
self.hidden_states[:num_tokens] = hidden_states
if self.use_aux_hidden_state_outputs:
for i, aux_hidden in enumerate(aux_hidden_states):
self.aux_hidden_states[i][:num_tokens] = aux_hidden
self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
# create batch descriptor for piecewise cudagraph dispatch key
batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)
# Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
with set_forward_context(
attn_metadata=None, # piecewise no need attn_metadata
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings,
):
model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
@torch.inference_mode()
def capture(
self,
model: nn.Module,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
) -> None:
common_kwargs = dict(
device=self.device,
capture_fn=self.capture_graph,
model=model,
input_buffers=input_buffers,
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
has_lora=has_lora,
)
# Phase 1: Capture for mixed prefill-decode batches if needed.
mixed_mode = self.cudagraph_mode.mixed_mode()
if mixed_mode != CUDAGraphMode.NONE:
capture_graphs(
cudagraph_sizes=self.cudagraph_sizes,
capture_cudagraph_mode=mixed_mode,
desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
uniform_decode=False,
**common_kwargs,
)
# Phase 2: Capture FULL graphs for uniform decode batches if needed.
# This is only needed if we use a separate routine for decode batches
# and the decode_mode is FULL.
if self.uniform_decode_cudagraph_sizes:
capture_graphs(
cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
capture_cudagraph_mode=CUDAGraphMode.FULL,
desc="Capturing CUDA graphs (decode, FULL)",
uniform_decode=True,
**common_kwargs,
)
def get_cudagraph_runtime_mode(
self, num_reqs: int, num_tokens: int, max_query_len: int
) -> tuple[CUDAGraphMode, int | None]:
is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_tokens == max_query_len * num_reqs
)
cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
elif is_uniform_decode:
cudagraph_mode = self.cudagraph_mode.decode_mode()
else:
cudagraph_mode = self.cudagraph_mode.mixed_mode()
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def run_fullgraph(
self, num_tokens: int
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
# Sync offloader before replay - needed when transitioning from
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader().sync_prev_onload()
self.graphs[num_tokens].replay()
assert self.hidden_states is not None
hidden_states = self.hidden_states[:num_tokens]
if not self.use_aux_hidden_state_outputs:
return hidden_states
return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states]
def get_cudagraph_sizes(
capture_sizes: list[int] | None,
max_num_reqs: int,
max_num_tokens: int,
cudagraph_mode: CUDAGraphMode,
uniform_decode_query_len: int = 1,
uniform_decode_cudagraph: bool = False,
) -> tuple[dict[int, int], dict[int, int]]:
# Support both FULL and PIECEWISE cudagraph modes
if cudagraph_mode == CUDAGraphMode.NONE:
return {}, {}
if not capture_sizes:
return {}, {}
capture_sizes = sorted(capture_sizes)
if not capture_sizes:
return {}, {}
cudagraph_sizes: dict[int, int] = {}
for i in range(1, capture_sizes[-1] + 1):
for x in capture_sizes:
if i <= x:
cudagraph_sizes[i] = x
break
uniform_decode_cudagraph_sizes: dict[int, int] = {}
if uniform_decode_cudagraph:
max_num_tokens = max_num_reqs * uniform_decode_query_len
uniform_decode_cudagraph_sizes = {
k: v
for k, v in cudagraph_sizes.items()
if v <= max_num_tokens and v >= uniform_decode_query_len
}
return cudagraph_sizes, uniform_decode_cudagraph_sizes
def capture_graphs(
cudagraph_sizes: dict[int, int],
device: torch.device,
capture_fn: Callable,
capture_cudagraph_mode: CUDAGraphMode,
desc: str = "Capturing CUDA graphs",
**capture_kwargs,
) -> None:
# Capture larger graphs first.
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
with graph_capture(device=device):
for size in sizes_to_capture:
capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
def prepare_inputs_to_capture(
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
max_model_len: int,
kv_cache_config: KVCacheConfig,
uniform_decode_query_len: int = 0,
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
if uniform_decode_query_len > 0:
num_tokens_per_req = uniform_decode_query_len
else:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
query_start_loc_np[-1] = num_tokens
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len.
input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, kv_cache_config
)
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=num_tokens_per_req,
seq_lens=input_buffers.seq_lens,
max_seq_len=max_model_len,
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
)
return attn_metadata, slot_mappings_by_layer

View File

@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import get_dp_group
def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None:
if dp_size == 1:
return None
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
def get_batch_metadata_across_dp(
num_tokens: int,
cudagraph_size: int,
cudagraph_runtime_mode: int,
dp_size: int,
dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert dp_size > 1
# Use CPU group to avoid CPU-GPU synchronization.
group = get_dp_group().cpu_group
tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
tensor[0][dp_rank] = num_tokens
tensor[1][dp_rank] = cudagraph_size
tensor[2][dp_rank] = cudagraph_runtime_mode
dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1], tensor[2]
def get_cudagraph_and_dp_padding(
num_tokens: int,
cudagraph_size: int | None,
cudagraph_runtime_mode: int,
dp_size: int,
dp_rank: int,
) -> tuple[int, torch.Tensor | None, int]:
if dp_size == 1:
if cudagraph_size is not None:
return cudagraph_size, None, cudagraph_runtime_mode
else:
return num_tokens, None, cudagraph_runtime_mode
# Convert None to -1 for sync (indicates no cudagraph available)
if num_tokens == 0:
cudagraph_size = 0
elif cudagraph_size is None:
cudagraph_size = -1
num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = (
get_batch_metadata_across_dp(
num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank
)
)
if torch.all(num_tokens_across_dp == 0).item():
# All ranks have zero tokens to run.
return 0, None, 0
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item())
# Check if all ranks have valid cudagraph_size.
all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
if synced_cudagraph_mode != 0 and all_have_cudagraph:
# All ranks use cudagraph. Pad to max cudagraph_size.
max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
num_tokens_across_dp[:] = max_cudagraph_size
return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
else:
# Fall back to eager mode (no cudagraph).
# Either some rank doesn't have cudagraph size or mode is NONE.
synced_cudagraph_mode = 0
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode

View File

@@ -0,0 +1,548 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.utils import random_uuid
class InputBuffers:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.device = device
self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
# DCP: per-request local seq_lens buffer
self.dcp_local_seq_lens = torch.zeros(
max_num_reqs, dtype=torch.int32, device=device
)
@dataclass
class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
num_reqs: int
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# Identical to idx_mapping except for spec decoding.
expanded_idx_mapping: torch.Tensor
# [total_num_logits] position within request for each logit
expanded_local_pos: torch.Tensor
# [num_reqs]
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
# sum(num_scheduled_tokens)
num_tokens: int
num_tokens_after_padding: int
num_draft_tokens: int
# [num_reqs + 1]
query_start_loc: torch.Tensor
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
# [num_tokens_after_padding]
input_ids: torch.Tensor
# [num_tokens_after_padding]
positions: torch.Tensor
# [3, num_tokens_after_padding]
mrope_positions: torch.Tensor | None
# [num_tokens_after_padding, hidden_size]
inputs_embeds: torch.Tensor | None
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# layer_name -> slot_mapping
slot_mappings: dict[str, torch.Tensor]
# [total_num_logits]
logits_indices: torch.Tensor
# [num_reqs + 1]
cu_num_logits: torch.Tensor
cu_num_logits_np: np.ndarray
# Whether any requests in batch use structured output.
has_structured_output_reqs: bool
@classmethod
def make_dummy(
cls,
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
device: torch.device,
) -> "InputBatch":
assert 0 < num_reqs <= num_tokens
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
expanded_idx_mapping = idx_mapping
expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
assert int(num_scheduled_tokens.sum()) == num_tokens
# seq_len equals to query_len
input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
# Pad for full CUDA graph mode.
input_buffers.seq_lens[num_reqs:] = 0
seq_lens = input_buffers.seq_lens[:num_reqs]
query_start_loc_np = np.empty(num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
input_buffers.query_start_loc[:1] = 0
torch.cumsum(
seq_lens, dim=0, out=input_buffers.query_start_loc[1 : num_reqs + 1]
)
# Pad for full CUDA graph mode.
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
input_ids = input_buffers.input_ids[:num_tokens].zero_()
positions = input_buffers.positions[:num_tokens].zero_()
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping,
expanded_local_pos=expanded_local_pos,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
num_draft_tokens=0,
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
input_ids=input_ids,
positions=positions,
mrope_positions=None,
inputs_embeds=None,
attn_metadata=None, # type: ignore
slot_mappings=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
has_structured_output_reqs=False,
)
@triton.jit
def _prepare_prefill_inputs_kernel(
input_ids_ptr,
next_prefill_tokens_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
all_token_ids_ptr,
all_token_ids_stride,
prefill_lens_ptr,
num_computed_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
if num_computed >= prefill_len:
# Not prefill.
return
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
request_ptr = all_token_ids_ptr + req_state_idx * all_token_ids_stride
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
tokens = tl.load(request_ptr + num_computed + block, mask=mask)
tl.store(input_ids_ptr + query_start + block, tokens, mask=mask)
next_pos = num_computed + query_len
if next_pos < prefill_len:
next_token = tl.load(request_ptr + next_pos)
tl.store(next_prefill_tokens_ptr + req_state_idx, next_token)
def prepare_prefill_inputs(
input_ids: torch.Tensor,
next_prefill_tokens: torch.Tensor,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
all_token_ids: torch.Tensor,
prefill_len: torch.Tensor,
num_computed_tokens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_prepare_prefill_inputs_kernel[(num_reqs,)](
input_ids,
next_prefill_tokens,
idx_mapping,
query_start_loc,
all_token_ids,
all_token_ids.stride(0),
prefill_len,
num_computed_tokens,
BLOCK_SIZE=1024,
)
@triton.jit
def _prepare_pos_seq_lens_kernel(
pos_ptr,
seq_lens_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
num_computed_tokens_ptr,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
req_id = tl.program_id(0)
num_reqs = tl.num_programs(0) - 1
if req_id == num_reqs:
# Pad unused seq_lens as 0 for full CUDA graphs.
for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(seq_lens_ptr + block, 0, mask=mask)
return
req_state_idx = tl.load(idx_mapping_ptr + req_id)
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
start = tl.load(query_start_loc_ptr + req_id)
end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = end - start
seq_len = num_computed_tokens + query_len
tl.store(seq_lens_ptr + req_id, seq_len)
for i in tl.range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
pos = num_computed_tokens + block
tl.store(pos_ptr + start + block, pos, mask=mask)
def prepare_pos_seq_lens(
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
num_computed_tokens: torch.Tensor,
pos: torch.Tensor,
seq_lens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
# NOTE(woosuk): We do +1 because the last thread block is used
# to pad unused seq_lens as 0 for full CUDA graphs.
_prepare_pos_seq_lens_kernel[(num_reqs + 1,)](
pos,
seq_lens,
idx_mapping,
query_start_loc,
num_computed_tokens,
seq_lens.shape[0],
BLOCK_SIZE=1024,
)
@triton.jit
def _combine_sampled_and_draft_tokens_kernel(
input_ids_ptr,
idx_mapping_ptr,
last_sampled_tokens_ptr,
query_start_loc_ptr,
seq_lens_ptr,
prefill_len_ptr,
draft_tokens_ptr,
draft_tokens_stride,
cu_num_logits_ptr,
logits_indices_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
# Get the number of logits and draft tokens.
cu_num_logits_start = tl.load(cu_num_logits_ptr + batch_idx)
cu_num_logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
num_logits = cu_num_logits_end - cu_num_logits_start
num_draft_tokens = num_logits - 1
# Compute the logits indices.
block = tl.arange(0, BLOCK_SIZE)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
logits_start = query_end - num_logits
tl.store(
logits_indices_ptr + cu_num_logits_start + block,
logits_start + block,
mask=block < num_logits,
)
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if seq_len <= prefill_len:
# Handling prefill tokens. No sampled or draft tokens.
return
# Write the last sampled token ID to input_ids.
last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
tl.store(input_ids_ptr + query_end - num_logits, last_token_id)
# Write the draft tokens (if any) to input_ids.
if num_draft_tokens > 0:
mask = block < num_draft_tokens
draft_tokens = tl.load(
draft_tokens_ptr + req_state_idx * draft_tokens_stride + block,
mask=mask,
)
tl.store(
input_ids_ptr + query_end - num_draft_tokens + block,
draft_tokens,
mask=mask,
)
def combine_sampled_and_draft_tokens(
input_ids: torch.Tensor,
idx_mapping: torch.Tensor,
last_sampled_tokens: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
prefill_len: torch.Tensor,
draft_tokens: torch.Tensor,
cu_num_logits: torch.Tensor,
num_logits: int,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
num_speculative_steps = draft_tokens.shape[-1]
logits_indices = torch.empty(
num_logits,
dtype=torch.int64,
device=input_ids.device,
)
_combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
input_ids,
idx_mapping,
last_sampled_tokens,
query_start_loc,
seq_lens,
prefill_len,
draft_tokens,
draft_tokens.stride(0),
cu_num_logits,
logits_indices,
# NOTE(woosuk): Add 1 to ensure the block can cover the last sampled token
# in addition to all draft tokens.
BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
)
return logits_indices
@triton.jit
def _get_num_sampled_and_rejected_kernel(
num_sampled_ptr,
num_rejected_ptr,
seq_lens_ptr,
cu_num_logits_ptr,
idx_mapping_ptr,
prefill_len_ptr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
is_chunked_prefilling = seq_len < prefill_len
num_sampled = tl.load(num_sampled_ptr + batch_idx)
num_sampled = tl.where(is_chunked_prefilling, 0, num_sampled)
tl.store(num_sampled_ptr + batch_idx, num_sampled)
logits_start = tl.load(cu_num_logits_ptr + batch_idx)
logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
num_logits = logits_end - logits_start
num_rejected = num_logits - num_sampled
num_rejected = tl.where(is_chunked_prefilling, 0, num_rejected)
tl.store(num_rejected_ptr + batch_idx, num_rejected)
def get_num_sampled_and_rejected(
num_sampled: torch.Tensor,
seq_lens: torch.Tensor,
cu_num_logits: torch.Tensor,
idx_mapping: torch.Tensor,
prefill_len: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = idx_mapping.shape[0]
num_rejected = torch.empty_like(num_sampled)
_get_num_sampled_and_rejected_kernel[(num_reqs,)](
num_sampled,
num_rejected,
seq_lens,
cu_num_logits,
idx_mapping,
prefill_len,
)
return num_sampled, num_rejected
@triton.jit
def _post_update_kernel(
idx_mapping_ptr,
num_computed_tokens_ptr,
last_sampled_tokens_ptr,
output_bin_counts_ptr,
output_bin_counts_stride,
sampled_tokens_ptr,
sampled_tokens_stride,
num_sampled_ptr,
num_rejected_ptr,
query_start_loc_ptr,
all_token_ids_ptr,
all_token_ids_stride,
total_len_ptr,
):
req_id = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_id)
total_len = tl.load(total_len_ptr + req_state_idx)
num_sampled = tl.load(num_sampled_ptr + req_id)
if num_sampled > 0:
token_id = tl.load(
sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1
)
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
tl.store(total_len_ptr + req_state_idx, total_len + num_sampled)
for i in range(num_sampled):
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
token_ptr = (
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
)
count = tl.load(token_ptr)
count += 1
tl.store(token_ptr, count)
tl.store(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
token_id,
)
query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = query_end - query_start
num_rejected = tl.load(num_rejected_ptr + req_id)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
num_computed += query_len - num_rejected
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
def post_update(
# [num_reqs]
idx_mapping: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [max_num_reqs]
last_sampled_tokens: torch.Tensor,
# [max_num_reqs, vocab_size]
output_bin_counts: torch.Tensor,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens: torch.Tensor,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [num_reqs + 1]
query_start_loc: torch.Tensor,
# [max_num_reqs, max_model_len]
all_token_ids: torch.Tensor,
# [max_num_reqs]
total_len: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_post_update_kernel[(num_reqs,)](
idx_mapping,
num_computed_tokens,
last_sampled_tokens,
output_bin_counts,
output_bin_counts.stride(0),
sampled_tokens,
sampled_tokens.stride(0),
num_sampled,
num_rejected,
query_start_loc,
all_token_ids,
all_token_ids.stride(0),
total_len,
num_warps=1,
)
@triton.jit
def _expand_idx_mapping_kernel(
idx_mapping_ptr,
expanded_idx_mapping_ptr,
expanded_local_pos_ptr,
cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
block = tl.arange(0, BLOCK_SIZE)
mask = block < num_tokens
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
tl.store(expanded_idx_mapping_ptr + start_idx + block, req_state_idx, mask=mask)
tl.store(expanded_local_pos_ptr + start_idx + block, block, mask=mask)
def expand_idx_mapping(
idx_mapping: torch.Tensor,
total_num_logits: int,
cu_num_logits: torch.Tensor,
max_expand_len: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = idx_mapping.shape[0]
expanded_idx_mapping = idx_mapping.new_empty(total_num_logits)
expanded_local_pos = torch.empty(
total_num_logits, dtype=torch.int32, device=idx_mapping.device
)
_expand_idx_mapping_kernel[(num_reqs,)](
idx_mapping,
expanded_idx_mapping,
expanded_local_pos,
cu_num_logits,
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
)
return expanded_idx_mapping, expanded_local_pos

View File

@@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
kv_transfer_state,
)
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.forward_context import (
get_forward_context,
is_forward_context_available,
set_forward_context,
)
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
ModelRunnerOutput,
)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
class KVConnector:
"""KVConnector interface used by GPUModelRunner."""
def pre_forward(self, scheduler_output: "SchedulerOutput") -> None:
pass
def post_forward(
self, scheduler_output: "SchedulerOutput", wait_for_save: bool = True
) -> KVConnectorOutput | None:
return None
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
return EMPTY_MODEL_RUNNER_OUTPUT
def set_disabled(self, disabled: bool) -> None:
pass
class ActiveKVConnector(KVConnector):
def __init__(
self, vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
):
self.vllm_config = vllm_config
self.kv_connector = get_kv_transfer_group()
# Register kv caches with KV Connector if applicable.
# TODO: support cross_layers_kv_cache
# (see https://github.com/vllm-project/vllm/pull/27743)
self.kv_connector.register_kv_caches(kv_caches_dict)
self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks)
self._disabled = False
def pre_forward(self, scheduler_output: "SchedulerOutput") -> None:
if self._disabled:
return
if scheduler_output.preempted_req_ids:
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
self.kv_connector.bind_connector_metadata(kv_connector_metadata)
# TODO: sort out KV Connectors' use of forward_context
if is_forward_context_available():
self.kv_connector.start_load_kv(get_forward_context())
else:
with set_forward_context(None, self.vllm_config):
self.kv_connector.start_load_kv(get_forward_context())
def post_forward(
self,
scheduler_output: "SchedulerOutput",
wait_for_save: bool = True,
clear_metadata: bool = True,
) -> KVConnectorOutput | None:
if self._disabled:
return None
output = KVConnectorOutput()
if wait_for_save:
self.kv_connector.wait_for_save()
output.finished_sending, output.finished_recving = (
self.kv_connector.get_finished(scheduler_output.finished_req_ids)
)
output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors()
output.kv_connector_stats = self.kv_connector.get_kv_connector_stats()
output.kv_cache_events = self.kv_connector.get_kv_connector_kv_cache_events()
if clear_metadata:
self.kv_connector.clear_connector_metadata()
return output
def clear_metadata(self) -> None:
"""Clear the connector metadata. Call this after draft model runs."""
if not self._disabled:
self.kv_connector.clear_connector_metadata()
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
if self._disabled:
return EMPTY_MODEL_RUNNER_OUTPUT
self.pre_forward(scheduler_output)
kv_connector_output = self.post_forward(scheduler_output, wait_for_save=False)
if kv_connector_output is None or kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
def set_disabled(self, disabled: bool) -> None:
# Ensure that layer-wise connector hooks aren't called when disabled.
kv_transfer_state._KV_CONNECTOR_AGENT = None if disabled else self.kv_connector
self._disabled = disabled
NO_OP_KV_CONNECTOR = KVConnector()
def get_kv_connector(
vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
) -> KVConnector:
if not has_kv_transfer_group():
# No-op connector.
return NO_OP_KV_CONNECTOR
return ActiveKVConnector(vllm_config, kv_caches_dict)

View File

@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
from vllm.lora.request import LoRARequest
NO_LORA_ID = 0
class LoraState:
def __init__(self, max_num_reqs: int):
self.lora_ids = np.zeros(max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
# req_id -> lora_request
self.lora_requests: dict[str, LoRARequest] = {}
def add_request(
self, req_id: str, req_index: int, lora_request: LoRARequest | None
) -> None:
if lora_request is not None:
self.lora_requests[req_id] = lora_request
self.lora_ids[req_index] = lora_request.lora_int_id
else:
self.lora_ids[req_index] = NO_LORA_ID
def remove_request(self, req_id: str) -> None:
self.lora_requests.pop(req_id, None)
def make_lora_inputs(
self,
req_ids: list[str],
idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray,
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
lora_ids = self.lora_ids[idx_mapping]
prompt_lora_mapping = tuple(lora_ids)
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set()
for req_id in req_ids:
lora_request = self.lora_requests.get(req_id)
if lora_request is not None:
active_lora_requests.add(lora_request)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests

View File

View File

@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch._inductor.runtime.triton_helpers import libdevice
from vllm.triton_utils import tl, triton
@triton.jit
def _num_nans_kernel(
logits_ptr,
logits_stride,
num_nans_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_nans = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=0
)
logits = logits.to(tl.float32)
is_nan = libdevice.isnan(logits).to(tl.int1)
num_nans += tl.sum(is_nan).to(tl.int32)
tl.store(num_nans_ptr + req_idx, num_nans)
def get_num_nans(logits: torch.Tensor) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device)
_num_nans_kernel[(num_reqs,)](
logits,
logits.stride(0),
num_nans,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
return num_nans

View File

View File

@@ -0,0 +1,183 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
class EncoderRunner:
def __init__(
self,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.dtype = dtype
self.device = device
self.inputs_embeds = torch.zeros(
max_num_tokens, hidden_size, dtype=dtype, device=device
)
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
self.encoder_cache: dict[str, torch.Tensor] = {}
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_cache.clear()
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
self.req_id_to_mm_features[req_id] = mm_features
def free_encoder_cache(self, mm_hash: str) -> None:
self.encoder_cache.pop(mm_hash, None)
def remove_request(self, req_id: str) -> None:
self.req_id_to_mm_features.pop(req_id, None)
def prepare_mm_inputs(
self, scheduled_encoder_inputs: dict[str, list[int]]
) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]:
mm_hashes: list[str] = []
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
mm_features = self.req_id_to_mm_features[req_id]
for mm_input_id in encoder_input_ids:
mm_feature = mm_features[mm_input_id]
if mm_feature.data is None:
continue
mm_hashes.append(mm_feature.identifier)
mm_kwargs.append((mm_feature.modality, mm_feature.data))
return mm_hashes, mm_kwargs
@torch.inference_mode()
def execute_mm_encoder(
self,
model: SupportsMultiModal,
mm_hashes: list[str],
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
) -> list[torch.Tensor]:
if not mm_hashes:
return []
encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, device=self.device, pin_memory=False
):
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs, expected_num_items=num_items
)
encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
return encoder_outputs
def gather_mm_embeddings(
self,
req_ids: list[str],
total_num_scheduled_tokens: int,
num_scheduled_tokens: np.ndarray,
query_start_loc: np.ndarray,
prefill_lens: np.ndarray,
computed_prefill_lens: np.ndarray,
) -> tuple[list[torch.Tensor], torch.Tensor]:
is_prefilling = (computed_prefill_lens < prefill_lens).tolist()
all_decode = not any(is_prefilling)
if all_decode:
# All decode requests, so no need to gather any embeddings.
return [], torch.zeros(
total_num_scheduled_tokens, dtype=torch.bool, device=self.device
)
query_start = computed_prefill_lens.tolist()
query_end = (computed_prefill_lens + num_scheduled_tokens).tolist()
mm_embeds: list[torch.Tensor] = []
is_mm_embed = torch.zeros(
total_num_scheduled_tokens, dtype=torch.bool, device="cpu", pin_memory=True
)
for i, req_id in enumerate(req_ids):
if not is_prefilling[i]:
# OPTIMIZATION: Skip decode requests.
continue
mm_features = self.req_id_to_mm_features[req_id]
for mm_feature in mm_features:
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
if start_pos >= query_end[i]:
# The encoder output is not needed in this step.
break
if start_pos + num_encoder_tokens <= query_start[i]:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx = max(query_start[i] - start_pos, 0)
end_idx = min(query_end[i] - start_pos, num_encoder_tokens)
assert start_idx < end_idx
curr_embeds_start, curr_embeds_end = (
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
)
# If there are no embeddings in the current range, we skip
# gathering the embeddings.
if curr_embeds_start == curr_embeds_end:
continue
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
else:
mm_embeds_item = encoder_output[start_idx:end_idx]
req_start_pos = query_start_loc[i] + start_pos - query_start[i]
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True if is_embed is None else is_embed
)
mm_embeds.append(mm_embeds_item)
# Copy the is_mm_embed tensor to the GPU.
is_mm_embed = is_mm_embed.to(device=self.device, non_blocking=True)
return mm_embeds, is_mm_embed
@torch.inference_mode()
def get_inputs_embeds(
self,
model: SupportsMultiModal,
input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor],
is_mm_embed: torch.Tensor,
) -> torch.Tensor:
x = model.embed_input_ids(
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
)
# Copy to the pre-allocated buffer for CUDA graphs.
self.inputs_embeds[: x.shape[0]] = x
return self.inputs_embeds

View File

@@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.models.interfaces import SupportsMRoPE
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
class MRopeState:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
max_model_len: int,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.max_model_len = max_model_len
self.device = device
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# wasting a lot of CPU memory.
self.prefill_mrope_positions = StagedWriteTensor(
(max_num_reqs * 3, max_model_len),
dtype=torch.int32,
device=device,
uva_instead_of_gpu=True,
)
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros(
(3, max_num_tokens + 1), dtype=torch.int64, device=device
)
def init_prefill_mrope_positions(
self,
req_idx: int,
mrope_model: SupportsMRoPE,
prefill_token_ids: list[int],
mm_features: list,
) -> None:
prefill_mrope_positions, prefill_mrope_delta = (
mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features)
)
for i in range(3):
pos = prefill_mrope_positions[i].tolist()
self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos)
self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta
def apply_staged_writes(self) -> None:
self.prefill_mrope_positions.apply_write()
self.prefill_mrope_delta.copy_to_uva()
def prepare_mrope_positions(
self,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
prefill_lens: torch.Tensor,
num_computed_tokens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_prepare_mrope_positions_kernel[(num_reqs,)](
self.mrope_positions,
self.mrope_positions.stride(0),
self.prefill_mrope_positions.gpu,
3 * self.max_model_len,
self.max_model_len,
self.prefill_mrope_delta.gpu,
idx_mapping,
query_start_loc,
prefill_lens,
num_computed_tokens,
BLOCK_SIZE=1024,
)
@triton.jit
def _prepare_mrope_positions_kernel(
mrope_positions_ptr,
mrope_positions_stride,
prefill_mrope_positions_ptr,
prefill_mrope_positions_stride0,
prefill_mrope_positions_stride1,
prefill_mrope_delta_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
prefill_lens_ptr,
num_computed_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
is_prefill = num_computed < prefill_len
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
mrope_delta = tl.load(prefill_mrope_delta_ptr + req_state_idx)
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
orig_pos = num_computed + block
for j in tl.static_range(3):
if is_prefill:
# Read from pre-computed M-RoPE positions.
pos = tl.load(
prefill_mrope_positions_ptr
+ req_state_idx * prefill_mrope_positions_stride0
+ j * prefill_mrope_positions_stride1
+ orig_pos,
mask=mask,
)
else:
# Apply M-RoPE delta.
pos = orig_pos + mrope_delta
tl.store(
mrope_positions_ptr + j * mrope_positions_stride + query_start + block,
pos,
mask=mask,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pipeline Parallelism utils for V2 Model Runner."""
import torch
from vllm.distributed.parallel_state import get_pp_group
def pp_broadcast(
sampled_token_ids: torch.Tensor,
num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
) -> None:
pp = get_pp_group()
assert pp.is_last_rank
assert sampled_token_ids.dtype == torch.int64
torch.distributed.broadcast(
sampled_token_ids.contiguous(), src=pp.last_rank, group=pp.device_group
)
combined = torch.stack((num_sampled, num_rejected), dim=0)
torch.distributed.broadcast(combined, src=pp.last_rank, group=pp.device_group)
def pp_receive(
num_reqs: int, max_sample_len: int = 1
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pp = get_pp_group()
assert not pp.is_last_rank
sampled_tokens = torch.empty(
num_reqs, max_sample_len, dtype=torch.int64, device=pp.device
)
torch.distributed.broadcast(sampled_tokens, src=pp.last_rank, group=pp.device_group)
combined = torch.empty(2, num_reqs, dtype=torch.int32, device=pp.device)
torch.distributed.broadcast(combined, src=pp.last_rank, group=pp.device_group)
num_sampled, num_rejected = combined.unbind(dim=0)
return sampled_tokens, num_sampled, num_rejected

View File

View File

@@ -0,0 +1,194 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
from vllm.v1.worker.gpu.states import RequestState
MAX_BAD_WORDS_TOTAL_TOKENS = 1024 # Max total tokens for all bad words per request
MAX_NUM_BAD_WORDS = 128 # Max number of bad words per request
class BadWordsState:
def __init__(self, req_states: RequestState):
self.req_states = req_states
self.max_num_reqs = req_states.max_num_reqs
self.device = req_states.device
# flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS]
self.bad_word_token_ids = StagedWriteTensor(
(self.max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS),
dtype=torch.int32,
device=self.device,
)
# cumulative offsets of bad words: [max_num_reqs, MAX_NUM_BAD_WORDS + 1]
self.bad_word_offsets = StagedWriteTensor(
(self.max_num_reqs, MAX_NUM_BAD_WORDS + 1),
dtype=torch.int32,
device=self.device,
)
# number of bad words per request
self.num_bad_words = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
bad_words_token_ids = sampling_params.bad_words_token_ids
if not bad_words_token_ids:
self.num_bad_words.np[req_idx] = 0
return
num_bad_words = len(bad_words_token_ids)
if num_bad_words > MAX_NUM_BAD_WORDS:
raise ValueError(
f"Too many bad words: {num_bad_words}. "
f"The max number is {MAX_NUM_BAD_WORDS}."
)
# Flatten bad words and compute offsets
flattened_tokens: list[int] = []
offsets: list[int] = [0]
for bad_word in bad_words_token_ids:
flattened_tokens.extend(bad_word)
offsets.append(len(flattened_tokens))
if len(flattened_tokens) > MAX_BAD_WORDS_TOTAL_TOKENS:
raise ValueError(
f"Too many total bad word tokens: {len(flattened_tokens)}. "
f"The max is {MAX_BAD_WORDS_TOTAL_TOKENS}."
)
# Stage writes
self.bad_word_token_ids.stage_write(req_idx, 0, flattened_tokens)
self.bad_word_offsets.stage_write(req_idx, 0, offsets)
self.num_bad_words.np[req_idx] = num_bad_words
def apply_staged_writes(self) -> None:
self.num_bad_words.copy_to_uva()
self.bad_word_token_ids.apply_write()
self.bad_word_offsets.apply_write()
def apply_bad_words(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> None:
max_num_bad_words = int(self.num_bad_words.np[idx_mapping_np].max())
if max_num_bad_words == 0:
# No request uses bad words. Skip the kernel launch.
return
apply_bad_words(
logits,
idx_mapping,
self.bad_word_token_ids.gpu,
self.bad_word_offsets.gpu,
self.num_bad_words.gpu,
self.req_states.all_token_ids.gpu,
self.req_states.prompt_len.gpu,
self.req_states.total_len.gpu,
input_ids,
expanded_local_pos,
max_num_bad_words,
)
@triton.jit
def _bad_words_kernel(
logits_ptr,
logits_stride,
expanded_idx_mapping_ptr,
bad_word_token_ids_ptr,
bad_word_token_ids_stride,
bad_word_offsets_ptr,
bad_word_offsets_stride,
num_bad_words_ptr,
all_token_ids_ptr,
all_token_ids_stride,
prompt_len_ptr,
total_len_ptr,
input_ids_ptr,
expanded_local_pos_ptr,
):
logit_idx = tl.program_id(0)
bw_idx = tl.program_id(1)
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
if bw_idx >= num_bad_words:
return
pos = tl.load(expanded_local_pos_ptr + logit_idx)
cur_req_first_pos = logit_idx - pos
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
total_len = tl.load(total_len_ptr + req_state_idx)
output_len = total_len - prompt_len
effective_len = output_len + pos
bd_offsets_base = bad_word_offsets_ptr + req_state_idx * bad_word_offsets_stride
bd_tokens_base = bad_word_token_ids_ptr + req_state_idx * bad_word_token_ids_stride
output_base = all_token_ids_ptr + req_state_idx * all_token_ids_stride + prompt_len
start = tl.load(bd_offsets_base + bw_idx)
end = tl.load(bd_offsets_base + bw_idx + 1)
bad_word_len = end - start
prefix_len = bad_word_len - 1
if prefix_len > effective_len:
return
last_token = tl.load(bd_tokens_base + end - 1)
match = 1
for i in range(prefix_len):
expected = tl.load(bd_tokens_base + start + i)
actual_pos = effective_len - prefix_len + i
from_spec_input = actual_pos >= output_len
if from_spec_input:
spec_offset = actual_pos - output_len
actual = tl.load(input_ids_ptr + cur_req_first_pos + spec_offset)
else:
actual = tl.load(output_base + actual_pos)
match = match & (expected == actual)
if match:
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
def apply_bad_words(
logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
bad_word_token_ids: torch.Tensor,
bad_word_offsets: torch.Tensor,
num_bad_words: torch.Tensor,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
total_len: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
max_num_bad_words: int,
) -> None:
total_num_tokens = logits.shape[0]
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
logits,
logits.stride(0),
expanded_idx_mapping,
bad_word_token_ids,
bad_word_token_ids.stride(0),
bad_word_offsets,
bad_word_offsets.stride(0),
num_bad_words,
all_token_ids,
all_token_ids.stride(0),
prompt_len,
total_len,
input_ids,
expanded_local_pos,
)

View File

@@ -0,0 +1,149 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _temperature_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
temperature_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
if temperature == 0.0 or temperature == 1.0:
# Early return to avoid loading logits.
return
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
logits = logits / temperature
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
def apply_temperature(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
temperature: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_temperature_kernel[(num_reqs, num_blocks)](
logits,
logits.stride(0),
idx_mapping,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
@triton.jit
def _gumbel_sample_kernel(
local_argmax_ptr,
local_argmax_stride,
local_max_ptr,
local_max_stride,
logits_ptr,
logits_stride,
idx_mapping_ptr,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + batch_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise in FP32.
u = tl.rand(gumbel_seed, block)
u = tl.maximum(u, 1e-7)
gumbel_noise = -tl.log(-tl.log(u))
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
value, idx = tl.max(logits, axis=0, return_indices=True)
token_id = block_idx * BLOCK_SIZE + idx
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
idx_mapping: torch.Tensor, # [max_num_reqs]
temperature: torch.Tensor, # [max_num_reqs]
seed: torch.Tensor, # [max_num_reqs]
pos: torch.Tensor, # [num_reqs]
apply_temperature: bool,
) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_reqs,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_reqs,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
_gumbel_sample_kernel[(num_reqs, num_blocks)](
local_argmax,
local_argmax.stride(0),
local_max,
local_max.stride(0),
logits,
logits.stride(0),
idx_mapping,
seed,
pos,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
APPLY_TEMPERATURE=apply_temperature,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
return sampled

View File

@@ -0,0 +1,280 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
MAX_NUM_ALLOWED_TOKEN_IDS = 1024
MAX_NUM_LOGIT_BIAS_TOKENS = 1024
MAX_NUM_STOP_TOKEN_IDS = 128
class LogitBiasState:
def __init__(self, max_num_reqs: int, device: torch.device):
self.max_num_reqs = max_num_reqs
# Allowed token IDs.
self.num_allowed_token_ids = UvaBackedTensor(
self.max_num_reqs, dtype=torch.int32
)
self.allowed_token_ids = StagedWriteTensor(
(self.max_num_reqs, MAX_NUM_ALLOWED_TOKEN_IDS),
dtype=torch.int32,
device=device,
)
# Logit bias.
self.num_logit_bias = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.logit_bias_token_ids = StagedWriteTensor(
(self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
dtype=torch.int32,
device=device,
)
self.logit_bias = StagedWriteTensor(
(self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
dtype=torch.float32,
device=device,
)
# Min tokens.
self.min_lens = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.num_stop_token_ids = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.stop_token_ids = StagedWriteTensor(
(self.max_num_reqs, MAX_NUM_STOP_TOKEN_IDS),
dtype=torch.int32,
device=device,
)
# Using any of the above.
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
def add_request(
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
) -> None:
# Using any logit bias.
use_logit_bias = False
# Allowed token IDs.
allowed_token_ids = sampling_params.allowed_token_ids
if allowed_token_ids:
num_allowed_token_ids = len(allowed_token_ids)
if num_allowed_token_ids > MAX_NUM_ALLOWED_TOKEN_IDS:
raise ValueError(
f"Too many allowed token IDs: {num_allowed_token_ids}. "
f"The max size is {MAX_NUM_ALLOWED_TOKEN_IDS}."
)
self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids
self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids)
use_logit_bias = True
else:
self.num_allowed_token_ids.np[req_idx] = 0
# Logit bias.
logit_bias = sampling_params.logit_bias
if logit_bias:
num_logit_bias = len(logit_bias)
if num_logit_bias > MAX_NUM_LOGIT_BIAS_TOKENS:
raise ValueError(
f"Too many logit bias tokens: {num_logit_bias}. "
f"The max size is {MAX_NUM_LOGIT_BIAS_TOKENS}."
)
self.num_logit_bias.np[req_idx] = num_logit_bias
self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys())
self.logit_bias.stage_write(req_idx, 0, logit_bias.values())
use_logit_bias = True
else:
self.num_logit_bias.np[req_idx] = 0
# Min tokens.
min_tokens = sampling_params.min_tokens
min_len = prompt_len + min_tokens
self.min_lens.np[req_idx] = min_len
stop_token_ids = sampling_params.all_stop_token_ids
if min_tokens > 0 and stop_token_ids:
num_stop_token_ids = len(stop_token_ids)
if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS:
raise ValueError(
f"Too many stop tokens: {num_stop_token_ids}. "
f"The max size is {MAX_NUM_STOP_TOKEN_IDS}."
)
self.num_stop_token_ids.np[req_idx] = num_stop_token_ids
self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids)
use_logit_bias = True
else:
self.num_stop_token_ids.np[req_idx] = 0
self.use_logit_bias[req_idx] = use_logit_bias
def apply_staged_writes(self) -> None:
self.num_allowed_token_ids.copy_to_uva()
self.allowed_token_ids.apply_write()
self.num_logit_bias.copy_to_uva()
self.logit_bias_token_ids.apply_write()
self.logit_bias.apply_write()
self.min_lens.copy_to_uva()
self.num_stop_token_ids.copy_to_uva()
self.stop_token_ids.apply_write()
def apply_logit_bias(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> None:
if not np.any(self.use_logit_bias[idx_mapping_np]):
# No request uses logit bias. Skip the kernel launch.
return
apply_logit_bias(
logits,
idx_mapping,
pos,
self.num_allowed_token_ids.gpu,
self.allowed_token_ids.gpu,
self.num_logit_bias.gpu,
self.logit_bias_token_ids.gpu,
self.logit_bias.gpu,
self.min_lens.gpu,
self.num_stop_token_ids.gpu,
self.stop_token_ids.gpu,
)
@triton.jit
def _bias_kernel(
logits_ptr,
logits_stride,
vocab_size,
idx_mapping_ptr,
# Allowed token IDs.
num_allowed_token_ids_ptr,
allowed_token_ids_ptr,
allowed_token_ids_stride,
# Logit bias.
num_logit_bias_ptr,
bias_token_ids_ptr,
bias_token_ids_stride,
bias_ptr,
bias_stride,
# Min tokens.
pos_ptr,
min_lens_ptr,
num_stop_token_ids_ptr,
stop_token_ids_ptr,
stop_token_ids_stride,
BLOCK_SIZE: tl.constexpr,
LOGITS_BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
block = tl.arange(0, BLOCK_SIZE)
# Allowed token IDs.
num_allowed_token_ids = tl.load(num_allowed_token_ids_ptr + req_state_idx)
if num_allowed_token_ids > 0:
block = tl.arange(0, BLOCK_SIZE)
mask = block < num_allowed_token_ids
# Save logits for allowed token IDs.
allowed_token_ids = tl.load(
allowed_token_ids_ptr + req_state_idx * allowed_token_ids_stride + block,
mask=mask,
)
logits = tl.load(
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
)
# Set logits to -inf for all tokens.
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
tl.store(
logits_ptr + batch_idx * logits_stride + offset,
-float("inf"),
mask=offset < vocab_size,
)
# Restore logits for allowed token IDs.
tl.store(
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
logits,
mask=mask,
)
# Logit bias.
num_logit_bias = tl.load(num_logit_bias_ptr + req_state_idx)
if num_logit_bias > 0:
mask = block < num_logit_bias
token_ids = tl.load(
bias_token_ids_ptr + req_state_idx * bias_token_ids_stride + block,
mask=mask,
)
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
logits += bias
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
# Apply min tokens.
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx)
min_len = tl.load(min_lens_ptr + req_state_idx)
if num_stop_token_ids > 0 and pos < min_len:
mask = block < num_stop_token_ids
stop_token_ids = tl.load(
stop_token_ids_ptr + req_state_idx * stop_token_ids_stride + block,
mask=mask,
)
tl.store(
logits_ptr + batch_idx * logits_stride + stop_token_ids,
-float("inf"),
mask=mask,
)
def apply_logit_bias(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
pos: torch.Tensor,
num_allowed_token_ids: torch.Tensor,
allowed_token_ids: torch.Tensor,
num_logit_bias: torch.Tensor,
logit_bias_token_ids: torch.Tensor,
logit_bias: torch.Tensor,
min_lens: torch.Tensor,
num_stop_token_ids: torch.Tensor,
stop_token_ids: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = triton.next_power_of_2(
max(
allowed_token_ids.shape[-1],
logit_bias_token_ids.shape[-1],
stop_token_ids.shape[-1],
)
)
LOGITS_BLOCK_SIZE = 8192
_bias_kernel[(num_reqs,)](
logits,
logits.stride(0),
vocab_size,
idx_mapping,
num_allowed_token_ids,
allowed_token_ids,
allowed_token_ids.stride(0),
num_logit_bias,
logit_bias_token_ids,
logit_bias_token_ids.stride(0),
logit_bias,
logit_bias.stride(0),
pos,
min_lens,
num_stop_token_ids,
stop_token_ids,
stop_token_ids.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
)

View File

@@ -0,0 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
@triton.jit
def _topk_log_softmax_kernel(
output_ptr,
logits_ptr,
logits_stride,
topk_ids_ptr,
topk,
vocab_size,
BLOCK_SIZE: tl.constexpr,
PADDED_TOPK: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore
se = 0.0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
logits = logits.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e)
lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < topk
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
logits = logits.to(tl.float32)
o = logits - max_val - lse
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
@triton.jit
def _ranks_kernel(
output_ptr,
logits_ptr,
logits_stride,
token_ids_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
token_id = tl.load(token_ids_ptr + req_idx)
x = tl.load(row_ptr + token_id)
n = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
n += tl.sum((logits >= x).to(tl.int32))
tl.store(output_ptr + req_idx, n)
def compute_token_logprobs(
logits: torch.Tensor, token_ids: torch.Tensor
) -> torch.Tensor:
batch_size, vocab_size = logits.shape
token_ids = token_ids.to(torch.int64)
num_logprobs = token_ids.shape[1]
logprobs = logits.new_empty((batch_size, num_logprobs), dtype=torch.float32)
_topk_log_softmax_kernel[(batch_size,)](
logprobs,
logits,
logits.stride(0),
token_ids,
num_logprobs,
vocab_size,
BLOCK_SIZE=1024, # type: ignore
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
)
return logprobs
def compute_topk_logprobs(
logits: torch.Tensor,
num_logprobs: int,
sampled_token_ids: torch.Tensor,
cu_num_logits: list[int] | None = None,
) -> LogprobsTensors:
assert num_logprobs >= 0
batch_size, vocab_size = logits.shape
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
if num_logprobs > 0:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat((logprob_token_ids, topk_indices), dim=1)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
logprobs = compute_token_logprobs(logits, logprob_token_ids)
token_ranks = torch.empty(batch_size, dtype=torch.int64, device=logits.device)
_ranks_kernel[(batch_size,)](
token_ranks,
logits,
logits.stride(0),
sampled_token_ids,
vocab_size,
BLOCK_SIZE=8192, # type: ignore
)
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=token_ranks,
cu_num_generated_tokens=cu_num_logits,
)

View File

@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _min_p_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
min_p_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
if min_p == 0.0:
return
max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
)
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore
threshold = max_val + tl.log(min_p)
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
)
logits = tl.where(logits < threshold, float("-inf"), logits)
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
def apply_min_p(
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
_min_p_kernel[(num_reqs,)](
logits,
logits.stride(0),
idx_mapping,
min_p,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)

View File

@@ -0,0 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.v1.outputs import LogprobsTensors
@dataclass
class SamplerOutput:
sampled_token_ids: torch.Tensor
logprobs_tensors: LogprobsTensors | None
num_nans: torch.Tensor | None

View File

@@ -0,0 +1,311 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import async_tensor_h2d
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
from vllm.v1.worker.gpu.states import RequestState
class PenaltiesState:
def __init__(self, req_states: RequestState):
self.req_states = req_states
max_num_reqs = req_states.max_num_reqs
self.vocab_size = req_states.vocab_size
self.device = req_states.device
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.use_penalty = np.zeros(max_num_reqs, dtype=bool)
# Initialize repetition penalty manually because 0 is an invalid value for it.
self.repetition_penalty.np.fill(1.0)
self.repetition_penalty.copy_to_uva()
# Statistics for penalties.
self.prompt_bin_mask = torch.zeros(
max_num_reqs,
cdiv(self.vocab_size, 32),
dtype=torch.int32,
device=self.device,
)
# TODO(woosuk): This tensor is rarely used but can be very large, taking up
# GBs of GPU memory. Optimize the memory usage.
self.output_bin_counts = torch.zeros(
max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
self._new_penalties_reqs: list[int] = []
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
do_penalty = use_penalty(sampling_params)
self.use_penalty[req_idx] = do_penalty
if do_penalty:
self._new_penalties_reqs.append(req_idx)
def apply_staged_writes(self) -> None:
if self._new_penalties_reqs:
idx_mapping = async_tensor_h2d(
self._new_penalties_reqs,
dtype=torch.int32,
target_device=self.device,
pin_memory=True,
)
prefill_lens = self.req_states.prefill_len.np[self._new_penalties_reqs]
max_prefill_len = int(prefill_lens.max())
bincount(
idx_mapping,
self.req_states.all_token_ids.gpu,
self.req_states.prompt_len.gpu,
self.req_states.prefill_len.gpu,
self.prompt_bin_mask,
self.output_bin_counts,
max_prefill_len,
)
self._new_penalties_reqs.clear()
self.repetition_penalty.copy_to_uva()
self.frequency_penalty.copy_to_uva()
self.presence_penalty.copy_to_uva()
def apply_penalties(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
num_speculative_tokens: int,
) -> None:
if not np.any(self.use_penalty[idx_mapping_np]):
# No request uses penalties. Skip the kernel launch.
return
apply_penalties(
logits,
idx_mapping,
input_ids,
expanded_local_pos,
self.repetition_penalty.gpu,
self.frequency_penalty.gpu,
self.presence_penalty.gpu,
self.prompt_bin_mask,
self.output_bin_counts,
num_speculative_tokens,
)
@triton.jit
def _penalties_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
token_ids_ptr,
expanded_local_pos_ptr,
repetition_penalty_ptr,
frequency_penalty_ptr,
presence_penalty_ptr,
prompt_bin_mask_ptr,
prompt_bin_mask_stride,
output_bin_counts_ptr,
output_bin_counts_stride,
vocab_size,
BLOCK_SIZE: tl.constexpr,
MAX_SPEC_LEN: tl.constexpr,
):
token_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0
use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
if not use_penalty:
# Early return to avoid loading logits.
return
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
base_output_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
mask=mask,
other=0,
)
# Compute cumulative draft_counts from previous positions in this request
pos = tl.load(expanded_local_pos_ptr + token_idx)
start_idx = token_idx - pos
draft_counts = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
for prev_pos in tl.static_range(MAX_SPEC_LEN):
if prev_pos < pos:
prev_token = tl.load(token_ids_ptr + start_idx + prev_pos + 1)
token_match = block == prev_token
draft_counts = draft_counts + token_match.to(tl.int32)
# Total counts = base output counts + cumulative draft counts
output_bin_counts = base_output_counts + draft_counts
output_bin_mask = output_bin_counts > 0
# Apply repetition penalties.
if use_rep_penalty:
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
packed_mask = tl.load(
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block,
mask=packed_block < tl.cdiv(vocab_size, 32),
other=0,
)
prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
logits *= tl.where(logits > 0, 1.0 / scale, scale)
# Apply frequency penalties.
logits -= freq_penalty * output_bin_counts
# Apply presence penalties.
logits -= pres_penalty * output_bin_mask
# Store back to logits.
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
def apply_penalties(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
token_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
repetition_penalty: torch.Tensor,
frequency_penalty: torch.Tensor,
presence_penalty: torch.Tensor,
prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor,
num_speculative_tokens: int,
) -> None:
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_penalties_kernel[(num_tokens, num_blocks)](
logits,
logits.stride(0),
idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
frequency_penalty,
presence_penalty,
prompt_bin_mask,
prompt_bin_mask.stride(0),
output_bin_counts,
output_bin_counts.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
MAX_SPEC_LEN=num_speculative_tokens,
)
@triton.jit
def _bincount_kernel(
idx_mapping_ptr,
all_token_ids_ptr,
all_token_ids_stride,
prompt_len_ptr,
prefill_len_ptr,
prompt_bin_mask_ptr,
prompt_bin_mask_stride,
output_bin_counts_ptr,
output_bin_counts_stride,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if block_idx * BLOCK_SIZE >= prefill_len:
return
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if block_idx * BLOCK_SIZE < prompt_len:
mask = block < prompt_len
prompt_tokens = tl.load(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
)
idx = prompt_tokens // 32
bit_idx = prompt_tokens % 32
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
tl.atomic_or(
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + idx,
bit,
mask=mask,
)
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
mask = block < prefill_len
mask &= block >= prompt_len
output_tokens = tl.load(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
)
tl.atomic_add(
output_bin_counts_ptr
+ req_state_idx * output_bin_counts_stride
+ output_tokens,
1,
mask=mask,
)
def bincount(
idx_mapping: torch.Tensor,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
prefill_len: torch.Tensor,
prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor,
max_prefill_len: int,
) -> None:
prompt_bin_mask[idx_mapping] = 0
output_bin_counts[idx_mapping] = 0
num_reqs = idx_mapping.shape[0]
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_reqs, num_blocks)](
idx_mapping,
all_token_ids,
all_token_ids.stride(0),
prompt_len,
prefill_len,
prompt_bin_mask,
prompt_bin_mask.stride(0),
output_bin_counts,
output_bin_counts.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
)
def use_penalty(sampling_params: SamplingParams) -> bool:
return (
sampling_params.repetition_penalty != 1.0
or sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
)

View File

@@ -0,0 +1,208 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
class PromptLogprobsWorker:
def __init__(self, max_num_reqs: int):
self.max_num_reqs = max_num_reqs
self.uses_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
# req_idx -> list of in-progress LogprobsTensors
self.in_progress_prompt_logprobs: dict[str, list[LogprobsTensors]] = {}
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
# For now, only support prompt logprobs for the prompt tokens (not top-k).
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.uses_prompt_logprobs[req_idx] = uses_prompt_logprobs
if uses_prompt_logprobs:
self.in_progress_prompt_logprobs[req_id] = []
def remove_request(self, req_id: str) -> None:
self.in_progress_prompt_logprobs.pop(req_id, None)
def compute_prompt_logprobs(
self,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
hidden_states: torch.Tensor,
input_batch: InputBatch,
# [max_num_reqs, max_model_len]
all_token_ids: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [max_num_reqs]
prompt_lens: np.ndarray,
# [max_num_reqs]
prefill_lens: np.ndarray,
# [max_num_reqs]
num_computed_prefill_tokens: np.ndarray,
) -> dict[str, LogprobsTensors]:
idx_mapping_np = input_batch.idx_mapping_np
needs_prompt_logprobs = self.uses_prompt_logprobs[idx_mapping_np]
if not np.any(needs_prompt_logprobs):
# Common case: No request asks for prompt logprobs.
return {}
prompt_lens = prompt_lens[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
computed_prefill = num_computed_prefill_tokens[idx_mapping_np]
includes_prompt = computed_prefill < prompt_lens - 1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt = prompt_lens < prefill_lens[idx_mapping_np]
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
if not np.any(needs_prompt_logprobs):
return {}
# Get the prompt logprobs token_ids.
prompt_logprobs_token_ids = get_prompt_logprobs_token_ids(
input_batch.num_tokens,
input_batch.query_start_loc,
input_batch.idx_mapping,
num_computed_tokens,
all_token_ids,
)
# Compute the prompt logprobs.
prompt_logprobs, prompt_ranks = compute_prompt_logprobs_with_chunking(
prompt_logprobs_token_ids,
hidden_states[: input_batch.num_tokens],
logits_fn,
)
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
is_prompt_chunked = pos_after_step < prompt_lens
query_start_loc_np = input_batch.query_start_loc_np
prompt_token_ids = prompt_logprobs_token_ids.unsqueeze(-1)
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
start_idx = query_start_loc_np[i]
end_idx = query_start_loc_np[i + 1]
assert start_idx < end_idx, (
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
)
if not is_prompt_chunked[i]:
end_idx -= 1
logprobs = LogprobsTensors(
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
logprobs=prompt_logprobs[start_idx:end_idx],
selected_token_ranks=prompt_ranks[start_idx:end_idx],
)
prompt_logprobs_list = self.in_progress_prompt_logprobs[req_id]
if is_prompt_chunked[i]:
# Prompt is chunked. Do not return the logprobs yet.
prompt_logprobs_list.append(logprobs)
continue
if prompt_logprobs_list:
# Merge the in-progress logprobs.
prompt_logprobs_list.append(logprobs)
logprobs = LogprobsTensors(
logprob_token_ids=torch.cat(
[x.logprob_token_ids for x in prompt_logprobs_list]
),
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
selected_token_ranks=torch.cat(
[x.selected_token_ranks for x in prompt_logprobs_list]
),
)
prompt_logprobs_list.clear()
prompt_logprobs_dict[req_id] = logprobs
return prompt_logprobs_dict
@triton.jit
def _prompt_logprobs_token_ids_kernel(
prompt_logprobs_token_ids_ptr,
query_start_loc_ptr,
idx_mapping_ptr,
num_computed_tokens_ptr,
all_token_ids_ptr,
all_token_ids_stride,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
# NOTE(woosuk): We should shift the pos by one
# because the logprob is computed for the next token.
target_pos = num_computed_tokens + 1 + block
token_ids = tl.load(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + target_pos,
mask=mask,
)
tl.store(
prompt_logprobs_token_ids_ptr + query_start + block, token_ids, mask=mask
)
def get_prompt_logprobs_token_ids(
num_tokens: int,
query_start_loc: torch.Tensor,
idx_mapping: torch.Tensor,
num_computed_tokens: torch.Tensor,
all_token_ids: torch.Tensor,
) -> torch.Tensor:
token_ids = torch.empty(num_tokens, dtype=torch.int64, device=idx_mapping.device)
num_reqs = idx_mapping.shape[0]
_prompt_logprobs_token_ids_kernel[(num_reqs,)](
token_ids,
query_start_loc,
idx_mapping,
num_computed_tokens,
all_token_ids,
all_token_ids.stride(0),
BLOCK_SIZE=1024,
)
return token_ids
def compute_prompt_logprobs_with_chunking(
prompt_token_ids: torch.Tensor,
prompt_hidden_states: torch.Tensor,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE = 1024
logprobs = []
ranks = []
prompt_token_ids = prompt_token_ids.to(torch.int64)
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
end_idx = start_idx + CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
prompt_logprobs = compute_topk_logprobs(
prompt_logits,
0, # num_logprobs
prompt_token_ids[start_idx:end_idx],
)
logprobs.append(prompt_logprobs.logprobs)
ranks.append(prompt_logprobs.selected_token_ranks)
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
return logprobs, ranks

View File

@@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
import vllm.envs as envs
from vllm.config.model import LogprobsMode
from vllm.sampling_params import SamplingParams
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
from vllm.v1.worker.gpu.states import RequestState
class Sampler:
def __init__(
self,
max_num_reqs: int,
vocab_size: int,
device: torch.device,
req_states: RequestState,
logprobs_mode: LogprobsMode = "raw_logprobs",
num_speculative_tokens: int = 1,
):
if logprobs_mode not in ("processed_logprobs", "raw_logprobs"):
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
self.logprobs_mode = logprobs_mode
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
self.penalties_state = PenaltiesState(req_states)
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
self.bad_words_state = BadWordsState(req_states)
self.num_speculative_tokens = num_speculative_tokens
def add_request(
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
) -> None:
self.sampling_states.add_request(req_idx, sampling_params)
self.penalties_state.add_request(req_idx, sampling_params)
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
self.bad_words_state.add_request(req_idx, sampling_params)
def apply_staged_writes(self) -> None:
self.sampling_states.apply_staged_writes()
self.penalties_state.apply_staged_writes()
self.logit_bias_state.apply_staged_writes()
self.bad_words_state.apply_staged_writes()
def __call__(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
cu_num_logits_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> SamplerOutput:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.compute_nans else None
sampled, processed_logits = self.sample(
logits,
idx_mapping,
idx_mapping_np,
pos,
input_ids,
expanded_local_pos,
)
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
if max_num_logprobs != NO_LOGPROBS:
if self.logprobs_mode == "processed_logprobs":
logits = processed_logits
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
logprobs_tensors = compute_topk_logprobs(
logits, max_num_logprobs, sampled, cu_num_logits
)
else:
logprobs_tensors = None
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors,
num_nans=num_nans,
)
return sampler_output
def sample(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
# Apply penalties in place.
self.penalties_state.apply_penalties(
logits,
idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
self.num_speculative_tokens,
)
# Apply bad words masking in place.
self.bad_words_state.apply_bad_words(
logits,
idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
)
# Apply temperature in place.
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
# Apply min_p in place.
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
# Apply top_k and/or top_p. This might or might not return a new tensor.
logits = self.sampling_states.apply_top_k_top_p(
logits, idx_mapping, idx_mapping_np
)
# Sample the next token.
sampled = gumbel_sample(
logits,
idx_mapping,
self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu,
pos,
apply_temperature=False,
)
return sampled, logits

View File

@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
NO_LOGPROBS = -1
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
class SamplingStates:
def __init__(self, max_num_reqs: int, vocab_size: int):
self.max_num_reqs = max_num_reqs
self.vocab_size = vocab_size
self.temperature = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.top_k = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
self.top_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.min_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.seeds = UvaBackedTensor(max_num_reqs, dtype=torch.int64)
# Initialize top_k and top_p manually because 0 is an invalid value for them.
self.top_k.np.fill(self.vocab_size)
self.top_k.copy_to_uva()
self.top_p.np.fill(1.0)
self.top_p.copy_to_uva()
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(NO_LOGPROBS)
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
top_k = sampling_params.top_k
if top_k <= 0 or top_k > self.vocab_size:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.min_p.np[req_idx] = sampling_params.min_p
seed = sampling_params.seed
if seed is None:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
num_logprobs = sampling_params.logprobs
if num_logprobs is None:
num_logprobs = NO_LOGPROBS
self.num_logprobs[req_idx] = num_logprobs
def apply_staged_writes(self) -> None:
self.temperature.copy_to_uva()
self.top_p.copy_to_uva()
self.top_k.copy_to_uva()
self.min_p.copy_to_uva()
self.seeds.copy_to_uva()
def apply_temperature(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
temp_np = self.temperature.np[idx_mapping_np]
if np.all((temp_np == 0.0) | (temp_np == 1.0)):
# No request requires temperature. Skip the kernel launch.
return
apply_temperature(logits, idx_mapping, self.temperature.gpu)
def apply_min_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
# No request uses min_p. Skip the kernel launch.
return
apply_min_p(logits, idx_mapping, self.min_p.gpu)
def apply_top_k_top_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> torch.Tensor:
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
do_top_p = np.any(self.top_p.np[idx_mapping_np] != 1.0)
if not (do_top_k or do_top_p):
return logits
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
return apply_top_k_top_p(logits, top_k, top_p)
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
return int(np.max(self.num_logprobs[idx_mapping_np]))

View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import VllmConfig
def init_speculator(vllm_config: VllmConfig, device: torch.device):
speculative_config = vllm_config.speculative_config
assert speculative_config is not None
if speculative_config.use_eagle():
from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator
return EagleSpeculator(vllm_config, device)
raise NotImplementedError(f"{speculative_config.method} is not supported yet.")

View File

@@ -0,0 +1,191 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.model_executor.offloader.base import get_offloader
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import (
capture_graphs,
get_cudagraph_sizes,
prepare_inputs_to_capture,
)
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.utils import AttentionGroup
class EagleCudaGraphManager:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode()
# only need to capture uniform decode cudagraph sizes (the 2nd return value)
_, self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
uniform_decode_query_len=1,
uniform_decode_cudagraph=True,
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle()
def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens)
def capture_graph(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> None:
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
num_reqs = min(num_tokens, self.max_num_reqs)
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
block_tables,
attn_groups,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=1,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
# Capture the graph.
capture_fn(
num_reqs=num_reqs,
num_tokens=num_tokens,
generate_fn=generate_fn,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
)
def _capture_full_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
with torch.cuda.graph(graph, self.pool):
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader().join_after_forward()
self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.PIECEWISE,
)
@torch.inference_mode()
def capture(
self,
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> None:
if self.cudagraph_mode == CUDAGraphMode.NONE:
return
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
capture_cudagraph_mode=self.cudagraph_mode,
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
generate_fn=generate_fn,
input_buffers=input_buffers,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
)
def run_fullgraph(self, num_tokens: int) -> None:
assert num_tokens in self.graphs
# Sync offloader before replay - needed when transitioning from
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader().sync_prev_onload()
self.graphs[num_tokens].replay()

View File

@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast
import torch.nn as nn
from vllm.config import SpeculativeConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsEagle3, supports_eagle3
logger = init_logger(__name__)
def set_eagle3_aux_hidden_state_layers(
model: nn.Module,
spec_config: SpeculativeConfig,
) -> None:
if not supports_eagle3(model):
raise RuntimeError("Model does not support EAGLE3 interface")
# mypy may infer the class-level overload for supports_eagle3.
# Narrow explicitly to the runtime protocol instance.
if isinstance(model, type):
raise RuntimeError("Expected model instance for EAGLE3 configuration")
eagle3_model = cast(SupportsEagle3, model)
aux_layers = get_eagle3_aux_layers_from_config(spec_config)
if aux_layers:
logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers)
else:
aux_layers = eagle3_model.get_eagle3_aux_hidden_state_layers()
logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers)
eagle3_model.set_aux_hidden_state_layers(aux_layers)
def get_eagle3_aux_layers_from_config(
spec_config: SpeculativeConfig,
) -> tuple[int, ...] | None:
if not (spec_config and spec_config.draft_model_config):
return None
hf_config = spec_config.draft_model_config.hf_config
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
return None
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
if layer_ids and isinstance(layer_ids, (list, tuple)):
return tuple(layer_ids)
return None

View File

@@ -0,0 +1,583 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
from vllm.v1.worker.utils import AttentionGroup
logger = init_logger(__name__)
class EagleSpeculator:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.device = device
self.speculative_config = vllm_config.speculative_config
assert self.speculative_config is not None
self.method = self.speculative_config.method
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
self.draft_model_config = self.speculative_config.draft_model_config
self.scheduler_config = vllm_config.scheduler_config
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = vllm_config.model_config.max_model_len
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
self.vocab_size = self.draft_model_config.get_vocab_size()
self.dtype = vllm_config.model_config.dtype
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
device=device,
)
self.hidden_states = torch.zeros(
self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
)
self.idx_mapping = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
self.temperature = torch.zeros(
self.max_num_reqs, dtype=torch.float32, device=device
)
self.seeds = torch.zeros(self.max_num_reqs, dtype=torch.int64, device=device)
self.draft_tokens = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
dtype=torch.int64,
device=device,
)
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
def load_model(self, target_model: nn.Module) -> None:
self.model = load_eagle_model(target_model, self.vllm_config)
def set_attn(
self,
kv_cache_config: KVCacheConfig,
attn_groups: list[list[AttentionGroup]],
block_tables: BlockTables,
) -> None:
self.kv_cache_config = kv_cache_config
self.attn_groups = attn_groups
self.block_tables = block_tables
@torch.inference_mode()
def run_model(
self,
num_tokens: int,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_descriptor = BatchDescriptor(num_tokens=num_tokens)
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
batch_descriptor=batch_descriptor,
):
ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids[:num_tokens],
positions=self.input_buffers.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
return last_hidden_states, hidden_states
def generate_draft(
self,
num_reqs: int,
num_tokens_padded: int,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
idx_mapping = self.idx_mapping[:num_reqs]
for step in range(1, self.num_speculative_steps):
# Run the eagle model.
last_hidden_states, hidden_states = self.run_model(
num_tokens_padded,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
cudagraph_runtime_mode,
)
last_hidden_states = last_hidden_states[:num_reqs]
hidden_states = hidden_states[:num_reqs]
logits = self.model.compute_logits(last_hidden_states)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits,
idx_mapping,
self.temperature,
self.seeds,
pos + 1,
apply_temperature=True,
)
self.draft_tokens[:num_reqs, step] = draft_tokens
if step < self.num_speculative_steps - 1:
# Update the inputs for the next step.
update_eagle_inputs(
draft_tokens,
hidden_states,
self.input_buffers,
self.hidden_states,
self.max_model_len,
)
self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
def capture_model(self) -> None:
if self.num_speculative_steps == 1:
return
logger.info("Capturing model for Eagle speculator...")
self.cudagraph_manager.capture(
self.generate_draft,
self.input_buffers,
self.block_tables,
self.attn_groups,
self.kv_cache_config,
)
@torch.inference_mode()
def propose(
self,
input_batch: InputBatch,
# [num_tokens, hidden_size]
last_hidden_states: torch.Tensor,
# num_layers x [num_tokens, hidden_size]
aux_hidden_states: list[torch.Tensor] | None,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [max_num_reqs]
last_sampled: torch.Tensor,
# [max_num_reqs]
next_prefill_tokens: torch.Tensor,
# [max_num_reqs]
temperature: torch.Tensor,
# [max_num_reqs]
seeds: torch.Tensor,
) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
# hidden_states the same as the target model's. This means, we pad each
# request's query length to include any rejected positions. By doing so,
# we can also reuse the attention metadata (e.g., query_start_loc,
# seq_lens) of the target model.
if aux_hidden_states:
assert self.method == "eagle3"
hidden_states = self.model.combine_hidden_states(
torch.cat(aux_hidden_states, dim=-1)
)
else:
hidden_states = last_hidden_states
num_tokens = input_batch.num_tokens_after_padding
self.hidden_states[:num_tokens] = hidden_states
# Get the input ids and last token indices for the speculator.
last_token_indices = prepare_eagle_inputs(
self.input_buffers,
input_batch,
num_sampled,
num_rejected,
last_sampled,
next_prefill_tokens,
)
# Prefill: Run the eagle speculator with eager mode.
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states, hidden_states = self.run_model(
num_tokens,
input_batch.attn_metadata,
input_batch.slot_mappings,
num_tokens_across_dp=None, # FIXME
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
idx_mapping = self.idx_mapping[:num_reqs]
idx_mapping.copy_(input_batch.idx_mapping)
self.temperature.copy_(temperature)
self.seeds.copy_(seeds)
# Gather the values and copy them to the pre-allocated buffers.
pos = self.input_buffers.positions[:num_reqs]
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits,
idx_mapping,
self.temperature,
self.seeds,
pos + 1,
apply_temperature=True,
)
if self.num_speculative_steps == 1:
# Early exit.
return draft_tokens.view(-1, 1)
# Save the draft tokens for the first step.
self.draft_tokens[:num_reqs, 0] = draft_tokens
# Prepare the inputs for the decode steps.
prepare_eagle_decode(
draft_tokens,
hidden_states,
last_token_indices,
input_batch.seq_lens,
num_rejected,
self.input_buffers,
self.hidden_states,
self.max_model_len,
self.max_num_reqs,
)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
cudagraph_mode = self.cudagraph_manager.cudagraph_mode
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
# Run full CUDA graph.
self.cudagraph_manager.run_fullgraph(cudagraph_size)
return self.draft_tokens[:num_reqs]
# Run eager or piecewise CUDA graph.
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu"
)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
self.generate_draft(
num_reqs,
num_tokens_padded,
attn_metadata,
slot_mappings_by_layer,
num_tokens_across_dp=None, # FIXME
cudagraph_runtime_mode=cudagraph_mode,
)
return self.draft_tokens[:num_reqs]
@triton.jit
def _prepare_eagle_inputs_kernel(
last_token_indices_ptr,
eagle_input_ids_ptr,
eagle_positions_ptr,
target_input_ids_ptr,
target_positions_ptr,
idx_mapping_ptr,
last_sampled_ptr,
next_prefill_tokens_ptr,
num_sampled_ptr,
num_rejected_ptr,
query_start_loc_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
# Get the true query length and next token after accounting for rejected tokens.
num_rejected = tl.load(num_rejected_ptr + batch_idx)
query_len -= num_rejected
num_sampled = tl.load(num_sampled_ptr + batch_idx)
if num_sampled > 0:
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
else:
# Chunked prefilling.
# Get the next prefill token.
next_token = tl.load(next_prefill_tokens_ptr + req_state_idx)
# Shift target_input_ids by one.
for i in range(1, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
input_ids = tl.load(target_input_ids_ptr + query_start + block, mask=mask)
tl.store(eagle_input_ids_ptr + query_start + block - 1, input_ids, mask=mask)
last_token_index = query_start + query_len - 1
tl.store(last_token_indices_ptr + batch_idx, last_token_index)
tl.store(eagle_input_ids_ptr + last_token_index, next_token)
# Copy positions.
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
target_pos = tl.load(target_positions_ptr + query_start + block, mask=mask)
tl.store(eagle_positions_ptr + query_start + block, target_pos, mask=mask)
def prepare_eagle_inputs(
input_buffers: InputBuffers,
input_batch: InputBatch,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [max_num_reqs]
last_sampled: torch.Tensor,
# [max_num_reqs]
next_prefill_tokens: torch.Tensor,
) -> torch.Tensor:
num_reqs = input_batch.num_reqs
last_token_indices = torch.empty(
num_reqs,
dtype=torch.int64,
device=num_sampled.device,
)
_prepare_eagle_inputs_kernel[(num_reqs,)](
last_token_indices,
input_buffers.input_ids,
input_buffers.positions,
input_batch.input_ids,
input_batch.positions,
input_batch.idx_mapping,
last_sampled,
next_prefill_tokens,
num_sampled,
num_rejected,
input_batch.query_start_loc,
BLOCK_SIZE=1024,
)
return last_token_indices
@triton.jit
def _prepare_eagle_docode_kernel(
draft_tokens_ptr,
output_hidden_states_ptr,
output_hidden_states_stride,
last_token_indices_ptr,
target_seq_lens_ptr,
num_rejected_ptr,
input_ids_ptr,
positions_ptr,
input_hidden_states_ptr,
input_hidden_states_stride,
query_start_loc_ptr,
seq_lens_ptr,
hidden_size,
max_model_len,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_reqs = tl.num_programs(0) - 1
if req_idx == num_reqs:
# Compute query_start_loc. Pad it with the last query_start_loc
# for CUDA graphs.
for i in range(0, max_num_reqs + 1, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
q = tl.where(block < num_reqs, block, num_reqs)
mask = block < max_num_reqs + 1
tl.store(query_start_loc_ptr + block, q, mask=mask)
# Pad seq_lens for CUDA graphs.
for i in range(req_idx, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(seq_lens_ptr + block, 0, mask=mask)
return
# draft token -> input id.
draft_token = tl.load(draft_tokens_ptr + req_idx)
tl.store(input_ids_ptr + req_idx, draft_token)
# output hidden states -> input hidden states.
src_idx = tl.load(last_token_indices_ptr + req_idx)
for i in range(0, hidden_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < hidden_size
output_hidden_states = tl.load(
output_hidden_states_ptr + src_idx * output_hidden_states_stride + block,
mask=mask,
)
tl.store(
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
output_hidden_states,
mask=mask,
)
# Compute position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position = tl.load(positions_ptr + req_idx)
position = tl.minimum(position + 1, max_model_len - 1)
tl.store(positions_ptr + req_idx, position)
target_seq_len = tl.load(target_seq_lens_ptr + req_idx)
num_rejected = tl.load(num_rejected_ptr + req_idx)
seq_len = target_seq_len - num_rejected
seq_len = tl.minimum(seq_len + 1, max_model_len)
tl.store(seq_lens_ptr + req_idx, seq_len)
def prepare_eagle_decode(
draft_tokens: torch.Tensor,
output_hidden_states: torch.Tensor,
last_token_indices: torch.Tensor,
target_seq_lens: torch.Tensor,
num_rejected: torch.Tensor,
input_buffers: InputBuffers,
input_hidden_states: torch.Tensor,
max_model_len: int,
max_num_reqs: int,
):
num_reqs = draft_tokens.shape[0]
hidden_size = output_hidden_states.shape[-1]
_prepare_eagle_docode_kernel[(num_reqs + 1,)](
draft_tokens,
output_hidden_states,
output_hidden_states.stride(0),
last_token_indices,
target_seq_lens,
num_rejected,
input_buffers.input_ids,
input_buffers.positions,
input_hidden_states,
input_hidden_states.stride(0),
input_buffers.query_start_loc,
input_buffers.seq_lens,
hidden_size,
max_model_len,
max_num_reqs,
BLOCK_SIZE=1024,
)
@triton.jit
def _update_eagle_inputs_kernel(
input_ids_ptr,
positions_ptr,
input_hidden_states_ptr,
input_hidden_states_stride,
seq_lens_ptr,
max_model_len,
draft_tokens_ptr,
output_hidden_states_ptr,
output_hidden_states_stride,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
# Draft token -> Input ID.
draft_token = tl.load(draft_tokens_ptr + req_idx)
tl.store(input_ids_ptr + req_idx, draft_token)
# Output hidden states -> Input hidden states.
for i in range(0, hidden_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < hidden_size
output_hidden_states = tl.load(
output_hidden_states_ptr + req_idx * output_hidden_states_stride + block,
mask=mask,
)
tl.store(
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
output_hidden_states,
mask=mask,
)
# Increment position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position = tl.load(positions_ptr + req_idx)
position = tl.minimum(position + 1, max_model_len - 1)
tl.store(positions_ptr + req_idx, position)
seq_len = tl.load(seq_lens_ptr + req_idx)
seq_len = tl.minimum(seq_len + 1, max_model_len)
tl.store(seq_lens_ptr + req_idx, seq_len)
def update_eagle_inputs(
draft_tokens: torch.Tensor,
output_hidden_states: torch.Tensor,
input_buffers: InputBuffers,
hidden_states: torch.Tensor,
max_model_len: int,
):
num_reqs, hidden_size = output_hidden_states.shape
_update_eagle_inputs_kernel[(num_reqs,)](
input_buffers.input_ids,
input_buffers.positions,
hidden_states,
hidden_states.stride(0),
input_buffers.seq_lens,
max_model_len,
draft_tokens,
output_hidden_states,
output_hidden_states.stride(0),
hidden_size,
BLOCK_SIZE=1024,
)

View File

@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.model_loader import get_model
def load_eagle_model(target_model: nn.Module, vllm_config: VllmConfig) -> nn.Module:
from vllm.compilation.backends import set_model_tag
speculative_config = vllm_config.speculative_config
assert speculative_config is not None
draft_model_config = speculative_config.draft_model_config
with set_model_tag("eagle_head"):
eagle_model = get_model(
vllm_config=vllm_config, model_config=draft_model_config
)
# Share target embeddings when the draft checkpoint does not include
# its own vocab embedding table.
share_embeddings = True
if hasattr(eagle_model, "has_own_embed_tokens"):
share_embeddings = not eagle_model.has_own_embed_tokens
if share_embeddings:
target_language_model = (
target_model.get_language_model()
if hasattr(target_model, "get_language_model")
else target_model
)
inner_model = getattr(target_language_model, "model", None)
target_embed_tokens = None
if inner_model is not None:
if hasattr(inner_model, "embed_tokens"):
target_embed_tokens = inner_model.embed_tokens
elif hasattr(inner_model, "embedding"):
target_embed_tokens = inner_model.embedding
if target_embed_tokens is not None and hasattr(eagle_model, "model"):
if hasattr(eagle_model.model, "embed_tokens"):
del eagle_model.model.embed_tokens
eagle_model.model.embed_tokens = target_embed_tokens
# Only share target lm_head when the draft model does not own one.
share_lm_head = True
if hasattr(eagle_model, "has_own_lm_head"):
share_lm_head = not eagle_model.has_own_lm_head
if share_lm_head and hasattr(target_model, "lm_head"):
if hasattr(eagle_model, "lm_head"):
del eagle_model.lm_head
eagle_model.lm_head = target_model.lm_head
return eagle_model

View File

@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _rejection_sample_kernel(
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
sampled_stride,
num_sampled_ptr, # [num_reqs]
target_sampled_ptr, # [num_draft_tokens + num_reqs]
input_ids_ptr, # [num_draft_tokens + num_reqs]
cu_num_logits_ptr, # [num_reqs + 1]
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
num_sampled = 0
rejected = False
for i in range(num_tokens - 1):
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
num_sampled += 1
if target_sampled != draft_sampled:
rejected = True
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
)
num_sampled += 1
tl.store(num_sampled_ptr + req_idx, num_sampled)
def rejection_sample(
# [num_draft_tokens + num_reqs]
target_sampled: torch.Tensor,
# [num_draft_tokens + num_reqs]
input_ids: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_sampled,
input_ids,
cu_num_logits,
num_warps=1,
)
return sampled, num_sampled

View File

@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.v1.outputs import DraftTokenIds
from vllm.v1.worker.gpu.async_utils import async_copy_to_np
from vllm.v1.worker.gpu.input_batch import InputBatch
class DraftTokensHandler:
def __init__(self, device: torch.device | None = None):
self.device = device
self.copy_stream = torch.cuda.Stream(device)
self.copy_event = torch.cuda.Event()
self.req_ids: list[str] = []
self.draft_tokens_np: np.ndarray | None = None
self.num_draft_tokens: int = 0
def set_draft_tokens(
self, input_batch: InputBatch, draft_tokens: torch.Tensor
) -> None:
self.req_ids = input_batch.req_ids
self.num_draft_tokens = draft_tokens.shape[1]
if not input_batch.has_structured_output_reqs:
# No draft token validation needs to be performed by
# the scheduler for this batch.
self.draft_tokens_np = None
return
# For spec decoding + structured outputs, we must transfer the
# draft tokens back to the scheduler for grammar validation.
current_stream = torch.cuda.current_stream(self.device)
self.copy_stream.wait_stream(current_stream)
with torch.cuda.stream(self.copy_stream):
self.draft_tokens_np = async_copy_to_np(draft_tokens)
self.copy_event.record()
def get_draft_tokens(self) -> DraftTokenIds | None:
if self.draft_tokens_np is not None:
self.copy_event.synchronize()
draft_token_ids = self.draft_tokens_np.tolist()
else:
# This case only happens when async scheduling is disabled.
draft_token_ids = [[-1] * self.num_draft_tokens for _ in self.req_ids]
return DraftTokenIds(self.req_ids, draft_token_ids)

View File

@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
num_speculative_steps: int,
vocab_size: int,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.num_speculative_steps = num_speculative_steps
self.vocab_size = vocab_size
self.device = device
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs))
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# depending on the configured max_num_reqs and max_model_len.
# To save GPU memory, we use UVA instead of GPU for this tensor.
self.all_token_ids = StagedWriteTensor(
(self.max_num_reqs, self.max_model_len),
dtype=torch.int32,
device=device,
uva_instead_of_gpu=True,
)
# NOTE(woosuk): Distinguish clearly between prompt_len and prefill_len:
# - prompt_len: Number of tokens in the user-provided prompt.
# - prefill_len: Number of tokens passed into the model runner.
# This can include the prompt and additional partial output tokens,
# so prefill_len >= prompt_len.
# Usually, prefill_len equals prompt_len, but in cases such as resumption after
# preemption, prefill_len may be greater. Differentiating between these values
# is crucial, as certain features such as prompt logprobs or frequency penalties
# must treat prompt and output tokens separately.
self.prompt_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
# total_len = prompt_len + output_len. It grows as the request progresses.
self.total_len = StagedWriteTensor(
self.max_num_reqs, dtype=torch.int32, device=device
)
# Number of computed tokens.
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = StagedWriteTensor(
self.max_num_reqs, dtype=torch.int32, device=device
)
# Last sampled tokens.
self.last_sampled_tokens = torch.zeros(
self.max_num_reqs, 1, dtype=torch.int64, device=device
)
# Draft tokens.
self.draft_tokens = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
dtype=torch.int64,
device=device,
)
self.next_prefill_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
def add_request(
self,
req_id: str,
prompt_len: int,
all_token_ids: list[int],
num_computed_tokens: int,
) -> None:
assert len(self.free_indices) > 0, "No free indices"
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
self.prompt_len.np[req_idx] = prompt_len
prefill_len = len(all_token_ids)
assert prefill_len >= prompt_len, (
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
)
self.prefill_len.np[req_idx] = prefill_len
self.total_len.stage_write_elem(req_idx, prefill_len)
self.all_token_ids.stage_write(req_idx, 0, all_token_ids)
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
def apply_staged_writes(self) -> None:
self.prompt_len.copy_to_uva()
self.prefill_len.copy_to_uva()
self.total_len.apply_write()
self.all_token_ids.apply_write()
self.num_computed_tokens.apply_write()
def remove_request(self, req_id: str) -> None:
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
def any_prefills(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(
self.num_computed_prefill_tokens[idx_mapping_np]
< self.prefill_len.np[idx_mapping_np]
)

View File

@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.input_batch import InputBatch
class StructuredOutputsWorker:
def __init__(self, max_num_logits: int, vocab_size: int, device: torch.device):
self.logits_indices = torch.zeros(
max_num_logits, dtype=torch.int32, device=device
)
self.grammar_bitmask = torch.zeros(
(max_num_logits, cdiv(vocab_size, 32)), dtype=torch.int32, device=device
)
self.device = device
self.copy_stream = torch.cuda.Stream()
def apply_grammar_bitmask(
self,
logits: torch.Tensor,
input_batch: InputBatch,
grammar_req_ids: list[str],
grammar_bitmask: np.ndarray,
) -> None:
if not grammar_req_ids:
return
# Asynchronously copy the bitmask to GPU.
with torch.cuda.stream(self.copy_stream):
bitmask = async_copy_to_gpu(
grammar_bitmask, out=self.grammar_bitmask[: grammar_bitmask.shape[0]]
)
# Construct bitmask -> logits mapping
mapping: list[int] = []
req_ids = input_batch.req_ids
cu_num_logits = input_batch.cu_num_logits_np.tolist()
req_id_to_idx = {req_id: i for i, req_id in enumerate(req_ids)}
for grammar_req_id in grammar_req_ids:
req_idx = req_id_to_idx[grammar_req_id]
logits_start_idx = cu_num_logits[req_idx]
logits_end_idx = cu_num_logits[req_idx + 1]
mapping.extend(range(logits_start_idx, logits_end_idx))
# Asynchronously copy the mapping to GPU.
with torch.cuda.stream(self.copy_stream):
logits_indices = torch.tensor(
mapping, dtype=torch.int32, device="cpu", pin_memory=True
)
logits_indices = self.logits_indices[: len(mapping)].copy_(
logits_indices, non_blocking=True
)
# Ensure all async copies are complete before launching the kernel.
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(self.copy_stream)
num_masks = bitmask.shape[0]
assert num_masks == len(mapping)
vocab_size = logits.shape[-1]
BLOCK_SIZE = 8192
grid = (num_masks, triton.cdiv(vocab_size, BLOCK_SIZE))
_apply_grammar_bitmask_kernel[grid](
logits,
logits.stride(0),
logits_indices,
bitmask,
bitmask.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
# Ensure the copy stream waits for the device tensors to finish being used
# before it re-uses or deallocates them
self.copy_stream.wait_stream(current_stream)
# Adapted from
# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
@triton.jit
def _apply_grammar_bitmask_kernel(
logits_ptr,
logits_stride,
logits_indices_ptr,
bitmask_ptr,
bitmask_stride,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
bitmask_idx = tl.program_id(0)
logits_idx = tl.load(logits_indices_ptr + bitmask_idx)
# Load the bitmask.
block_id = tl.program_id(1)
bitmask_offset = (block_id * BLOCK_SIZE) // 32 + tl.arange(0, BLOCK_SIZE // 32)
packed_bitmask = tl.load(
bitmask_ptr + bitmask_idx * bitmask_stride + bitmask_offset,
mask=bitmask_offset < bitmask_stride,
)
# Unpack the bitmask.
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)
# Apply the bitmask to the logits.
block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
logits_ptr + logits_idx * logits_stride + block_offset,
-float("inf"),
mask=bitmask & (block_offset < vocab_size),
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,494 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
import torch
import vllm.envs as envs
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import (
DPMetadata,
create_forward_context,
get_forward_context,
override_forward_context,
)
from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
logger = init_logger(__name__)
@dataclass
class UbatchMetadata:
context: UBatchContext
input_ids: torch.Tensor
positions: torch.Tensor
inputs_embeds: torch.Tensor | None
intermediate_tensors: IntermediateTensors | None
num_tokens: int
@dataclass
class CUDAGraphMetaData:
cudagraph: torch.cuda.CUDAGraph
ubatch_metadata: UbatchMetadata
outputs: Any | None = None
class SMControlContextManager:
def __init__(
self,
comm_sms: int,
set_comm_sms: Callable[[int], None],
set_compute_sms: Callable[[int], None],
):
"""
Context manager for controlling SM (Streaming Multiprocessor)
allocation. Upon entering the context, it sets the number of SMs
allocated for communication and computation to comm_sms and
total_sms - comm_sms respectively. Upon exiting, it restores the
allocation to use all available SMs (i.e. total_sms).
Args:
comm_sms (int): The number of SMs to allocate for communication.
(The remainder will be used for computation.)
set_comm_sms (Callable[[int], None]):
A function that sets the number of SMs for communication.
set_compute_sms (Callable[[int], None]):
A function that sets the number of SMs for computation.
"""
assert current_platform.is_cuda(), (
"SM control is currently only supported on CUDA"
)
total_sms = num_compute_units(torch.cuda.current_device())
assert comm_sms < total_sms
self.total_sms = total_sms
self.compute_sms = total_sms - comm_sms
self.comm_sms = comm_sms
self.set_comm_sms = set_comm_sms
self.set_compute_sms = set_compute_sms
def __enter__(self):
self.set_comm_sms(self.comm_sms)
self.set_compute_sms(self.compute_sms)
def __exit__(self, exc_type, exc_value, traceback):
self.set_comm_sms(self.total_sms)
self.set_compute_sms(self.total_sms)
class UBatchWrapper:
def __init__(
self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
device: torch.cuda.device,
):
self.runnable = runnable
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.comm_stream = torch.cuda.Stream(device=device)
# Ubatch threads plus the main thread
self.ready_barrier = threading.Barrier(
self.vllm_config.parallel_config.num_ubatches + 1
)
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
self.cudagraph_wrapper = None
self.graph_pool = None
if runtime_mode is not CUDAGraphMode.NONE:
self.cudagraph_wrapper = CUDAGraphWrapper(
runnable, vllm_config, runtime_mode=runtime_mode
)
self.graph_pool = current_platform.get_global_graph_pool()
self.sm_control = self._create_sm_control_context(vllm_config)
self.device = device
@staticmethod
def _create_sm_control_context(vllm_config: VllmConfig):
comm_sms: int = envs.VLLM_DBO_COMM_SMS
set_comm_sms = lambda sms: None
if vllm_config.parallel_config.enable_expert_parallel:
# Currently only DeepEP highthroughput supports SM control so this
# only affects that case.
ep_group = get_ep_group()
device_communicator = ep_group.device_communicator
all2all_manager = None
if device_communicator is not None:
all2all_manager = device_communicator.all2all_manager
if all2all_manager is not None:
max_sms_used = all2all_manager.max_sms_used()
if max_sms_used is not None:
comm_sms = min(comm_sms, max_sms_used)
if comm_sms > 0 and all2all_manager is not None:
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
# TODO(lucas): support other kernels besides DeepGEMM
set_compute_sms = lambda sms: None
if has_deep_gemm() and comm_sms > 0:
import deep_gemm as dg
set_compute_sms = lambda sms: dg.set_num_sms(sms)
return SMControlContextManager(
comm_sms=comm_sms,
set_comm_sms=set_comm_sms,
set_compute_sms=set_compute_sms,
)
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(
f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}"
)
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
"""
Capture a cudagraph for a microbatched run.
The logic here is somewhat complicated because we need to make sure that
each of the ubatch threads initialize the cuda context before we start
the graph capture.
The flow is as follows:
1. The main thread starts up each ubatch thread. Each thread will
initialize its cuda context (torch.cuda.current_blas_handle())
before going to sleep upon entering the ubatch_context.
2. The main thread starts the graph capture and wakes up the first
ubatch thread.
3. Each ubatch thread runs the model to completion and returns the
completed output tensors back to the main thread.
4. The main thread stores the captured cudagraph along with its metadata
and returns
"""
@torch.inference_mode()
def _capture_ubatch_thread(results, ubatch_metadata):
torch.cuda.set_device(self.device)
ubatch_context = ubatch_metadata.context
with torch.cuda.stream(ubatch_context.compute_stream):
_ = torch.cuda.current_blas_handle()
with torch.cuda.stream(ubatch_context.comm_stream):
_ = torch.cuda.current_blas_handle()
with ubatch_context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
results: list[tuple[int, torch.Tensor]] = []
compute_stream = ubatch_metadata[0].context.compute_stream
num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(
target=_capture_ubatch_thread,
args=(
results,
metadata,
),
)
ubatch_threads.append(thread)
thread.start()
self.ready_barrier.wait() # Wait for both threads to be ready
# Capture the cudagraph
cudagraph_metadata = CUDAGraphMetaData(
cudagraph=torch.cuda.CUDAGraph(),
ubatch_metadata=ubatch_metadata,
)
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
with torch.cuda.graph(
cudagraph_metadata.cudagraph,
stream=compute_stream,
pool=self.graph_pool,
):
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
cudagraph_metadata.outputs = result
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader().join_after_forward()
self.cudagraphs[num_tokens] = cudagraph_metadata
return cudagraph_metadata.outputs
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
@torch.inference_mode()
def _ubatch_thread(results, model, ubatch_metadata):
with ubatch_metadata.context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
results: list[tuple[int, torch.Tensor]] = []
# Ubatch threads will manually manage the forward context, so we
# override it to None here so we can have it restored correctly
# after both threads have finished
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(
target=_ubatch_thread,
args=(
results,
model,
metadata,
),
)
ubatch_threads.append(thread)
thread.start()
self.ready_barrier.wait() # Wait for both threads to be ready
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
return result
def _make_ubatch_metadata(
self,
ubatch_slices,
attn_metadata,
slot_mapping,
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
compute_stream,
dp_metadata,
batch_descriptor,
cudagraph_runtime_mode,
) -> list[UbatchMetadata]:
# Create one forward context per ubatch
forward_contexts = []
# slot_mapping can be None, an empty dict (from create_forward_context
# converting None to {}), or a list of dicts (one per ubatch)
has_slot_mapping = slot_mapping and isinstance(slot_mapping, list)
for i, ubatch_slice in enumerate(ubatch_slices):
forward_contexts.append(
create_forward_context(
attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config,
dp_metadata=dp_metadata[i],
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping[i] if has_slot_mapping else None,
)
)
ubatch_ctxs = make_ubatch_contexts(
num_micro_batches=len(ubatch_slices),
comm_stream=self.comm_stream,
compute_stream=compute_stream,
forward_contexts=forward_contexts,
ready_barrier=self.ready_barrier,
)
ubatch_metadata: list[UbatchMetadata] = []
for i, ubatch_slice in enumerate(ubatch_slices):
(
sliced_input_ids,
sliced_positions,
sliced_inputs_embeds,
sliced_intermediate_tensors,
) = self._slice_model_inputs(
ubatch_slice.token_slice,
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
)
ubatch_metadata.append(
UbatchMetadata(
context=ubatch_ctxs[i],
input_ids=sliced_input_ids,
positions=sliced_positions,
inputs_embeds=sliced_inputs_embeds,
intermediate_tensors=sliced_intermediate_tensors,
num_tokens=ubatch_slice.token_slice.stop
- ubatch_slice.token_slice.start,
)
)
return ubatch_metadata
def _slice_model_inputs(
self,
tokens_slice: slice,
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
):
sliced_input_ids = input_ids[tokens_slice]
# if we are using mrope. Mrope adds an additional dimension to the
# positions tensor
if positions.ndim == 2:
sliced_positions = positions[:, tokens_slice]
else:
sliced_positions = positions[tokens_slice]
sliced_inputs_embeds = inputs_embeds[tokens_slice] if inputs_embeds else None
sliced_intermediate_tensors = (
intermediate_tensors[tokens_slice] if intermediate_tensors else None
)
return (
sliced_input_ids,
sliced_positions,
sliced_inputs_embeds,
sliced_intermediate_tensors,
)
def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
ubatch_slices = forward_context.ubatch_slices
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
# If there's no ubatching, just run the runnable object
if ubatch_slices is None:
# This is to account for the case where ubatching was aborted.
# When we capture full graphs we only capture one graph per shape,
# meaning that if we have a ubatched cudagraph for the current
# num_tokens, we don't have a non-ubatched one. Without this
# check, the cudagraph wrapper will try to capture a cudagraph
# for this shape during a normal run.
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
assert batch_descriptor is not None
if batch_descriptor.num_tokens in self.cudagraphs:
cudagraph_runtime_mode = CUDAGraphMode.NONE
if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE):
return self.runnable(*args, **kwargs)
else:
assert self.cudagraph_wrapper is not None
return self.cudagraph_wrapper(*args, **kwargs)
attn_metadata = forward_context.attn_metadata
slot_mapping = forward_context.slot_mapping
num_tokens = sum(ubatch_slice.num_tokens for ubatch_slice in ubatch_slices)
input_ids = kwargs["input_ids"]
positions = kwargs["positions"]
intermediate_tensors = kwargs["intermediate_tensors"]
inputs_embeds = kwargs["inputs_embeds"]
compute_stream = torch.cuda.current_stream()
dp_metadata = forward_context.dp_metadata
# We shouldn't be here unless we are running with multiple DP ranks
assert dp_metadata is not None
ubatch_dp_metadata = []
for ubatch_slice in ubatch_slices:
dp_size = self.vllm_config.parallel_config.data_parallel_size
ubatch_num_tokens_across_dp = torch.tensor(
[ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32
)
ubatch_dp_metadata.append(
DPMetadata.make(
self.vllm_config.parallel_config,
ubatch_slice.num_tokens,
ubatch_num_tokens_across_dp,
)
)
if (
num_tokens not in self.cudagraphs
and cudagraph_runtime_mode is CUDAGraphMode.FULL
):
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
slot_mapping=slot_mapping,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=ubatch_dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
)
with self.sm_control:
return self._capture_ubatches(ubatch_metadata, self.model)
elif (
num_tokens in self.cudagraphs
and cudagraph_runtime_mode is CUDAGraphMode.FULL
):
cudagraph_metadata = self.cudagraphs[num_tokens]
# Sync offloader before replay - ensures any external dependencies
# from pre-capture prefetches are satisfied.
get_offloader().sync_prev_onload()
cudagraph_metadata.cudagraph.replay()
return cudagraph_metadata.outputs
else:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
slot_mapping=slot_mapping,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=ubatch_dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
)
with self.sm_control:
return self._run_ubatches(ubatch_metadata, self.model)

1138
vllm/v1/worker/gpu_worker.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,283 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Define KV connector functionality mixin for model runners.
"""
import copy
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_shutdown,
get_kv_transfer_group,
has_kv_transfer_group,
)
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
ModelRunnerOutput,
)
from vllm.v1.worker.utils import AttentionGroup
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU)
class KVConnectorModelRunnerMixin:
@staticmethod
def ensure_kv_transfer_shutdown() -> None:
# has_kv_transfer_group can be None during interpreter shutdown.
if has_kv_transfer_group and has_kv_transfer_group(): # type: ignore[truthy-function]
ensure_kv_transfer_shutdown()
@staticmethod
def kv_connector_no_forward(
scheduler_output: "SchedulerOutput", vllm_config: VllmConfig
) -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with (
set_forward_context(None, vllm_config),
KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output, wait_for_save=False
) as kv_connector_output,
):
pass
if kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
@staticmethod
def maybe_get_kv_connector_output(
scheduler_output: "SchedulerOutput",
clear_metadata: bool = True,
) -> AbstractContextManager[KVConnectorOutput | None]:
return (
KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output, clear_metadata=clear_metadata
)
if has_kv_transfer_group()
else nullcontext()
)
# This context manager must be used within an active forward context.
# It encapsulates the entire KV connector lifecycle within execute_model
@staticmethod
@contextmanager
def _get_kv_connector_output(
scheduler_output: "SchedulerOutput",
wait_for_save: bool = True,
clear_metadata: bool = True,
) -> Generator[KVConnectorOutput, None, None]:
output = KVConnectorOutput()
# Update KVConnector with the KVConnector metadata forward().
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
try:
yield output
finally:
if wait_for_save:
kv_connector.wait_for_save()
output.finished_sending, output.finished_recving = (
kv_connector.get_finished(scheduler_output.finished_req_ids)
)
output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
if clear_metadata:
kv_connector.clear_connector_metadata()
@staticmethod
def clear_kv_connector_metadata() -> None:
"""Clear the KV connector metadata. Call after draft model runs."""
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
kv_connector.clear_connector_metadata()
@staticmethod
def use_uniform_kv_cache(
attn_groups: list[list[AttentionGroup]],
cache_dtype: CacheDType,
) -> bool:
"""
Determines whether a uniform KV layout should be used.
A uniform layout means all layers KV caches will share the same
underlying tensor, where for a given block number, the respective
KV data for all layers will be contiguous.
This will allow efficient KV transfer of per-block KV data for all
layers at once.
Note this layout will only be applied given 3 conditions:
1. The KV Cache config contains just a single group where all layers
have the same page size.
2. A KV connector is configured, and the KV connector instance prefers
to use this layout (prefer_cross_layer_blocks() returns True)
2. The flash attention backend supports this layout
(get_kv_cache_stride_order(True) includes a placement for a
num_layers dimension)
Note that the actual placement of the num_layers dimensions
in the unified layers tensors will be determined by the attention
backend.
Thus, the layers KV data may still not be contiguous per block
if the attention backend does not support it.
Args:
attn_groups: The list of attention groups for this model
cache_dtype: The KV cache dtype
Returns:
True if we should use a uniform KV cache layout.
"""
if not has_kv_transfer_group():
return False
if not get_kv_transfer_group().prefer_cross_layer_blocks:
return False
if len(attn_groups) != 1 or len(attn_groups[0]) != 1:
return False
attn_group = attn_groups[0][0]
kv_cache_spec = attn_group.kv_cache_spec
if not isinstance(kv_cache_spec, AttentionSpec):
return False
attn_backend = attn_group.backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
1234,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
except (AttributeError, NotImplementedError):
return False
# check that attention backend include a layers dimension
return len(kv_cache_stride_order) == len(kv_cache_shape) + 1
@staticmethod
def allocate_uniform_kv_caches(
kv_cache_config: KVCacheConfig,
attn_groups: list[list[AttentionGroup]],
cache_dtype: CacheDType,
device: torch.device,
kernel_block_sizes: list[int],
) -> tuple[dict[str, torch.Tensor], torch.Tensor, type[AttentionBackend]]:
"""
Initializes and reshapes KV caches for the simple case where all
layers have the same layout.
This function assumes use_uniform_kv_cache() returned True.
Args:
kv_cache_config: The KV cache config
attn_groups: The list of attention groups for this model
cache_dtype: The KV cache dtype
device: The torch device to allocate on.
kernel_block_sizes: The kernel block sizes for each KV cache group.
Returns:
A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where:
kv_caches is a dict mapping between layer names to their
corresponding memory buffer for KV cache.
cross_layers_kv_cache is the cross layers kv cache tensor
attn_backend is the attention backend matching this tensor
"""
attn_group = attn_groups[0][0]
kv_cache_spec = attn_group.kv_cache_spec
assert isinstance(kv_cache_spec, AttentionSpec)
tensor_sizes = set(
kv_cache_tensor.size for kv_cache_tensor in kv_cache_config.kv_cache_tensors
)
assert len(tensor_sizes) == 1
tensor_size = tensor_sizes.pop()
page_size = kv_cache_spec.page_size_bytes
assert tensor_size % page_size == 0
num_blocks = tensor_size // page_size
num_layers = len(kv_cache_config.kv_cache_tensors)
total_size = tensor_size * num_layers
assert len(kernel_block_sizes) == 1
kernel_block_size = kernel_block_sizes[0]
num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
attn_backend = attn_group.backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
)
# prepend a num_layers dimension into the shape
kv_cache_shape = (num_layers,) + kv_cache_shape
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
logger.info("Allocating a cross layer KV cache of shape %s", kv_cache_shape)
# allocate one contiguous buffer for all layers
cross_layers_kv_cache = (
torch.zeros(total_size, dtype=torch.int8, device=device)
.view(kv_cache_spec.dtype)
.view(kv_cache_shape)
)
# Maintain original KV shape view.
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]
permuted_kv_cache = cross_layers_kv_cache.permute(*inv_order)
kv_caches = {}
for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
tensor = permuted_kv_cache[i]
for layer_name in kv_cache_tensor.shared_by:
kv_caches[layer_name] = tensor
return kv_caches, cross_layers_kv_cache, attn_backend

View File

@@ -0,0 +1,285 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Define LoRA functionality mixin for model runners.
"""
from contextlib import contextmanager
from typing import TypeAlias
import numpy as np
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping, LoRAMappingType
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models import supports_lora
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch
logger = init_logger(__name__)
# Defined as a mixin for GPUModelRunner
class LoRAModelRunnerMixin:
def load_lora_model(
self,
model: nn.Module,
vllm_config: VllmConfig,
device: torch.device,
) -> nn.Module:
if not supports_lora(model):
raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")
# Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager(
vllm_config,
device,
model.embedding_modules,
)
return self.lora_manager.create_lora_manager(model, vllm_config)
def _set_active_loras(
self,
prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: tuple[int, ...],
lora_requests: set[LoRARequest],
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
) -> None:
self._ensure_lora_enabled()
# Set is_prefill to True, so we always use the SGMV kernels on
# non-cuda platforms.
# On cuda platforms we use the same kernels for prefill and
# decode and this flag is generally ignored.
lora_mapping = LoRAMapping(
token_lora_mapping,
prompt_lora_mapping,
is_prefill=True,
type=mapping_type,
)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def _ensure_lora_enabled(self) -> None:
if not hasattr(self, "lora_manager"):
raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")
def set_active_loras(
self,
input_batch: InputBatch,
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray | None = None,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
) -> None:
if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
prompt_lora_mapping: tuple[int, ...] # of size np.sum(num_sampled_tokens)
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
lora_requests: set[LoRARequest]
prompt_lora_mapping, token_lora_mapping, lora_requests = (
input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens)
)
return self._set_active_loras(
prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type
)
@contextmanager
def maybe_setup_dummy_loras(
self, lora_config: LoRAConfig | None, remove_lora: bool = True
):
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_loras = lora_config.max_loras
lora_warmup_rank = (
lora_config.max_lora_rank if lora_config.max_lora_rank < 8 else 8
)
# Make dummy lora requests
lora_requests: set[LoRARequest] = {
LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
for lora_id in range(1, num_loras + 1)
}
with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
# load from disk.
for lr in lora_requests:
self.lora_manager.add_dummy_lora(lr, rank=lora_warmup_rank)
yield
# __exit__ code
if remove_lora:
self.lora_manager.remove_all_adapters()
@contextmanager
def maybe_select_dummy_loras(
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
num_sampled_tokens: np.ndarray | None = None,
num_active_loras: int = 0,
):
"""
Context manager to select dummy LoRAs for capture/warmup.
Args:
lora_config: LoRA configuration, or None if LoRA is disabled.
num_scheduled_tokens: Array of scheduled token counts per request.
num_sampled_tokens: Array of sampled token counts per request.
num_active_loras: Number of distinct active LoRAs to use.
- 0: No LoRA active (set up zero mappings).
- >0: Use exactly this many distinct LoRAs.
"""
if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
# Skip LoRA setup entirely only if no LoRA config
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_reqs = len(num_scheduled_tokens)
max_loras = lora_config.max_loras
# Determine how many distinct LoRAs to use and whether to include
# no-LoRA tokens (-1 entries).
# When num_active_loras > max_loras (e.g., max_loras + 1), we need
# to include -1 entries to simulate batches with both LoRA and
# no-LoRA tokens. This ensures prepare_tensors computes the correct
# num_active_loras that matches the cudagraph capture key.
if num_active_loras == 0:
# No LoRA active - use 0 mappings like the original code
effective_num_loras = 0
include_no_lora = False
elif num_active_loras > max_loras:
# num_active_loras > max_loras means we want max_loras adapters
# PLUS no-LoRA tokens (-1). This is the max_loras + 1 case.
effective_num_loras = max_loras
include_no_lora = True
else:
# Specific number of active LoRAs requested
effective_num_loras = min(num_active_loras, max_loras)
include_no_lora = False
# Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
# LoRA IDs are 1-indexed (1 to max_loras) as required by LoRARequest.
# convert_mapping() will convert these to 0-indexed slot indices.
if effective_num_loras > 0:
if include_no_lora:
# Include -1 (no-LoRA) entries by cycling through
# -1, 1, 2, ..., effective_num_loras
# This ensures prepare_tensors sees both LoRA and no-LoRA
# tokens, computing num_active_loras = effective_num_loras+1
cycle_values = np.array(
list(range(1, effective_num_loras + 1)),
dtype=np.int32,
)
prompt_lora_mapping = cycle_values[
np.arange(num_reqs, dtype=np.int32) % len(cycle_values)
]
else:
# Use 1 to effective_num_loras (1-indexed lora IDs)
prompt_lora_mapping = (
np.arange(num_reqs, dtype=np.int32) % effective_num_loras
) + 1
else:
# No LoRA active - use 0 for all tokens (original behavior)
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
# Make sample lora mapping
sample_lora_mapping = np.repeat(prompt_lora_mapping, num_sampled_tokens)
# Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
# Make dummy lora requests (only for the active LoRAs)
lora_requests: set[LoRARequest] = {
LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
for lora_id in range(1, effective_num_loras + 1)
}
self._set_active_loras(
tuple(sample_lora_mapping),
tuple(token_lora_mapping),
lora_requests,
mapping_type,
)
yield
@contextmanager
def maybe_dummy_run_with_lora(
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray,
remove_lora: bool = True,
num_active_loras: int = 0,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
):
"""
Context manager for dummy runs with LoRA.
Args:
lora_config: LoRA configuration.
num_scheduled_tokens: Array of scheduled token counts per request.
num_sampled_tokens: Array of sampled token counts per request.
remove_lora: Whether to remove LoRAs after the context exits.
num_active_loras: Number of distinct active LoRAs to use.
LoRA is activated when num_active_loras > 0.
"""
with (
self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(
lora_config,
num_scheduled_tokens,
mapping_type,
num_sampled_tokens,
num_active_loras,
),
):
yield
def maybe_remove_all_loras(self, lora_config: LoRAConfig | None):
if lora_config is None:
return
self.lora_manager.remove_all_adapters()
def add_lora(self, lora_request: LoRARequest) -> bool:
self._ensure_lora_enabled()
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
self._ensure_lora_enabled()
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
self._ensure_lora_enabled()
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> set[int]:
self._ensure_lora_enabled()
return self.lora_manager.list_adapters()

View File

@@ -0,0 +1,242 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Any
import torch
from vllm.config import CacheConfig
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
)
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
@triton.jit
def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
src_ptr = tl.load(src_ptrs + pid)
dst_ptr = tl.load(dst_ptrs + pid)
size = tl.load(sizes + pid)
offsets = tl.arange(0, BLOCK_SIZE)
for i in range(0, size, BLOCK_SIZE):
mask = (i + offsets) < size
curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
data = tl.load(curr_src_ptr, mask=mask)
tl.store(curr_dst_ptr, data, mask=mask)
def batch_memcpy(src_ptrs, dst_ptrs, sizes):
batch = src_ptrs.shape[0]
assert dst_ptrs.shape[0] == batch
assert sizes.shape[0] == batch
grid = (batch,)
BLOCK_SIZE = 1024
batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE)
def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]:
mamba_group_ids: list[int] = []
mamba_specs: list[MambaSpec] = []
for i in range(len(kv_cache_config.kv_cache_groups)):
kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec
if isinstance(kv_cache_spec, MambaSpec):
mamba_group_ids.append(i)
mamba_specs.append(kv_cache_spec)
assert len(mamba_group_ids) > 0, "no mamba layers in the model"
assert all(mamba_specs[0] == spec for spec in mamba_specs)
return mamba_group_ids, mamba_specs[0]
def collect_mamba_copy_meta(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
kv_cache_config: KVCacheConfig,
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
mamba_group_ids: list[int],
src_block_idx: int,
dest_block_idx: int,
accept_token_bias: int,
req_state: CachedRequestState,
forward_context: dict[str, Any],
):
if src_block_idx == dest_block_idx and accept_token_bias == 0:
return
for mamba_group_id in mamba_group_ids:
block_ids = req_state.block_ids[mamba_group_id]
dest_block_id = block_ids[dest_block_idx]
layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names
for layer_name in layer_names:
attention = forward_context[layer_name]
kv_caches: list[torch.Tensor] = attention.kv_cache[0]
for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs):
copy_spec = state_copy_func(
state, block_ids, src_block_idx, accept_token_bias + 1
)
src_state_list.append(copy_spec.start_addr)
dest_state_list.append(state[dest_block_id].data_ptr())
num_elements_list.append(copy_spec.num_elements * state.element_size())
def do_mamba_copy_block(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
):
if len(src_state_list) == 0:
return
assert len(src_state_list) == len(dest_state_list)
assert len(src_state_list) == len(num_elements_list)
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
def preprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
cache_config: CacheConfig,
mamba_state_idx: dict[str, int],
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
):
"""
Copy the mamba state of previous step to the last
(1 + num_speculative_blocks) block.
"""
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
num_speculative_blocks = mamba_spec.num_speculative_blocks
# TODO(Chen): we need to optimize this function a lot
assert cache_config.enable_prefix_caching
block_size = mamba_spec.block_size
finished_req_ids = scheduler_output.finished_req_ids
preempted_req_ids = scheduler_output.preempted_req_ids or set()
# We need to clear mamba_state_idx for resumed requests. When requests are
# force-preempted (e.g., during reset_prefix_cache / KV cache flush),
# they appear in resumed_req_ids without a corresponding entry in
# preempted_req_ids, leaving stale mamba_state_idx entries that can
# point to block indices beyond the new (smaller) block allocation.
resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
mamba_state_idx.pop(req_id, None)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
prev_state_idx = mamba_state_idx.get(req_id)
if prev_state_idx is None:
# new / resumed request, no previous state
# if num_computed_tokens is 0, prev_state_idx will be -1
prev_state_idx = (req_state.num_computed_tokens - 1) // block_size
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_blocks: int = (
cdiv(req_state.num_computed_tokens + num_scheduled_tokens, block_size)
+ num_speculative_blocks
)
# We always save the current running state at the last
# (1 + num_speculative_blocks) block.
# A corner case worth mention here: assume we have block_size = 4 and
# num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft
# tokens [draft 1, draft 2]. Then we will have:
# Block 0: [A, B, C, draft 1]
# Block 1: [draft 2, TOFILL, TOFILL, TOFILL]
# Block 2: speculative block
# Block 3: speculative block
# And use block 1 to save the running state.
curr_state_idx = num_blocks - 1 - num_speculative_blocks
mamba_state_idx[req_id] = curr_state_idx
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
prev_state_idx,
curr_state_idx,
input_batch.num_accepted_tokens_cpu[i] - 1,
req_state,
forward_context,
)
input_batch.num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
def postprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
mamba_state_idx: dict[str, int],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
):
"""
If a blocks is converted from partial block to full block in this step, copy the
state from the block for running state to the new full block.
"""
num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens
scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
# NOTE: can be optimized as this function always returns the same result
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, []))
num_scheduled_tokens = num_scheduled_tokens_dict[req_id]
num_accepted_tokens = num_accepted_tokens_cpu[i]
num_tokens_running_state = (
num_computed_tokens + num_scheduled_tokens - num_draft_tokens
)
new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1
aligned_new_computed_tokens = (
new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size
)
# TODO: how to ensure all blocks that cache_blocks called are cached here?
if aligned_new_computed_tokens >= num_tokens_running_state:
accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state
src_block_idx = mamba_state_idx[req_id]
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
src_block_idx,
dest_block_idx,
accept_token_bias,
req_state,
forward_context,
)
if src_block_idx == dest_block_idx:
num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)

View File

@@ -0,0 +1,574 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a TPU input batch
from typing import cast
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState
_SAMPLING_EPS = 1e-5
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[int],
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self._req_ids: list[str | None] = []
self.req_id_to_index: dict[str, int] = {}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
)
# Sampling-related.
self.temperature = torch.empty(
(max_num_reqs,), dtype=torch.float32, device=device
)
self.temperature_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
self.top_p_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
self.top_k_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set()
self.min_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
self.min_p_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.min_p_reqs: set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty(
(max_num_reqs,), dtype=torch.float, device=device
)
self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
)
self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures
self.presence_penalties = torch.empty(
(max_num_reqs,), dtype=torch.float, device=device
)
self.presence_penalties_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures
self.repetition_penalties = torch.empty(
(max_num_reqs,), dtype=torch.float, device=device
)
self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
)
self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# req_index -> (min_tokens, stop_token_ids)
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
self.logit_bias: list[dict[int, float] | None] = [None] * max_num_reqs
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: torch.Tensor | None = None
self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.req_output_token_ids: list[list[int] | None] = []
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
# while performing state updates to the batch.
return cast(list[str], self._req_ids)
def add_request(
self,
request: "CachedRequestState",
req_index: int | None = None,
) -> None:
if req_index is None:
req_index = self.num_reqs
assert req_index < self.max_num_reqs
req_id = request.req_id
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds
)
# TODO: copy prompt_embeds
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
sampling_params = request.sampling_params
assert sampling_params is not None, "pooling requests not supported yet"
if sampling_params.sampling_type == SamplingType.GREEDY:
# Should avoid division by zero later when apply_temperature.
self.temperature_cpu[req_index] = 0.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.min_tokens:
self.min_tokens[req_index] = (
sampling_params.min_tokens,
sampling_params.all_stop_token_ids,
)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device,
)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.bool, device="cpu"
)
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids
] = False
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[req_index] = sampling_params.bad_words_token_ids
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0
def remove_request(self, req_id: str) -> int | None:
"""This method must always be followed by a call to condense()."""
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.min_p_reqs.discard(req_id)
self.min_tokens.pop(req_index, None)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
self.lora_id_to_request_ids[lora_id].discard(req_id)
if len(self.lora_id_to_request_ids[lora_id]) == 0:
self.lora_id_to_request_ids.pop(lora_id)
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0
self.logit_bias[req_index] = None
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
self.req_output_token_ids[i2],
self.req_output_token_ids[i1],
)
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
self.req_id_to_index[old_id_i2],
self.req_id_to_index[old_id_i1],
)
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
self.num_tokens_no_spec[i2],
self.num_tokens_no_spec[i1],
)
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
self.num_prompt_tokens[i2],
self.num_prompt_tokens[i1],
)
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
self.num_computed_tokens_cpu[i2],
self.num_computed_tokens_cpu[i1],
)
self.temperature_cpu[i1], self.temperature_cpu[i2] = (
self.temperature_cpu[i2],
self.temperature_cpu[i1],
)
self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
self.frequency_penalties_cpu[i2],
self.frequency_penalties_cpu[i1],
)
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
self.presence_penalties_cpu[i2],
self.presence_penalties_cpu[i1],
)
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
self.repetition_penalties_cpu[i2],
self.repetition_penalties_cpu[i1],
)
self.min_p_cpu[i1], self.min_p_cpu[i2] = self.min_p_cpu[i2], self.min_p_cpu[i1]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.min_tokens, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
self.request_lora_mapping[i2],
self.request_lora_mapping[i1],
)
self.logit_bias[i1], self.logit_bias[i2] = (
self.logit_bias[i2],
self.logit_bias[i1],
)
if self.allowed_token_ids_mask_cpu_tensor is not None:
(
self.allowed_token_ids_mask_cpu_tensor[i1],
self.allowed_token_ids_mask_cpu_tensor[i2],
) = (
self.allowed_token_ids_mask_cpu_tensor[i2],
self.allowed_token_ids_mask_cpu_tensor[i1],
)
self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: list[int]) -> None:
"""Move non-empty requests down into lower, empty indices.
Args:
empty_req_indices: empty batch indices, sorted descending.
"""
num_reqs = self.num_reqs
if num_reqs == 0:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = empty_req_indices.pop()
if empty_index >= last_req_index:
break
# Swap the states.
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
self._req_ids[empty_index] = req_id
self._req_ids[last_req_index] = None
self.req_output_token_ids[empty_index] = output_token_ids
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
num_tokens = self.num_tokens_no_spec[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens
]
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index
]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
last_req_index
]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
last_req_index
]
self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
last_req_index
]
self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
last_req_index
]
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
min_token = self.min_tokens.pop(last_req_index, None)
if min_token is not None:
self.min_tokens[empty_index] = min_token
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index
]
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
self.allowed_token_ids_mask_cpu_tensor[last_req_index]
)
bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index -= 1
# Trim lists to the batch size.
del self._req_ids[self.num_reqs :]
del self.req_output_token_ids[self.num_reqs :]
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[: self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = self.token_ids_cpu[: self.num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values()
)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_min_p(self) -> bool:
return len(self.min_p_reqs) == 0
@property
def no_penalties(self) -> bool:
return (
len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0
)
@property
def max_num_logprobs(self) -> int | None:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0

View File

@@ -0,0 +1,243 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
import torch
from vllm.config import ParallelConfig
from vllm.v1.attention.backend import CommonAttentionMetadata
@dataclass
class UBatchSlice:
request_slice: slice
token_slice: slice
def is_empty(self) -> bool:
return (
self.request_slice.start == self.request_slice.stop
or self.token_slice.start == self.token_slice.stop
)
@property
def num_tokens(self) -> int:
return self.token_slice.stop - self.token_slice.start
UBatchSlices: TypeAlias = list[UBatchSlice]
def is_last_ubatch_empty(
orig_num_tokens: int, padded_num_tokens: int, num_ubatches: int
) -> bool:
return (padded_num_tokens // num_ubatches) * (num_ubatches - 1) >= orig_num_tokens
def check_ubatch_thresholds(
config: ParallelConfig, num_tokens: int, uniform_decode: bool
) -> bool:
if not config.use_ubatching:
return False
if uniform_decode:
return num_tokens >= config.dbo_decode_token_threshold
else:
return num_tokens >= config.dbo_prefill_token_threshold
# This pads the last ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slices(
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
) -> UBatchSlices:
last_slice = ubatch_slices[-1]
padded_last_request_slice = slice(last_slice.request_slice.start, num_reqs_padded)
padded_last_token_slice = slice(last_slice.token_slice.start, num_total_tokens)
return ubatch_slices[:-1] + [
UBatchSlice(padded_last_request_slice, padded_last_token_slice)
]
def maybe_create_ubatch_slices(
should_ubatch: bool,
num_scheduled_tokens: np.ndarray,
num_tokens_padded: int,
num_reqs_padded: int,
num_ubatches: int,
split_point: list[int] | int | None = None,
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
if not should_ubatch:
return None, None
if split_point is None:
split_point = int(num_tokens_padded) // num_ubatches
token_split_points = [split_point * i for i in range(1, num_ubatches)]
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
ubatch_slices = []
start_token = 0
# Add the end point to the split points to make iteration easier
all_points = token_split_points + [cu_num_tokens[-1]]
for end_token in all_points:
token_slice = slice(start_token, end_token)
# Determine request slices using exclusive stop semantics
# Ubatch includes requests whose tokens overlap [start_token, end_token)
# Start at the request that contains the start_token
# or the request starting exactly at start_token (if on boundary)
req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1)
# Stop at the request that starts at or after end_token
req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left"))
req_slice = slice(req_start, req_stop)
ubatch_slices.append(UBatchSlice(req_slice, token_slice))
start_token = end_token
ubatch_slices_padded = _pad_out_ubatch_slices(
ubatch_slices, num_tokens_padded, num_reqs_padded
)
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
return ubatch_slices, ubatch_slices_padded
def slice_query_start_locs(
query_start_loc: torch.Tensor,
request_slice: slice,
) -> torch.Tensor:
"""
Creates a new query_start_loc that corresponds to the requests in
request_slice.
Note: This function creates a new tensor to hold the new query_start_locs.
This will break cudagraph compatibility.
"""
return (
query_start_loc[request_slice.start : request_slice.stop + 1]
- query_start_loc[request_slice.start]
)
def _make_metadata_with_slice(
ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
) -> CommonAttentionMetadata:
"""
This function creates a new CommonAttentionMetadata that corresponds to
the requests included in ubatch_slice
"""
assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
request_slice = ubatch_slice.request_slice
token_slice = ubatch_slice.token_slice
start_locs = attn_metadata.query_start_loc_cpu
first_req = request_slice.start
first_tok = token_slice.start
last_req = request_slice.stop - 1
last_tok = token_slice.stop - 1
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
"Token slice start outside of first request"
)
# NOTE: last token can be outside of the last request if we have CG padding.
# If the request is split across ubatches, we have to adjust the metadata.
# splits_first_request: The first request in this slice is the continuation of
# a request that started in a previous slice.
# splits_last_request: The last request in this slice continues into the
# next slice.
splits_first_request = first_tok > start_locs[first_req]
splits_last_request = last_tok < start_locs[last_req + 1] - 1
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
query_start_loc = slice_query_start_locs(
attn_metadata.query_start_loc, request_slice
)
assert len(query_start_loc) >= 2, (
f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
)
if splits_first_request:
tokens_skipped = first_tok - start_locs[first_req]
query_start_loc[1:] -= tokens_skipped
query_start_loc_cpu[1:] -= tokens_skipped
seq_lens = attn_metadata.seq_lens[request_slice]
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
if splits_last_request:
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# the tokens skipped because query_start_loc_cpu might have been modified
# if splits_first_request is True.
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
query_start_loc[-1] -= tokens_skipped
query_start_loc_cpu[-1] -= tokens_skipped
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
seq_lens = seq_lens.clone()
seq_lens_cpu = seq_lens_cpu.clone()
seq_lens[-1] -= tokens_skipped
seq_lens_cpu[-1] -= tokens_skipped
max_seq_len = int(seq_lens_cpu.max())
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
num_requests = request_slice.stop - request_slice.start
num_actual_tokens = token_slice.stop - token_slice.start
max_query_len = int(
torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
)
# This is to account for the case where we are in a dummy
# run and query_start_loc_cpu is full of 0s
if max_query_len == 0:
max_query_len = attn_metadata.max_query_len
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
slot_mapping = attn_metadata.slot_mapping[token_slice]
return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
num_reqs=num_requests,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
)
def split_attn_metadata(
ubatch_slices: list[UBatchSlice],
common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UBatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata
"""
results = []
for ubatch_slice in ubatch_slices:
results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
return results

241
vllm/v1/worker/ubatching.py Normal file
View File

@@ -0,0 +1,241 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
import torch
from vllm import forward_context
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.utils.torch_utils import current_stream
logger = init_logger(__name__)
_THREAD_ID_TO_CONTEXT: dict = {}
# Here we hardcode the number of microbatches to 2 for default.
_NUM_UBATCHES: int = 2
_CURRENT_CONTEXTS: list["UBatchContext | None"] = []
class UBatchContext:
"""
Context manager for micro-batching synchronization using threading events.
"""
def __init__(
self,
id: int,
comm_stream: torch.cuda.Stream,
compute_stream: torch.cuda.Stream,
forward_context: ForwardContext,
ready_barrier: threading.Barrier,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.Event,
gpu_compute_done_event: torch.Event,
schedule: str = "default",
):
self.id = id
self.comm_stream = comm_stream
self.compute_stream = compute_stream
self.forward_context = forward_context
self.ready_barrier = ready_barrier
self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event
self.current_stream = compute_stream
self.gpu_comm_done_event = gpu_comm_done_event
self.gpu_compute_done_event = gpu_compute_done_event
self.schedule = schedule
self.recv_hook = None
def __enter__(self):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
_CURRENT_CONTEXTS[self.id] = self
# _NUM_UBATCHES is set in make_ubatch_contexts
self.ready_barrier.wait()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
# Assume we want to start on the compute stream
self.update_stream(self.compute_stream)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_CURRENT_CONTEXTS[self.id] = None
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
self.maybe_run_recv_hook()
self.cpu_signal_event.set()
self.cpu_wait_event.clear()
return False
def _restore_context(self):
forward_context._forward_context = self.forward_context
def update_stream(self, stream):
self.current_stream = stream
if current_stream() != self.current_stream:
torch.cuda.set_stream(self.current_stream)
def _signal_comm_done(self):
self.gpu_comm_done_event.record(self.comm_stream)
def _signal_compute_done(self):
self.gpu_compute_done_event.record(self.compute_stream)
def _wait_compute_done(self):
self.comm_stream.wait_event(self.gpu_compute_done_event)
def _wait_comm_done(self):
self.compute_stream.wait_event(self.gpu_comm_done_event)
def _cpu_yield(self):
# It is critical for correctness that only one thread is running
# at a time. These asserts just make sure that this is the only
# thread running before waking the other one up and going to sleep
assert forward_context._forward_context == self.forward_context
assert current_stream() == self.current_stream
assert not self.cpu_wait_event.is_set()
self.cpu_signal_event.set()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
def switch_to_comm(self):
self.update_stream(self.comm_stream)
def switch_to_compute(self):
self.update_stream(self.compute_stream)
def switch_to_comm_sync(self):
self._signal_compute_done()
self.update_stream(self.comm_stream)
self._wait_compute_done()
def switch_to_compute_sync(self):
self._signal_comm_done()
self.update_stream(self.compute_stream)
self._wait_comm_done()
def maybe_run_recv_hook(self):
if self.recv_hook is not None:
self.recv_hook()
self.recv_hook = None
def yield_(self):
self.current_stream = current_stream()
self._cpu_yield()
self.update_stream(self.current_stream)
def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream
self._signal_compute_done()
self._cpu_yield()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
self._signal_comm_done()
self._cpu_yield()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
self._wait_comm_done()
def dbo_enabled() -> bool:
return len(_THREAD_ID_TO_CONTEXT) > 0
def dbo_current_ubatch_id() -> int:
if len(_THREAD_ID_TO_CONTEXT) == 0:
return 0
return _THREAD_ID_TO_CONTEXT[threading.get_ident()]
def _register_ubatch_function(func):
def wrapper(*args, **kwargs):
if len(_THREAD_ID_TO_CONTEXT) > 0:
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
ctx = _CURRENT_CONTEXTS[ctx_idx]
func(ctx, *args, **kwargs)
return wrapper
dbo_maybe_run_recv_hook = _register_ubatch_function(UBatchContext.maybe_run_recv_hook)
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
UBatchContext.yield_and_switch_from_compute_to_comm
)
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
UBatchContext.yield_and_switch_from_comm_to_compute
)
dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm)
dbo_switch_to_compute = _register_ubatch_function(UBatchContext.switch_to_compute)
dbo_switch_to_comm_sync = _register_ubatch_function(UBatchContext.switch_to_comm_sync)
dbo_switch_to_compute_sync = _register_ubatch_function(
UBatchContext.switch_to_compute_sync
)
def dbo_register_recv_hook(recv_hook):
if len(_THREAD_ID_TO_CONTEXT) > 0:
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES]
next_ctx.recv_hook = recv_hook
def dbo_get_previous_event(func, *args, **kwargs):
if len(_THREAD_ID_TO_CONTEXT) > 0:
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
ctx = _CURRENT_CONTEXTS[ctx_idx]
# execute callable on the ubatch compute stream to record/wait events there
with torch.cuda.stream(ctx.compute_stream):
return func(*args, **kwargs)
def make_ubatch_contexts(
num_micro_batches: int,
compute_stream: torch.cuda.Stream,
comm_stream: torch.cuda.Stream,
forward_contexts: list[ForwardContext],
ready_barrier: threading.Barrier,
schedule: str = "default",
) -> list[UBatchContext]:
global _NUM_UBATCHES, _CURRENT_CONTEXTS
assert num_micro_batches > 1, "num_micro_batches must be greater than 1"
_NUM_UBATCHES = num_micro_batches
# Ensure the global context list is large enough
if len(_CURRENT_CONTEXTS) < num_micro_batches:
_CURRENT_CONTEXTS.extend([None] * (num_micro_batches - len(_CURRENT_CONTEXTS)))
"""
Create a context manager for micro-batching synchronization.
"""
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)]
gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)]
ctxs = []
for i in range(num_micro_batches):
ctx = UBatchContext(
id=i,
compute_stream=compute_stream,
comm_stream=comm_stream,
forward_context=forward_contexts[i],
ready_barrier=ready_barrier,
cpu_wait_event=cpu_events[i],
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
gpu_comm_done_event=gpu_comm_done_events[i],
gpu_compute_done_event=gpu_compute_done_events[i],
schedule=schedule,
)
ctxs.append(ctx)
return ctxs

239
vllm/v1/worker/utils.py Normal file
View File

@@ -0,0 +1,239 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections import defaultdict
from dataclasses import dataclass, field
import torch
from vllm.config import CacheConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.platforms import current_platform
from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
logger = init_logger(__name__)
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
layer_names: list[str]
kv_cache_spec: KVCacheSpec
kv_cache_group_id: int
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistent buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder] = field(
default_factory=lambda: []
)
def create_metadata_builders(
self,
vllm_config,
device,
kernel_block_size: int | None,
num_metadata_builders: int = 1,
):
kv_cache_spec_builder = (
self.kv_cache_spec.copy_with_new_block_size(kernel_block_size)
if kernel_block_size is not None
else self.kv_cache_spec
)
self.metadata_builders = [
self.backend.get_builder_cls()(
kv_cache_spec_builder,
self.layer_names,
vllm_config,
device,
)
for _ in range(num_metadata_builders)
]
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id
return self.metadata_builders[ubatch_id]
def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings,
expected_num_items: int,
) -> None:
"""
Perform sanity checks for the result of
[`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][].
"""
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
f"or a single 3D tensor, but got {type(mm_embeddings)} "
"instead. This is most likely due to incorrect implementation "
"of the model's `embed_multimodal` method."
)
assert len(mm_embeddings) == expected_num_items, (
"Expected number of multimodal embeddings to match number of "
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
"instead. This is most likely due to incorrect implementation "
"of the model's `embed_multimodal` method."
)
assert all(e.ndim == 2 for e in mm_embeddings), (
"Expected multimodal embeddings to be a sequence of 2D tensors, "
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
"instead. This is most likely due to incorrect implementation "
"of the model's `embed_multimodal` method."
)
def request_memory(init_snapshot: MemorySnapshot, cache_config: CacheConfig) -> int:
"""
Calculate the amount of memory required by vLLM, then validate
that the current amount of free memory is sufficient for that.
"""
requested_memory = math.ceil(
init_snapshot.total_memory * cache_config.gpu_memory_utilization
)
if init_snapshot.free_memory < requested_memory:
raise ValueError(
f"Free memory on device {init_snapshot.device_} "
f"({format_gib(init_snapshot.free_memory)}/"
f"{format_gib(init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({cache_config.gpu_memory_utilization}, "
f"{format_gib(requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
return requested_memory
def add_kv_sharing_layers_to_kv_cache_groups(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec],
runner_only_attn_layers: set[str] | None = None,
) -> None:
"""
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
for layers that do not allocate its own KV cache, based on the mapping in
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
group, which is needed to ensure that attention metadata is assigned later.
Args:
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
If an Attention layer `layer_name` is in the keys of this dict, it
means this layer will perform attention using the keys and values
from the KV cache of `shared_kv_cache_layers[layer_name]`.
kv_cache_groups: The KV cache groups of the model.
"""
layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
for kv_cache_group in kv_cache_groups:
for layer_name in kv_cache_group.layer_names:
layer_to_kv_cache_group[layer_name] = kv_cache_group
for layer_name, target_layer_name in shared_kv_cache_layers.items():
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
tgt_kv_cache_group.layer_names.append(layer_name)
if runner_only_attn_layers is not None:
runner_only_attn_layers.add(layer_name)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, Attention],
runner_kv_caches: list[torch.Tensor],
num_attn_module: int = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
# TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers
# in the same decoder block.
if (
current_platform.is_cuda_alike()
or current_platform.is_xpu()
or current_platform.is_cpu()
):
# We know that the GPU / CPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this.
pass
else:
raise NotImplementedError
for layer_name in layer_names:
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
def is_residual_scattered_for_sp(
vllm_config: VllmConfig, num_input_tokens: int
) -> bool:
"""Check if the residual tensor is scattered for sequence parallelism.
The residual tensor is scattered across tensor parallel ranks when sequence
parallelism and tensor parallelism is enabled.
This follows the same logic as SequenceParallelismPass.is_applicable_for_range():
- In full-graph compilation mode (no splitting ops or using inductor graph
partition), SP is always applied
- Otherwise, SP is only applied for specific shapes in compile_sizes
"""
if not vllm_config.compilation_config.pass_config.enable_sp:
return False
tp = vllm_config.parallel_config.tensor_parallel_size
if tp == 1:
return False
# When sequence parallelism is enabled, we always pad num_input_tokens
# to be a multiple of tensor_parallel_size (tp) earlier.
assert num_input_tokens % tp == 0
if (
not vllm_config.compilation_config.splitting_ops
or vllm_config.compilation_config.use_inductor_graph_partition
):
return True
compile_sizes = vllm_config.compilation_config.compile_sizes
if compile_sizes is None:
return False
return num_input_tokens in compile_sizes

View File

@@ -0,0 +1,373 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeVar
import torch
import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.tracing import instrument
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.system_utils import update_environment_variables
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.serial_utils import run_method
if TYPE_CHECKING:
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.outputs import AsyncModelRunnerOutput, ModelRunnerOutput
else:
SchedulerOutput = object
GrammarOutput = object
AsyncModelRunnerOutput = object
ModelRunnerOutput = object
logger = init_logger(__name__)
_R = TypeVar("_R")
class WorkerBase:
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
) -> None:
"""
Initialize common worker components.
Args:
vllm_config: Complete vLLM configuration
local_rank: Local device index
rank: Global rank in distributed setup
distributed_init_method: Distributed initialization method
is_driver_worker: Whether this worker handles driver
responsibilities
"""
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config
self.compilation_config = vllm_config.compilation_config
from vllm.platforms import current_platform
self.current_platform = current_platform
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
# Device and model state
self.device: torch.device | None = None
self.model_runner: nn.Module | None = None
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""Get specifications for KV cache implementation."""
raise NotImplementedError
def compile_or_warm_up_model(self) -> None:
"""Prepare model for execution through compilation/warmup."""
raise NotImplementedError
def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
return
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise NotImplementedError
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks."""
raise NotImplementedError
def reset_mm_cache(self) -> None:
reset_fn = getattr(self.model_runner, "reset_mm_cache", None)
if callable(reset_fn):
reset_fn()
def get_model(self) -> nn.Module:
raise NotImplementedError
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
"""Apply a function on the model inside this worker."""
return fn(self.get_model())
def get_model_inspection(self) -> str:
"""Return a transformers-style hierarchical view of the model."""
from vllm.model_inspection import format_model_inspection
return format_model_inspection(self.get_model())
def load_model(self) -> None:
"""Load model onto target device."""
raise NotImplementedError
def execute_model(
self, scheduler_output: SchedulerOutput
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
"""If this method returns None, sample_tokens should be called immediately after
to obtain the ModelRunnerOutput.
Note that this design may be changed in future if/when structured outputs
parallelism is re-architected.
"""
raise NotImplementedError
def sample_tokens(
self, grammar_output: GrammarOutput
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
"""Should be called immediately after execute_model iff it returned None."""
raise NotImplementedError
def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise NotImplementedError
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def list_loras(self) -> set[int]:
raise NotImplementedError
@property
def vocab_size(self) -> int:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()
def shutdown(self) -> None:
"""Clean up resources held by the worker."""
return
class WorkerWrapperBase:
"""
This class represents one process in an executor/engine. It is responsible
for lazily initializing the worker and handling the worker's lifecycle.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
"""
def __init__(
self,
rpc_rank: int = 0,
global_rank: int | None = None,
) -> None:
"""
Initialize the worker wrapper with the given vllm_config and rpc_rank.
Note: rpc_rank is the rank of the worker in the executor. In most cases,
it is also the rank of the worker in the distributed group. However,
when multiple executors work together, they can be different.
e.g. in the case of SPMD-style offline inference with TP=2,
users can launch 2 engines/executors, each with only 1 worker.
All workers have rpc_rank=0, but they have different ranks in the TP
group.
"""
self.rpc_rank = rpc_rank
self.global_rank = self.rpc_rank if global_rank is None else global_rank
# Initialized after init_worker is called
self.worker: WorkerBase
self.vllm_config: VllmConfig
def shutdown(self) -> None:
if self.worker is not None:
self.worker.shutdown()
def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
"""
Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
# if self.rpc_rank in rank_mapping:
# self.rpc_rank = rank_mapping[self.rpc_rank]
old_rank = self.rpc_rank
if old_rank in rank_mapping:
self.rpc_rank = rank_mapping[old_rank]
if self.global_rank == old_rank:
self.global_rank = rank_mapping[old_rank]
def update_environment_variables(
self,
envs_list: list[dict[str, str]],
) -> None:
envs = envs_list[self.rpc_rank]
update_environment_variables(envs)
@instrument(span_name="Worker init")
def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None:
"""
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
kwargs = all_kwargs[self.rpc_rank]
vllm_config: VllmConfig | None = kwargs.get("vllm_config")
assert vllm_config is not None, (
"vllm_config is required to initialize the worker"
)
self.vllm_config = vllm_config
vllm_config.enable_trace_function_call_for_thread()
from vllm.plugins import load_general_plugins
load_general_plugins()
parallel_config = vllm_config.parallel_config
if isinstance(parallel_config.worker_cls, str):
worker_class: type[WorkerBase] = resolve_obj_by_qualname(
parallel_config.worker_cls
)
else:
raise ValueError(
"passing worker_cls is no longer supported. "
"Please pass keep the class in a separate module "
"and pass the qualified name of the class as a string."
)
if parallel_config.worker_extension_cls:
worker_extension_cls = resolve_obj_by_qualname(
parallel_config.worker_extension_cls
)
extended_calls = []
if worker_extension_cls not in worker_class.__bases__:
# check any conflicts between worker and worker_extension_cls
for attr in dir(worker_extension_cls):
if attr.startswith("__"):
continue
assert not hasattr(worker_class, attr), (
f"Worker class {worker_class} already has an attribute"
f" {attr}, which conflicts with the worker"
f" extension class {worker_extension_cls}."
)
if callable(getattr(worker_extension_cls, attr)):
extended_calls.append(attr)
# dynamically inherit the worker extension class
worker_class.__bases__ = worker_class.__bases__ + (
worker_extension_cls,
)
logger.info(
"Injected %s into %s for extended collective_rpc calls %s",
worker_extension_cls,
worker_class,
extended_calls,
)
shared_worker_lock = kwargs.pop("shared_worker_lock", None)
if shared_worker_lock is None:
msg = (
"Missing `shared_worker_lock` argument from executor. "
"This argument is needed for mm_processor_cache_type='shm'."
)
mm_config = vllm_config.model_config.multimodal_config
if mm_config and mm_config.mm_processor_cache_type == "shm":
raise ValueError(msg)
else:
logger.warning_once(msg)
self.mm_receiver_cache = None
else:
self.mm_receiver_cache = (
MULTIMODAL_REGISTRY.worker_receiver_cache_from_config(
vllm_config,
shared_worker_lock,
)
)
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
kv_cache_config = kv_cache_configs[self.global_rank]
assert self.vllm_config is not None
with set_current_vllm_config(self.vllm_config):
self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self):
assert self.vllm_config is not None
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during device initialization
self.worker.init_device() # type: ignore
def execute_method(self, method: str | bytes, *args, **kwargs):
try:
# method resolution order:
# if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return run_method(self, method, args, kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (
f"Error executing method {method!r}. "
"This might cause deadlock in distributed execution."
)
logger.exception(msg)
raise e
def __getattr__(self, attr: str):
return getattr(self.worker, attr)
def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None:
mm_cache = self.mm_receiver_cache
if mm_cache is None:
return
for req_data in scheduler_output.scheduled_new_reqs:
req_data.mm_features = mm_cache.get_and_update_features(
req_data.mm_features
)
def execute_model(
self, scheduler_output: SchedulerOutput
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
self._apply_mm_cache(scheduler_output)
return self.worker.execute_model(scheduler_output)
def reset_mm_cache(self) -> None:
mm_receiver_cache = self.mm_receiver_cache
if mm_receiver_cache is not None:
mm_receiver_cache.clear_cache()
self.worker.reset_mm_cache()

252
vllm/v1/worker/workspace.py Normal file
View File

@@ -0,0 +1,252 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
import os
from itertools import accumulate
from math import prod
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.math_utils import round_up
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__)
def _compute_bytes(shape: tuple[int, ...], dtype: torch.dtype) -> int:
return prod(shape) * dtype.itemsize
# Constants
_MB = 1024**2
_GiB = 1024**3
# Global workspace manager instance
_manager: "WorkspaceManager | None" = None
class WorkspaceManager:
"""Manager for workspace allocation.
Manages workspace buffers for DBO (Dual Batch Overlap) execution.
Can be locked to prevent further growth during execution.
"""
def __init__(self, device: torch.device, num_ubatches: int | None = None):
self._device = device
# Cache num ubatches at init based on configuration (default to 1)
self._num_ubatches = num_ubatches if num_ubatches is not None else 1
self._current_workspaces: list[torch.Tensor | None] = [None, None]
self._locked: bool = False
@staticmethod
def _workspace_size_bytes(workspace: torch.Tensor | None) -> int:
"""Get size of workspace in bytes."""
if workspace is None:
return 0
return workspace.numel() * workspace.element_size()
def lock(self) -> None:
"""Lock the workspace to prevent further growth.
After locking, any attempt to allocate a larger workspace will raise
an assertion error. This ensures workspace size is fixed during execution.
"""
self._locked = True
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Workspace locked. Current sizes: %s",
[
self._workspace_size_bytes(ws) / _MB
for ws in self._current_workspaces
if ws is not None
],
)
def is_locked(self) -> bool:
"""Check if workspace is locked."""
return self._locked
def get_simultaneous(
self, *shapes_and_dtypes: tuple[tuple[int, ...], torch.dtype]
) -> list[torch.Tensor]:
"""Get multiple workspace tensors simultaneously from a single allocation.
Args:
*shapes_and_dtypes: One or more (shape, dtype) tuples.
Returns:
List of tensor views into the workspace buffer, one per shape/dtype pair.
"""
actual_bytes = [_compute_bytes(s, d) for s, d in shapes_and_dtypes]
aligned_bytes = [round_up(actual, 256) for actual in actual_bytes]
total_bytes = sum(aligned_bytes)
# Calculate cumulative offsets using itertools.accumulate
offsets = list(accumulate([0] + aligned_bytes[:-1]))
current_workspace = self._ensure_workspace_size(total_bytes)
return [
current_workspace[offsets[i] : offsets[i] + actual_bytes[i]]
.view(shapes_and_dtypes[i][1])
.reshape(shapes_and_dtypes[i][0])
for i in range(len(shapes_and_dtypes))
]
def _ensure_workspace_size(self, required_bytes: int) -> torch.Tensor:
"""Ensure workspace is allocated and large enough, return current workspace.
Args:
required_bytes: The number of bytes required.
Returns:
The current workspace tensor.
"""
ubatch_id = dbo_current_ubatch_id()
current_workspace = self._current_workspaces[ubatch_id]
current_size = self._workspace_size_bytes(current_workspace)
if current_size < required_bytes:
def get_caller_info() -> str:
"""Find first frame outside WorkspaceManager."""
curr_frame = inspect.currentframe()
if curr_frame is None:
return "unknown"
# Walk up the stack skipping WorkspaceManager frames
curr_frame = curr_frame.f_back
while curr_frame is not None:
# TODO: This only catches instance methods (self), missing
# classmethods and staticmethods. Once Python 3.11+ is the
# minimum supported version, use co_qualname instead:
# qualname = curr_frame.f_code.co_qualname
# if qualname.startswith("WorkspaceManager."):
if isinstance(curr_frame.f_locals.get("self"), WorkspaceManager):
curr_frame = curr_frame.f_back
continue
filename = os.path.basename(curr_frame.f_code.co_filename)
return (
f"{filename}:{curr_frame.f_lineno}:{curr_frame.f_code.co_name}"
)
return "unknown"
if self._locked:
raise AssertionError(
f"Workspace is locked but allocation from '{get_caller_info()}' "
f"requires {required_bytes / _MB:.2f} MB, current size is "
f"{current_size / _MB:.2f} MB. "
"Workspace growth is not allowed after locking."
)
for ubatch_id in range(self._num_ubatches):
current_workspace = self._current_workspaces[ubatch_id]
if (
current_workspace is None
or self._workspace_size_bytes(current_workspace) < required_bytes
):
# Delete old tensor before allocating new one to avoid
# memory spike from resize_(). resize_() allocates new
# memory before freeing old, which can cause OOM.
# Must clear the list reference first since local var
# is just a copy of the reference.
self._current_workspaces[ubatch_id] = None
del current_workspace
self._current_workspaces[ubatch_id] = torch.empty(
(required_bytes,), dtype=torch.uint8, device=self._device
)
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Resized workspace from '%s': %.2f MB -> "
"%.2f MB (%d ubatches, total memory %.2f MB)",
get_caller_info(),
current_size / _MB,
required_bytes / _MB,
self._num_ubatches,
required_bytes * self._num_ubatches / _MB,
)
current_workspace = self._current_workspaces[dbo_current_ubatch_id()]
return current_workspace
def is_workspace_manager_initialized() -> bool:
"""Check if workspace manager has been initialized.
Returns:
True if workspace manager is initialized, False otherwise.
"""
return _manager is not None
def current_workspace_manager() -> "WorkspaceManager":
"""Get the current workspace manager instance.
Raises:
AssertionError: If workspace manager has not been initialized.
"""
assert _manager is not None, (
"WorkspaceManager not initialized. Call init_workspace_manager() "
"with a device before using workspace functions."
)
return _manager
def init_workspace_manager(
device: torch.device, num_ubatches: int | None = None
) -> None:
"""Initialize the workspace manager with a device.
Must be called before using any workspace functions. Typically called
from GPUModelRunner.__init__.
Args:
device: The device to allocate workspace on.
num_ubatches: Number of micro-batches. Defaults to 1.
"""
global _manager
if _manager is not None:
logger.warning(
"WorkspaceManager already initialized on device %s, "
"reinitializing on device %s",
_manager._device,
device,
)
_manager = WorkspaceManager(device, num_ubatches)
def lock_workspace() -> None:
"""Lock the workspace to prevent further growth.
After calling this function, any attempt to allocate a workspace larger
than the current size will raise an AssertionError. This ensures that
workspace size is fixed during execution and prevents unexpected memory
allocations in the hot path.
Example:
# During initialization
init_workspace_manager(device)
reserve_workspace(shape1, dtype1)
reserve_workspace(shape2, dtype2)
# Lock after warmup/profiling
lock_workspace()
# Now all get_workspace calls must fit in pre-allocated size
"""
current_workspace_manager().lock()
def reset_workspace_manager() -> None:
"""Reset the workspace manager to uninitialized state.
This is primarily intended for testing purposes to allow tests
to reinitialize the workspace manager cleanly.
"""
global _manager
_manager = None

View File

@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.torch_utils import supports_xpu_graph
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING:
pass
logger = init_logger(__name__)
class XPUModelRunner(GPUModelRunner):
"""A model runner for XPU devices."""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
with _torch_cuda_wrapper():
super().__init__(vllm_config, device)
# FIXME: To be verified.
self.cascade_attn_enabled = False
def _sync_device(self) -> None:
torch.xpu.synchronize()
@contextmanager
def _torch_cuda_wrapper():
try:
# replace cuda APIs with xpu APIs, this should work by default
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.default_stream = torch.xpu.current_stream
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.stream = torch.xpu.stream
torch.cuda.mem_get_info = torch.xpu.mem_get_info
torch.cuda.synchronize = torch.xpu.synchronize
if supports_xpu_graph():
torch.cuda.graph = torch.xpu.graph
torch.cuda.CUDAGraph = torch.xpu.XPUGraph
torch.cuda.empty_cache = torch.xpu.empty_cache
yield
finally:
pass

View File

@@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import os
from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.profiler.wrapper import TorchProfilerWrapper
from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
from vllm.v1.worker.workspace import init_workspace_manager
from vllm.v1.worker.xpu_model_runner import XPUModelRunner
from .utils import request_memory
logger = init_logger(__name__)
class XPUWorker(Worker):
"""A XPU worker class."""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(
vllm_config, local_rank, rank, distributed_init_method, is_driver_worker
)
device_config = self.device_config
assert device_config.device_type == "xpu"
assert current_platform.is_xpu()
# Torch profiler. Enabled and configured through profiler_config.
self.profiler: Any | None = None
profiler_config = vllm_config.profiler_config
if profiler_config.profiler == "torch":
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
self.profiler = TorchProfilerWrapper(
profiler_config,
worker_name=worker_name,
local_rank=self.local_rank,
activities=["CPU", "XPU"],
)
def init_device(self):
device = self.device_config.device
if (
isinstance(device, torch.device)
and device.type == "xpu"
and current_platform.is_xpu()
):
self.device = torch.device(f"xpu:{self.local_rank}")
current_platform.set_device(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype)
torch.xpu.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties(
self.local_rank
).total_memory
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
ENV_LOCAL_WORLD_SIZE = os.getenv(
"LOCAL_WORLD_SIZE", str(self.parallel_config.world_size)
)
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank)
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
current_platform.dist_backend,
)
# Set random seed.
set_random_seed(self.model_config.seed)
# Now take memory snapshot after NCCL is initialized
gc.collect()
torch.xpu.empty_cache()
# take current memory snapshot
self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
self.requested_memory = request_memory(init_snapshot, self.cache_config)
logger.debug("worker init memory snapshot: %r", self.init_snapshot)
logger.debug(
"worker requested memory: %sGiB", format_gib(self.requested_memory)
)
# Initialize workspace manager
num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
init_workspace_manager(self.device, num_ubatches)
# Construct the model runner
self.model_runner = XPUModelRunner( # type: ignore
self.vllm_config, self.device
)
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)