This commit is contained in:
root
2026-04-09 11:23:47 +08:00
parent 8082d5f4b2
commit 72387e4fa8
1885 changed files with 611521 additions and 1 deletions

View File

@@ -0,0 +1,4 @@
# [Experimental] Model Runner V2
This directory contains the new model runner which is under active development.
Ping [Woosuk Kwon](https://github.com/WoosukKwon) for any changes.

View File

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import numpy as np
import torch
from vllm.v1.outputs import AsyncModelRunnerOutput, LogprobsTensors, ModelRunnerOutput
from vllm.v1.worker.gpu.sample.output import SamplerOutput
class AsyncOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput,
num_sampled_tokens: torch.Tensor,
main_stream: torch.cuda.Stream,
copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event,
):
# NOTE(woosuk): We must retain references to the GPU tensors,
# as the copy operations are performed on a different CUDA stream than
# the one where the tensors were created.
self.model_runner_output = model_runner_output
self.sampler_output = sampler_output
self.num_sampled_tokens = num_sampled_tokens
self.copy_event = copy_event
with stream(copy_stream, main_stream):
copy_stream.wait_stream(main_stream)
self.sampled_token_ids = async_copy_to_np(sampler_output.sampled_token_ids)
self.logprobs_tensors: LogprobsTensors | None = None
if sampler_output.logprobs_tensors is not None:
self.logprobs_tensors = (
sampler_output.logprobs_tensors.to_cpu_nonblocking()
)
self.num_nans: np.ndarray | None = None
if sampler_output.num_nans is not None:
self.num_nans = async_copy_to_np(sampler_output.num_nans)
self.num_sampled_tokens_np = async_copy_to_np(num_sampled_tokens)
self.prompt_logprobs_dict = {
k: v.to_cpu_nonblocking() if v is not None else None
for k, v in self.model_runner_output.prompt_logprobs_dict.items()
}
self.copy_event.record(copy_stream)
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
# NOTE(woosuk): The following code is to ensure compatibility with
# the existing model runner.
# Going forward, we should keep the data structures as NumPy arrays
# rather than Python lists.
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
num_sampled_tokens: list[int] = self.num_sampled_tokens_np.tolist()
for token_ids, num_tokens in zip(sampled_token_ids, num_sampled_tokens):
del token_ids[num_tokens:]
self.model_runner_output.sampled_token_ids = sampled_token_ids
if self.num_nans is not None:
self.model_runner_output.num_nans_in_logits = dict(
zip(self.model_runner_output.req_ids, self.num_nans.tolist())
)
if self.logprobs_tensors is not None:
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
return self.model_runner_output
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
return x.to("cpu", non_blocking=True).numpy()
@contextlib.contextmanager
def stream(to_stream: torch.cuda.Stream, from_stream: torch.cuda.Stream):
"""Lightweight version of torch.cuda.stream() context manager which
avoids current_stream and device lookups.
"""
try:
torch.cuda.set_stream(to_stream)
yield
finally:
torch.cuda.set_stream(from_stream)

View File

@@ -0,0 +1,215 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any, cast
import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheConfig,
KVCacheSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
kv_cache_spec: dict[str, KVCacheSpec] = {}
layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type)
for layer_name, attn_module in attn_layers.items():
# Skip modules that don't need KV cache (eg encoder-only attention)
if spec := attn_module.get_kv_cache_spec(vllm_config):
kv_cache_spec[layer_name] = spec
return kv_cache_spec
def init_attn_backend(
kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device
):
attn_backends: dict[str, type[AttentionBackend]] = {}
attn_groups: list[list[AttentionGroup]] = []
attn_backend_workspace: torch.Tensor | None = None
for kv_cache_group_id, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups
):
layer_names = kv_cache_group_spec.layer_names
layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
group_map: dict[tuple[tuple[str, str], KVCacheSpec], AttentionGroup] = {}
group_order: list[tuple[tuple[str, str], KVCacheSpec]] = []
for layer_name in layer_names:
attn_backend = attn_layers[layer_name].get_attn_backend()
attn_backends[layer_name] = attn_backend
layer_kv_cache_spec: KVCacheSpec = kv_cache_group_spec.kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
key = (attn_backend.full_cls_name(), layer_kv_cache_spec)
if key not in group_map:
group_map[key] = AttentionGroup(
attn_backend,
[layer_name],
layer_kv_cache_spec,
kv_cache_group_id,
)
group_order.append(key)
else:
group_map[key].layer_names.append(layer_name)
groups = [group_map[key] for key in group_order]
for group in groups:
group.create_metadata_builders(
vllm_config=vllm_config,
device=device,
kernel_block_size=None,
num_metadata_builders=1,
)
builder = group.get_metadata_builder(0)
if attn_backend_workspace is None:
if hasattr(builder, "_get_workspace_buffer"):
attn_backend_workspace = builder._get_workspace_buffer()
else:
if hasattr(builder, "set_workspace_buffer"):
builder.set_workspace_buffer(attn_backend_workspace)
attn_groups.append(groups)
return attn_backends, attn_groups
def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys()), (
"Some layers are not correctly initialized"
)
return kv_cache_raw_tensors
def _reshape_kv_cache(
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, AttentionBackend],
) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
assert isinstance(kv_cache_spec, AttentionSpec)
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
attn_backend = attn_backends[layer_name]
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
)
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
dtype = kv_cache_spec.dtype
raw_tensor = raw_tensor.view(dtype)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
return kv_caches
def init_kv_cache(
runner_kv_caches: list[torch.Tensor],
forward_context: dict[str, Any],
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend],
device: torch.device,
) -> dict[str, torch.Tensor]:
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
return kv_caches
def build_slot_mappings_by_layer(
slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig
) -> dict[str, torch.Tensor]:
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for slot_mapping, kv_cache_group in zip(slot_mappings, kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
slot_mappings_by_layer[layer_name] = slot_mapping
return slot_mappings_by_layer
def build_attn_metadata(
attn_groups: list[list[AttentionGroup]],
num_reqs: int,
num_tokens: int,
query_start_loc_gpu: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
max_query_len: int,
seq_lens: torch.Tensor,
max_seq_len: int,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
dcp_local_seq_lens: torch.Tensor | None = None,
) -> dict[str, Any]:
seq_lens = seq_lens[:num_reqs]
if dcp_local_seq_lens is not None:
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]
attn_metadata: dict[str, Any] = {}
num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
for i in range(num_kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
causal=True,
dcp_local_seq_lens=dcp_local_seq_lens,
)
for attn_group in attn_groups[i]:
attn_metadata_builder = attn_group.get_metadata_builder(0)
metadata = attn_metadata_builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata

View File

@@ -0,0 +1,253 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
class BlockTables:
def __init__(
self,
block_sizes: list[int],
max_num_reqs: int,
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
cp_size: int = 1,
cp_rank: int = 0,
cp_interleave: int = 1,
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
self.cp_size = cp_size
self.cp_rank = cp_rank
self.cp_interleave = cp_interleave
self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[StagedWriteTensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
# When using DCP, each request's KV cache is sharded among different ranks.
# As a result, one block on the current rank covers `block_size * cp_size`
# tokens in the full, global (unsharded) sequence.
max_num_blocks = cdiv(self.max_model_len, block_size * self.cp_size)
block_table = StagedWriteTensor(
(self.max_num_reqs, max_num_blocks),
dtype=torch.int32,
device=device,
)
self.block_tables.append(block_table)
self.block_table_ptrs = self._make_ptr_tensor(
[b.gpu for b in self.block_tables]
)
self.block_table_strides = torch.tensor(
[b.gpu.stride(0) for b in self.block_tables],
dtype=torch.int64,
device=self.device,
)
self.block_sizes_tensor = torch.tensor(
self.block_sizes, dtype=torch.int32, device=self.device
)
self.num_blocks = UvaBackedTensor(
(self.num_kv_cache_groups, self.max_num_reqs),
dtype=torch.int32,
)
# Block tables used for model's forward pass.
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.input_block_tables: list[torch.Tensor] = [
torch.zeros_like(b.gpu) for b in self.block_tables
]
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
self.slot_mappings = torch.zeros(
self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device,
)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
return torch.tensor(
[t.data_ptr() for t in x], dtype=torch.uint64, device=self.device
)
def append_block_ids(
self,
req_index: int,
new_block_ids: tuple[list[int], ...],
overwrite: bool,
) -> None:
for i in range(self.num_kv_cache_groups):
start = self.num_blocks.np[i, req_index] if not overwrite else 0
block_ids = new_block_ids[i]
self.block_tables[i].stage_write(req_index, start, block_ids)
self.num_blocks.np[i, req_index] = start + len(block_ids)
def apply_staged_writes(self) -> None:
# TODO(woosuk): This can be inefficient since it launches one kernel per
# block table. Implement a kernel to handle all block tables at once.
for block_table in self.block_tables:
block_table.apply_write()
self.num_blocks.copy_to_uva()
def gather_block_tables(
self, idx_mapping: torch.Tensor
) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
idx_mapping,
self.block_table_ptrs,
self.input_block_table_ptrs,
self.block_table_strides,
self.num_blocks.gpu,
self.num_blocks.gpu.stride(0),
BLOCK_SIZE=1024, # type: ignore
)
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def compute_slot_mappings(
self,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
num_reqs = idx_mapping.shape[0]
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens,
idx_mapping,
query_start_loc,
positions,
self.block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mappings,
self.slot_mappings.stride(0),
self.cp_rank,
CP_SIZE=self.cp_size,
CP_INTERLEAVE=self.cp_interleave,
PAD_ID=PAD_SLOT_ID,
TRITON_BLOCK_SIZE=1024, # type: ignore
)
return self.slot_mappings[:, :num_tokens]
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
self.slot_mappings.fill_(PAD_SLOT_ID)
return self.slot_mappings[:, :num_tokens]
@triton.jit
def _gather_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_kv_cache_groups]
dst_block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
@triton.jit
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
idx_mapping, # [num_reqs]
query_start_loc, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
block_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride,
cp_rank,
CP_SIZE: tl.constexpr,
CP_INTERLEAVE: tl.constexpr,
PAD_ID: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
if batch_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
block_size = tl.load(block_sizes + group_id)
req_state_idx = tl.load(idx_mapping + batch_idx)
start_idx = tl.load(query_start_loc + batch_idx)
end_idx = tl.load(query_start_loc + batch_idx + 1)
for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // (block_size * CP_SIZE)
block_offsets = positions % (block_size * CP_SIZE)
block_numbers = tl.load(
block_table_ptr + req_state_idx * block_table_stride + block_indices
)
if CP_SIZE == 1:
# Common case: Context parallelism is not used.
slot_ids = block_numbers * block_size + block_offsets
else:
# Context parallelism is used.
is_local = block_offsets // CP_INTERLEAVE % CP_SIZE == cp_rank
rounds = block_offsets // (CP_INTERLEAVE * CP_SIZE)
remainder = block_offsets % CP_INTERLEAVE
local_offsets = rounds * CP_INTERLEAVE + remainder
slot_ids = block_numbers * block_size + local_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr)
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
return tl.multiple_of(ptr, 16)

View File

@@ -0,0 +1,220 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence
from functools import partial
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import (
async_tensor_h2d,
get_accelerator_view_from_cpu_tensor,
)
def async_copy_to_gpu(
x: torch.Tensor | np.ndarray,
out: torch.Tensor | None = None,
device: torch.device | None = None,
) -> torch.Tensor:
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
assert x.is_cpu
if out is None:
assert device is not None
out = torch.empty_like(x, device=device)
# CPU-to-CPU copy
tmp = x.pin_memory()
assert tmp is not x
# CPU-to-GPU copy
return out.copy_(tmp, non_blocking=True)
class UvaBuffer:
def __init__(self, size: int | Sequence[int], dtype: torch.dtype):
if not is_uva_available():
raise RuntimeError("UVA is not available")
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.uva = get_accelerator_view_from_cpu_tensor(self.cpu)
class UvaBufferPool:
def __init__(
self,
size: int | Sequence[int],
dtype: torch.dtype,
max_concurrency: int = 2,
):
self.size = size
self.dtype = dtype
self.max_concurrency = max_concurrency
# UVA buffers for concurrency
self._uva_bufs = [UvaBuffer(size, dtype) for _ in range(max_concurrency)]
# Current buffer index
self._curr = 0
def copy_to_uva(self, x: torch.Tensor | np.ndarray | list) -> torch.Tensor:
# Round robin to the next buffer.
self._curr = (self._curr + 1) % self.max_concurrency
buf = self._uva_bufs[self._curr]
# CPU-to-CPU copy
dst = buf.cpu if isinstance(x, torch.Tensor) else buf.np
n = len(x)
dst[:n] = x
return buf.uva[:n]
def copy_to_gpu(
self,
x: torch.Tensor | np.ndarray,
out: torch.Tensor | None = None,
) -> torch.Tensor:
uva = self.copy_to_uva(x)
# CPU-to-GPU copy
return uva.clone() if out is None else out.copy_(uva, non_blocking=True)
class UvaBackedTensor:
def __init__(
self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2
):
self.dtype = dtype
self.max_concurrency = max_concurrency
# Source of truth
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False)
self.np = self.cpu.numpy()
# Buffers for concurrency
self.pool = UvaBufferPool(size, dtype, max_concurrency)
self.gpu = self.pool.copy_to_uva(self.np)
def copy_to_uva(self, n: int | None = None) -> torch.Tensor:
# CPU-to-CPU copy
self.gpu = self.pool.copy_to_uva(self.np[:n] if n is not None else self.np)
return self.gpu
class StagedWriteTensor:
def __init__(
self,
size: int | Sequence[int],
dtype: torch.dtype,
device: torch.device,
max_concurrency: int = 2,
uva_instead_of_gpu: bool = False,
):
supported_dtypes = [torch.int32, torch.int64, torch.float32]
if dtype not in supported_dtypes:
raise ValueError(
f"Unsupported dtype {dtype}: should be one of {supported_dtypes}"
)
self.num_rows = size if isinstance(size, int) else size[0]
self.dtype = dtype
self.device = device
self.max_concurrency = max_concurrency
if not uva_instead_of_gpu:
# Create a GPU tensor (default)
self.gpu = torch.zeros(size, dtype=dtype, device=device)
else:
# For a large but not-frequently-accessed tensor, we can use UVA instead of
# GPU to save GPU memory
self._uva_buf = UvaBuffer(size, dtype)
self.gpu = self._uva_buf.uva
self._staged_write_indices: list[int] = []
self._staged_write_starts: list[int] = []
self._staged_write_contents: list[int | float] = []
self._staged_write_cu_lens: list[int] = []
new_buffer = partial(UvaBufferPool, max_concurrency=max_concurrency)
self.write_indices = new_buffer(self.num_rows, dtype=torch.int32)
self.write_starts = new_buffer(self.num_rows, dtype=torch.int32)
self.write_cu_lens = new_buffer(self.num_rows, dtype=torch.int32)
def stage_write(
self, index: int, start: int, x: Iterable[int] | Iterable[float]
) -> None:
assert index >= 0
assert start >= 0
if not x:
return
self._staged_write_indices.append(index)
self._staged_write_starts.append(start)
self._staged_write_contents.extend(x)
self._staged_write_cu_lens.append(len(self._staged_write_contents))
def stage_write_elem(self, index: int, x: int) -> None:
assert index >= 0
self._staged_write_indices.append(index)
self._staged_write_starts.append(0)
self._staged_write_contents.append(x)
self._staged_write_cu_lens.append(len(self._staged_write_contents))
def apply_write(self) -> None:
n = len(self._staged_write_indices)
if n == 0:
return
indices_uva = self.write_indices.copy_to_uva(self._staged_write_indices)
starts_uva = self.write_starts.copy_to_uva(self._staged_write_starts)
cu_lens_uva = self.write_cu_lens.copy_to_uva(self._staged_write_cu_lens)
# Special handling for write_contents
write_contents = async_tensor_h2d(
self._staged_write_contents, self.dtype, self.device, pin_memory=True
)
# Write diffs to the GPU buffer
_apply_write_kernel[(n,)](
self.gpu,
self.gpu.stride(0),
indices_uva,
starts_uva,
write_contents,
cu_lens_uva,
BLOCK_SIZE=1024,
)
# Clear the staged writes
self.clear_staged_writes()
def clear_staged_writes(self) -> None:
self._staged_write_indices.clear()
self._staged_write_starts.clear()
self._staged_write_contents.clear()
self._staged_write_cu_lens.clear()
@triton.jit
def _apply_write_kernel(
output_ptr,
output_stride,
write_indices_ptr,
write_starts_ptr,
write_contents_ptr,
write_cu_lens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = tl.load(write_indices_ptr + pid)
start_idx = tl.load(write_starts_ptr + pid)
cu_start = tl.load(write_cu_lens_ptr + pid - 1) if pid > 0 else 0
cu_end = tl.load(write_cu_lens_ptr + pid)
content_len = cu_end - cu_start
for i in range(0, content_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < content_len
content = tl.load(write_contents_ptr + cu_start + block, mask=mask)
tl.store(
output_ptr + row_idx * output_stride + start_idx + block, content, mask=mask
)

View File

@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
def prepare_dcp_local_seq_lens(
dcp_local_seq_lens: torch.Tensor,
seq_lens: torch.Tensor,
num_reqs: int,
dcp_size: int,
dcp_rank: int,
cp_interleave: int,
) -> None:
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
if dcp_size == 1:
return
max_num_reqs = dcp_local_seq_lens.shape[0]
BLOCK_SIZE = 128
num_blocks = triton.cdiv(max_num_reqs, BLOCK_SIZE)
_dcp_local_seq_lens_kernel[(num_blocks,)](
dcp_local_seq_lens,
seq_lens,
dcp_size,
dcp_rank,
cp_interleave,
num_reqs,
max_num_reqs,
BLOCK_SIZE,
)
@triton.jit
def _dcp_local_seq_lens_kernel(
out_ptr,
seq_lens_ptr,
dcp_size,
dcp_rank,
cp_interleave,
num_reqs,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
seq_lens = tl.load(seq_lens_ptr + block, mask=block < num_reqs)
# Distribute KV cache among different ranks, in a round-robin manner.
rounds = seq_lens // (dcp_size * cp_interleave)
remainder = seq_lens % (dcp_size * cp_interleave)
remainder = tl.maximum(remainder - dcp_rank * cp_interleave, 0)
remainder = tl.minimum(remainder, cp_interleave)
local_seq_lens = rounds * cp_interleave + remainder
# For [num_reqs, max_num_reqs), pad with 0
local_seq_lens = tl.where(block < num_reqs, local_seq_lens, 0)
tl.store(out_ptr + block, local_seq_lens, mask=block < max_num_reqs)

View File

@@ -0,0 +1,462 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.model_executor.offloader.base import get_offloader
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.utils import AttentionGroup
class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
uses_mrope: bool,
use_aux_hidden_state_outputs: bool,
device: torch.device,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.uses_mrope = uses_mrope
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.uniform_decode_query_len = 1
spec_config = vllm_config.speculative_config
if spec_config is not None:
self.uniform_decode_query_len += spec_config.num_speculative_tokens
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_mode = self.compilation_config.cudagraph_mode
use_uniform_decode_cudagraph = (
self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.cudagraph_mode.separate_routine()
)
self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
self.uniform_decode_query_len,
use_uniform_decode_cudagraph,
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
self.aux_hidden_states: list[torch.Tensor] = []
def needs_capture(self) -> bool:
return len(self.cudagraph_sizes) > 0
def get_cudagraph_size(
self, num_tokens: int, uniform_decode: bool = False
) -> int | None:
if uniform_decode and self.uniform_decode_cudagraph_sizes:
return self.uniform_decode_cudagraph_sizes.get(num_tokens)
return self.cudagraph_sizes.get(num_tokens)
def capture_graph(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
model: nn.Module,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
uniform_decode: bool = False,
) -> None:
# select and check capture function
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
# prepare inputs
if uniform_decode:
num_reqs = min(
cdiv(num_tokens, self.uniform_decode_query_len),
self.max_num_reqs,
)
else:
num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
if self.uses_mrope:
assert mrope_positions is not None
positions = mrope_positions[:, :num_tokens]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:num_tokens]
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
block_tables,
attn_groups,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=(
self.uniform_decode_query_len if uniform_decode else 0
),
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
):
model_output = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Allocate output buffers if not already done.
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states]
capture_fn(
num_tokens=num_tokens,
num_reqs=num_reqs,
model=model,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
num_tokens_across_dp=num_tokens_across_dp,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
has_lora=has_lora,
)
def _capture_full_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
assert attn_metadata is not None
# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
with (
set_forward_context(
attn_metadata=attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
),
torch.cuda.graph(graph, self.pool),
):
model_output = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader().join_after_forward()
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Copy outputs to the output buffers.
assert self.hidden_states is not None
self.hidden_states[:num_tokens] = hidden_states
if self.use_aux_hidden_state_outputs:
for i, aux_hidden in enumerate(aux_hidden_states):
self.aux_hidden_states[i][:num_tokens] = aux_hidden
self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
# create batch descriptor for piecewise cudagraph dispatch key
batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)
# Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
with set_forward_context(
attn_metadata=None, # piecewise no need attn_metadata
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings,
):
model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
@torch.inference_mode()
def capture(
self,
model: nn.Module,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
) -> None:
common_kwargs = dict(
device=self.device,
capture_fn=self.capture_graph,
model=model,
input_buffers=input_buffers,
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
has_lora=has_lora,
)
# Phase 1: Capture for mixed prefill-decode batches if needed.
mixed_mode = self.cudagraph_mode.mixed_mode()
if mixed_mode != CUDAGraphMode.NONE:
capture_graphs(
cudagraph_sizes=self.cudagraph_sizes,
capture_cudagraph_mode=mixed_mode,
desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
uniform_decode=False,
**common_kwargs,
)
# Phase 2: Capture FULL graphs for uniform decode batches if needed.
# This is only needed if we use a separate routine for decode batches
# and the decode_mode is FULL.
if self.uniform_decode_cudagraph_sizes:
capture_graphs(
cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
capture_cudagraph_mode=CUDAGraphMode.FULL,
desc="Capturing CUDA graphs (decode, FULL)",
uniform_decode=True,
**common_kwargs,
)
def get_cudagraph_runtime_mode(
self, num_reqs: int, num_tokens: int, max_query_len: int
) -> tuple[CUDAGraphMode, int | None]:
is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_tokens == max_query_len * num_reqs
)
cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
elif is_uniform_decode:
cudagraph_mode = self.cudagraph_mode.decode_mode()
else:
cudagraph_mode = self.cudagraph_mode.mixed_mode()
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def run_fullgraph(
self, num_tokens: int
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
# Sync offloader before replay - needed when transitioning from
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader().sync_prev_onload()
self.graphs[num_tokens].replay()
assert self.hidden_states is not None
hidden_states = self.hidden_states[:num_tokens]
if not self.use_aux_hidden_state_outputs:
return hidden_states
return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states]
def get_cudagraph_sizes(
capture_sizes: list[int] | None,
max_num_reqs: int,
max_num_tokens: int,
cudagraph_mode: CUDAGraphMode,
uniform_decode_query_len: int = 1,
uniform_decode_cudagraph: bool = False,
) -> tuple[dict[int, int], dict[int, int]]:
# Support both FULL and PIECEWISE cudagraph modes
if cudagraph_mode == CUDAGraphMode.NONE:
return {}, {}
if not capture_sizes:
return {}, {}
capture_sizes = sorted(capture_sizes)
if not capture_sizes:
return {}, {}
cudagraph_sizes: dict[int, int] = {}
for i in range(1, capture_sizes[-1] + 1):
for x in capture_sizes:
if i <= x:
cudagraph_sizes[i] = x
break
uniform_decode_cudagraph_sizes: dict[int, int] = {}
if uniform_decode_cudagraph:
max_num_tokens = max_num_reqs * uniform_decode_query_len
uniform_decode_cudagraph_sizes = {
k: v
for k, v in cudagraph_sizes.items()
if v <= max_num_tokens and v >= uniform_decode_query_len
}
return cudagraph_sizes, uniform_decode_cudagraph_sizes
def capture_graphs(
cudagraph_sizes: dict[int, int],
device: torch.device,
capture_fn: Callable,
capture_cudagraph_mode: CUDAGraphMode,
desc: str = "Capturing CUDA graphs",
**capture_kwargs,
) -> None:
# Capture larger graphs first.
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
with graph_capture(device=device):
for size in sizes_to_capture:
capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
def prepare_inputs_to_capture(
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
max_model_len: int,
kv_cache_config: KVCacheConfig,
uniform_decode_query_len: int = 0,
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
if uniform_decode_query_len > 0:
num_tokens_per_req = uniform_decode_query_len
else:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
query_start_loc_np[-1] = num_tokens
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len.
input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, kv_cache_config
)
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=num_tokens_per_req,
seq_lens=input_buffers.seq_lens,
max_seq_len=max_model_len,
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
)
return attn_metadata, slot_mappings_by_layer

View File

@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import get_dp_group
def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None:
if dp_size == 1:
return None
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
def get_batch_metadata_across_dp(
num_tokens: int,
cudagraph_size: int,
cudagraph_runtime_mode: int,
dp_size: int,
dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert dp_size > 1
# Use CPU group to avoid CPU-GPU synchronization.
group = get_dp_group().cpu_group
tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
tensor[0][dp_rank] = num_tokens
tensor[1][dp_rank] = cudagraph_size
tensor[2][dp_rank] = cudagraph_runtime_mode
dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1], tensor[2]
def get_cudagraph_and_dp_padding(
num_tokens: int,
cudagraph_size: int | None,
cudagraph_runtime_mode: int,
dp_size: int,
dp_rank: int,
) -> tuple[int, torch.Tensor | None, int]:
if dp_size == 1:
if cudagraph_size is not None:
return cudagraph_size, None, cudagraph_runtime_mode
else:
return num_tokens, None, cudagraph_runtime_mode
# Convert None to -1 for sync (indicates no cudagraph available)
if num_tokens == 0:
cudagraph_size = 0
elif cudagraph_size is None:
cudagraph_size = -1
num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = (
get_batch_metadata_across_dp(
num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank
)
)
if torch.all(num_tokens_across_dp == 0).item():
# All ranks have zero tokens to run.
return 0, None, 0
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item())
# Check if all ranks have valid cudagraph_size.
all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
if synced_cudagraph_mode != 0 and all_have_cudagraph:
# All ranks use cudagraph. Pad to max cudagraph_size.
max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
num_tokens_across_dp[:] = max_cudagraph_size
return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
else:
# Fall back to eager mode (no cudagraph).
# Either some rank doesn't have cudagraph size or mode is NONE.
synced_cudagraph_mode = 0
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode

View File

@@ -0,0 +1,548 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.utils import random_uuid
class InputBuffers:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.device = device
self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
# DCP: per-request local seq_lens buffer
self.dcp_local_seq_lens = torch.zeros(
max_num_reqs, dtype=torch.int32, device=device
)
@dataclass
class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
num_reqs: int
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# Identical to idx_mapping except for spec decoding.
expanded_idx_mapping: torch.Tensor
# [total_num_logits] position within request for each logit
expanded_local_pos: torch.Tensor
# [num_reqs]
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
# sum(num_scheduled_tokens)
num_tokens: int
num_tokens_after_padding: int
num_draft_tokens: int
# [num_reqs + 1]
query_start_loc: torch.Tensor
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
# [num_tokens_after_padding]
input_ids: torch.Tensor
# [num_tokens_after_padding]
positions: torch.Tensor
# [3, num_tokens_after_padding]
mrope_positions: torch.Tensor | None
# [num_tokens_after_padding, hidden_size]
inputs_embeds: torch.Tensor | None
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# layer_name -> slot_mapping
slot_mappings: dict[str, torch.Tensor]
# [total_num_logits]
logits_indices: torch.Tensor
# [num_reqs + 1]
cu_num_logits: torch.Tensor
cu_num_logits_np: np.ndarray
# Whether any requests in batch use structured output.
has_structured_output_reqs: bool
@classmethod
def make_dummy(
cls,
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
device: torch.device,
) -> "InputBatch":
assert 0 < num_reqs <= num_tokens
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
expanded_idx_mapping = idx_mapping
expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
assert int(num_scheduled_tokens.sum()) == num_tokens
# seq_len equals to query_len
input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
# Pad for full CUDA graph mode.
input_buffers.seq_lens[num_reqs:] = 0
seq_lens = input_buffers.seq_lens[:num_reqs]
query_start_loc_np = np.empty(num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
input_buffers.query_start_loc[:1] = 0
torch.cumsum(
seq_lens, dim=0, out=input_buffers.query_start_loc[1 : num_reqs + 1]
)
# Pad for full CUDA graph mode.
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
input_ids = input_buffers.input_ids[:num_tokens].zero_()
positions = input_buffers.positions[:num_tokens].zero_()
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping,
expanded_local_pos=expanded_local_pos,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
num_draft_tokens=0,
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
input_ids=input_ids,
positions=positions,
mrope_positions=None,
inputs_embeds=None,
attn_metadata=None, # type: ignore
slot_mappings=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
has_structured_output_reqs=False,
)
@triton.jit
def _prepare_prefill_inputs_kernel(
input_ids_ptr,
next_prefill_tokens_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
all_token_ids_ptr,
all_token_ids_stride,
prefill_lens_ptr,
num_computed_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
if num_computed >= prefill_len:
# Not prefill.
return
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
request_ptr = all_token_ids_ptr + req_state_idx * all_token_ids_stride
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
tokens = tl.load(request_ptr + num_computed + block, mask=mask)
tl.store(input_ids_ptr + query_start + block, tokens, mask=mask)
next_pos = num_computed + query_len
if next_pos < prefill_len:
next_token = tl.load(request_ptr + next_pos)
tl.store(next_prefill_tokens_ptr + req_state_idx, next_token)
def prepare_prefill_inputs(
input_ids: torch.Tensor,
next_prefill_tokens: torch.Tensor,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
all_token_ids: torch.Tensor,
prefill_len: torch.Tensor,
num_computed_tokens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_prepare_prefill_inputs_kernel[(num_reqs,)](
input_ids,
next_prefill_tokens,
idx_mapping,
query_start_loc,
all_token_ids,
all_token_ids.stride(0),
prefill_len,
num_computed_tokens,
BLOCK_SIZE=1024,
)
@triton.jit
def _prepare_pos_seq_lens_kernel(
pos_ptr,
seq_lens_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
num_computed_tokens_ptr,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
req_id = tl.program_id(0)
num_reqs = tl.num_programs(0) - 1
if req_id == num_reqs:
# Pad unused seq_lens as 0 for full CUDA graphs.
for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(seq_lens_ptr + block, 0, mask=mask)
return
req_state_idx = tl.load(idx_mapping_ptr + req_id)
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
start = tl.load(query_start_loc_ptr + req_id)
end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = end - start
seq_len = num_computed_tokens + query_len
tl.store(seq_lens_ptr + req_id, seq_len)
for i in tl.range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
pos = num_computed_tokens + block
tl.store(pos_ptr + start + block, pos, mask=mask)
def prepare_pos_seq_lens(
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
num_computed_tokens: torch.Tensor,
pos: torch.Tensor,
seq_lens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
# NOTE(woosuk): We do +1 because the last thread block is used
# to pad unused seq_lens as 0 for full CUDA graphs.
_prepare_pos_seq_lens_kernel[(num_reqs + 1,)](
pos,
seq_lens,
idx_mapping,
query_start_loc,
num_computed_tokens,
seq_lens.shape[0],
BLOCK_SIZE=1024,
)
@triton.jit
def _combine_sampled_and_draft_tokens_kernel(
input_ids_ptr,
idx_mapping_ptr,
last_sampled_tokens_ptr,
query_start_loc_ptr,
seq_lens_ptr,
prefill_len_ptr,
draft_tokens_ptr,
draft_tokens_stride,
cu_num_logits_ptr,
logits_indices_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
# Get the number of logits and draft tokens.
cu_num_logits_start = tl.load(cu_num_logits_ptr + batch_idx)
cu_num_logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
num_logits = cu_num_logits_end - cu_num_logits_start
num_draft_tokens = num_logits - 1
# Compute the logits indices.
block = tl.arange(0, BLOCK_SIZE)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
logits_start = query_end - num_logits
tl.store(
logits_indices_ptr + cu_num_logits_start + block,
logits_start + block,
mask=block < num_logits,
)
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if seq_len <= prefill_len:
# Handling prefill tokens. No sampled or draft tokens.
return
# Write the last sampled token ID to input_ids.
last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
tl.store(input_ids_ptr + query_end - num_logits, last_token_id)
# Write the draft tokens (if any) to input_ids.
if num_draft_tokens > 0:
mask = block < num_draft_tokens
draft_tokens = tl.load(
draft_tokens_ptr + req_state_idx * draft_tokens_stride + block,
mask=mask,
)
tl.store(
input_ids_ptr + query_end - num_draft_tokens + block,
draft_tokens,
mask=mask,
)
def combine_sampled_and_draft_tokens(
input_ids: torch.Tensor,
idx_mapping: torch.Tensor,
last_sampled_tokens: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
prefill_len: torch.Tensor,
draft_tokens: torch.Tensor,
cu_num_logits: torch.Tensor,
num_logits: int,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
num_speculative_steps = draft_tokens.shape[-1]
logits_indices = torch.empty(
num_logits,
dtype=torch.int64,
device=input_ids.device,
)
_combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
input_ids,
idx_mapping,
last_sampled_tokens,
query_start_loc,
seq_lens,
prefill_len,
draft_tokens,
draft_tokens.stride(0),
cu_num_logits,
logits_indices,
# NOTE(woosuk): Add 1 to ensure the block can cover the last sampled token
# in addition to all draft tokens.
BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
)
return logits_indices
@triton.jit
def _get_num_sampled_and_rejected_kernel(
num_sampled_ptr,
num_rejected_ptr,
seq_lens_ptr,
cu_num_logits_ptr,
idx_mapping_ptr,
prefill_len_ptr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
is_chunked_prefilling = seq_len < prefill_len
num_sampled = tl.load(num_sampled_ptr + batch_idx)
num_sampled = tl.where(is_chunked_prefilling, 0, num_sampled)
tl.store(num_sampled_ptr + batch_idx, num_sampled)
logits_start = tl.load(cu_num_logits_ptr + batch_idx)
logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
num_logits = logits_end - logits_start
num_rejected = num_logits - num_sampled
num_rejected = tl.where(is_chunked_prefilling, 0, num_rejected)
tl.store(num_rejected_ptr + batch_idx, num_rejected)
def get_num_sampled_and_rejected(
num_sampled: torch.Tensor,
seq_lens: torch.Tensor,
cu_num_logits: torch.Tensor,
idx_mapping: torch.Tensor,
prefill_len: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = idx_mapping.shape[0]
num_rejected = torch.empty_like(num_sampled)
_get_num_sampled_and_rejected_kernel[(num_reqs,)](
num_sampled,
num_rejected,
seq_lens,
cu_num_logits,
idx_mapping,
prefill_len,
)
return num_sampled, num_rejected
@triton.jit
def _post_update_kernel(
idx_mapping_ptr,
num_computed_tokens_ptr,
last_sampled_tokens_ptr,
output_bin_counts_ptr,
output_bin_counts_stride,
sampled_tokens_ptr,
sampled_tokens_stride,
num_sampled_ptr,
num_rejected_ptr,
query_start_loc_ptr,
all_token_ids_ptr,
all_token_ids_stride,
total_len_ptr,
):
req_id = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_id)
total_len = tl.load(total_len_ptr + req_state_idx)
num_sampled = tl.load(num_sampled_ptr + req_id)
if num_sampled > 0:
token_id = tl.load(
sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1
)
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
tl.store(total_len_ptr + req_state_idx, total_len + num_sampled)
for i in range(num_sampled):
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
token_ptr = (
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
)
count = tl.load(token_ptr)
count += 1
tl.store(token_ptr, count)
tl.store(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
token_id,
)
query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = query_end - query_start
num_rejected = tl.load(num_rejected_ptr + req_id)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
num_computed += query_len - num_rejected
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
def post_update(
# [num_reqs]
idx_mapping: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [max_num_reqs]
last_sampled_tokens: torch.Tensor,
# [max_num_reqs, vocab_size]
output_bin_counts: torch.Tensor,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens: torch.Tensor,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [num_reqs + 1]
query_start_loc: torch.Tensor,
# [max_num_reqs, max_model_len]
all_token_ids: torch.Tensor,
# [max_num_reqs]
total_len: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_post_update_kernel[(num_reqs,)](
idx_mapping,
num_computed_tokens,
last_sampled_tokens,
output_bin_counts,
output_bin_counts.stride(0),
sampled_tokens,
sampled_tokens.stride(0),
num_sampled,
num_rejected,
query_start_loc,
all_token_ids,
all_token_ids.stride(0),
total_len,
num_warps=1,
)
@triton.jit
def _expand_idx_mapping_kernel(
idx_mapping_ptr,
expanded_idx_mapping_ptr,
expanded_local_pos_ptr,
cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
block = tl.arange(0, BLOCK_SIZE)
mask = block < num_tokens
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
tl.store(expanded_idx_mapping_ptr + start_idx + block, req_state_idx, mask=mask)
tl.store(expanded_local_pos_ptr + start_idx + block, block, mask=mask)
def expand_idx_mapping(
idx_mapping: torch.Tensor,
total_num_logits: int,
cu_num_logits: torch.Tensor,
max_expand_len: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = idx_mapping.shape[0]
expanded_idx_mapping = idx_mapping.new_empty(total_num_logits)
expanded_local_pos = torch.empty(
total_num_logits, dtype=torch.int32, device=idx_mapping.device
)
_expand_idx_mapping_kernel[(num_reqs,)](
idx_mapping,
expanded_idx_mapping,
expanded_local_pos,
cu_num_logits,
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
)
return expanded_idx_mapping, expanded_local_pos

View File

@@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
kv_transfer_state,
)
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.forward_context import (
get_forward_context,
is_forward_context_available,
set_forward_context,
)
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
ModelRunnerOutput,
)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
class KVConnector:
"""KVConnector interface used by GPUModelRunner."""
def pre_forward(self, scheduler_output: "SchedulerOutput") -> None:
pass
def post_forward(
self, scheduler_output: "SchedulerOutput", wait_for_save: bool = True
) -> KVConnectorOutput | None:
return None
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
return EMPTY_MODEL_RUNNER_OUTPUT
def set_disabled(self, disabled: bool) -> None:
pass
class ActiveKVConnector(KVConnector):
def __init__(
self, vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
):
self.vllm_config = vllm_config
self.kv_connector = get_kv_transfer_group()
# Register kv caches with KV Connector if applicable.
# TODO: support cross_layers_kv_cache
# (see https://github.com/vllm-project/vllm/pull/27743)
self.kv_connector.register_kv_caches(kv_caches_dict)
self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks)
self._disabled = False
def pre_forward(self, scheduler_output: "SchedulerOutput") -> None:
if self._disabled:
return
if scheduler_output.preempted_req_ids:
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
self.kv_connector.bind_connector_metadata(kv_connector_metadata)
# TODO: sort out KV Connectors' use of forward_context
if is_forward_context_available():
self.kv_connector.start_load_kv(get_forward_context())
else:
with set_forward_context(None, self.vllm_config):
self.kv_connector.start_load_kv(get_forward_context())
def post_forward(
self,
scheduler_output: "SchedulerOutput",
wait_for_save: bool = True,
clear_metadata: bool = True,
) -> KVConnectorOutput | None:
if self._disabled:
return None
output = KVConnectorOutput()
if wait_for_save:
self.kv_connector.wait_for_save()
output.finished_sending, output.finished_recving = (
self.kv_connector.get_finished(scheduler_output.finished_req_ids)
)
output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors()
output.kv_connector_stats = self.kv_connector.get_kv_connector_stats()
output.kv_cache_events = self.kv_connector.get_kv_connector_kv_cache_events()
if clear_metadata:
self.kv_connector.clear_connector_metadata()
return output
def clear_metadata(self) -> None:
"""Clear the connector metadata. Call this after draft model runs."""
if not self._disabled:
self.kv_connector.clear_connector_metadata()
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
if self._disabled:
return EMPTY_MODEL_RUNNER_OUTPUT
self.pre_forward(scheduler_output)
kv_connector_output = self.post_forward(scheduler_output, wait_for_save=False)
if kv_connector_output is None or kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
def set_disabled(self, disabled: bool) -> None:
# Ensure that layer-wise connector hooks aren't called when disabled.
kv_transfer_state._KV_CONNECTOR_AGENT = None if disabled else self.kv_connector
self._disabled = disabled
NO_OP_KV_CONNECTOR = KVConnector()
def get_kv_connector(
vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
) -> KVConnector:
if not has_kv_transfer_group():
# No-op connector.
return NO_OP_KV_CONNECTOR
return ActiveKVConnector(vllm_config, kv_caches_dict)

View File

@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
from vllm.lora.request import LoRARequest
NO_LORA_ID = 0
class LoraState:
def __init__(self, max_num_reqs: int):
self.lora_ids = np.zeros(max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
# req_id -> lora_request
self.lora_requests: dict[str, LoRARequest] = {}
def add_request(
self, req_id: str, req_index: int, lora_request: LoRARequest | None
) -> None:
if lora_request is not None:
self.lora_requests[req_id] = lora_request
self.lora_ids[req_index] = lora_request.lora_int_id
else:
self.lora_ids[req_index] = NO_LORA_ID
def remove_request(self, req_id: str) -> None:
self.lora_requests.pop(req_id, None)
def make_lora_inputs(
self,
req_ids: list[str],
idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray,
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
lora_ids = self.lora_ids[idx_mapping]
prompt_lora_mapping = tuple(lora_ids)
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set()
for req_id in req_ids:
lora_request = self.lora_requests.get(req_id)
if lora_request is not None:
active_lora_requests.add(lora_request)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests

View File

View File

@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch._inductor.runtime.triton_helpers import libdevice
from vllm.triton_utils import tl, triton
@triton.jit
def _num_nans_kernel(
logits_ptr,
logits_stride,
num_nans_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_nans = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=0
)
logits = logits.to(tl.float32)
is_nan = libdevice.isnan(logits).to(tl.int1)
num_nans += tl.sum(is_nan).to(tl.int32)
tl.store(num_nans_ptr + req_idx, num_nans)
def get_num_nans(logits: torch.Tensor) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device)
_num_nans_kernel[(num_reqs,)](
logits,
logits.stride(0),
num_nans,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
return num_nans

View File

View File

@@ -0,0 +1,183 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
class EncoderRunner:
def __init__(
self,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.dtype = dtype
self.device = device
self.inputs_embeds = torch.zeros(
max_num_tokens, hidden_size, dtype=dtype, device=device
)
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
self.encoder_cache: dict[str, torch.Tensor] = {}
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_cache.clear()
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
self.req_id_to_mm_features[req_id] = mm_features
def free_encoder_cache(self, mm_hash: str) -> None:
self.encoder_cache.pop(mm_hash, None)
def remove_request(self, req_id: str) -> None:
self.req_id_to_mm_features.pop(req_id, None)
def prepare_mm_inputs(
self, scheduled_encoder_inputs: dict[str, list[int]]
) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]:
mm_hashes: list[str] = []
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
mm_features = self.req_id_to_mm_features[req_id]
for mm_input_id in encoder_input_ids:
mm_feature = mm_features[mm_input_id]
if mm_feature.data is None:
continue
mm_hashes.append(mm_feature.identifier)
mm_kwargs.append((mm_feature.modality, mm_feature.data))
return mm_hashes, mm_kwargs
@torch.inference_mode()
def execute_mm_encoder(
self,
model: SupportsMultiModal,
mm_hashes: list[str],
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
) -> list[torch.Tensor]:
if not mm_hashes:
return []
encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, device=self.device, pin_memory=False
):
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs, expected_num_items=num_items
)
encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
return encoder_outputs
def gather_mm_embeddings(
self,
req_ids: list[str],
total_num_scheduled_tokens: int,
num_scheduled_tokens: np.ndarray,
query_start_loc: np.ndarray,
prefill_lens: np.ndarray,
computed_prefill_lens: np.ndarray,
) -> tuple[list[torch.Tensor], torch.Tensor]:
is_prefilling = (computed_prefill_lens < prefill_lens).tolist()
all_decode = not any(is_prefilling)
if all_decode:
# All decode requests, so no need to gather any embeddings.
return [], torch.zeros(
total_num_scheduled_tokens, dtype=torch.bool, device=self.device
)
query_start = computed_prefill_lens.tolist()
query_end = (computed_prefill_lens + num_scheduled_tokens).tolist()
mm_embeds: list[torch.Tensor] = []
is_mm_embed = torch.zeros(
total_num_scheduled_tokens, dtype=torch.bool, device="cpu", pin_memory=True
)
for i, req_id in enumerate(req_ids):
if not is_prefilling[i]:
# OPTIMIZATION: Skip decode requests.
continue
mm_features = self.req_id_to_mm_features[req_id]
for mm_feature in mm_features:
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
if start_pos >= query_end[i]:
# The encoder output is not needed in this step.
break
if start_pos + num_encoder_tokens <= query_start[i]:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx = max(query_start[i] - start_pos, 0)
end_idx = min(query_end[i] - start_pos, num_encoder_tokens)
assert start_idx < end_idx
curr_embeds_start, curr_embeds_end = (
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
)
# If there are no embeddings in the current range, we skip
# gathering the embeddings.
if curr_embeds_start == curr_embeds_end:
continue
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
else:
mm_embeds_item = encoder_output[start_idx:end_idx]
req_start_pos = query_start_loc[i] + start_pos - query_start[i]
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True if is_embed is None else is_embed
)
mm_embeds.append(mm_embeds_item)
# Copy the is_mm_embed tensor to the GPU.
is_mm_embed = is_mm_embed.to(device=self.device, non_blocking=True)
return mm_embeds, is_mm_embed
@torch.inference_mode()
def get_inputs_embeds(
self,
model: SupportsMultiModal,
input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor],
is_mm_embed: torch.Tensor,
) -> torch.Tensor:
x = model.embed_input_ids(
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
)
# Copy to the pre-allocated buffer for CUDA graphs.
self.inputs_embeds[: x.shape[0]] = x
return self.inputs_embeds

View File

@@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.models.interfaces import SupportsMRoPE
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
class MRopeState:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
max_model_len: int,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.max_model_len = max_model_len
self.device = device
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# wasting a lot of CPU memory.
self.prefill_mrope_positions = StagedWriteTensor(
(max_num_reqs * 3, max_model_len),
dtype=torch.int32,
device=device,
uva_instead_of_gpu=True,
)
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros(
(3, max_num_tokens + 1), dtype=torch.int64, device=device
)
def init_prefill_mrope_positions(
self,
req_idx: int,
mrope_model: SupportsMRoPE,
prefill_token_ids: list[int],
mm_features: list,
) -> None:
prefill_mrope_positions, prefill_mrope_delta = (
mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features)
)
for i in range(3):
pos = prefill_mrope_positions[i].tolist()
self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos)
self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta
def apply_staged_writes(self) -> None:
self.prefill_mrope_positions.apply_write()
self.prefill_mrope_delta.copy_to_uva()
def prepare_mrope_positions(
self,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
prefill_lens: torch.Tensor,
num_computed_tokens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_prepare_mrope_positions_kernel[(num_reqs,)](
self.mrope_positions,
self.mrope_positions.stride(0),
self.prefill_mrope_positions.gpu,
3 * self.max_model_len,
self.max_model_len,
self.prefill_mrope_delta.gpu,
idx_mapping,
query_start_loc,
prefill_lens,
num_computed_tokens,
BLOCK_SIZE=1024,
)
@triton.jit
def _prepare_mrope_positions_kernel(
mrope_positions_ptr,
mrope_positions_stride,
prefill_mrope_positions_ptr,
prefill_mrope_positions_stride0,
prefill_mrope_positions_stride1,
prefill_mrope_delta_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
prefill_lens_ptr,
num_computed_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
is_prefill = num_computed < prefill_len
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
mrope_delta = tl.load(prefill_mrope_delta_ptr + req_state_idx)
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
orig_pos = num_computed + block
for j in tl.static_range(3):
if is_prefill:
# Read from pre-computed M-RoPE positions.
pos = tl.load(
prefill_mrope_positions_ptr
+ req_state_idx * prefill_mrope_positions_stride0
+ j * prefill_mrope_positions_stride1
+ orig_pos,
mask=mask,
)
else:
# Apply M-RoPE delta.
pos = orig_pos + mrope_delta
tl.store(
mrope_positions_ptr + j * mrope_positions_stride + query_start + block,
pos,
mask=mask,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pipeline Parallelism utils for V2 Model Runner."""
import torch
from vllm.distributed.parallel_state import get_pp_group
def pp_broadcast(
sampled_token_ids: torch.Tensor,
num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
) -> None:
pp = get_pp_group()
assert pp.is_last_rank
assert sampled_token_ids.dtype == torch.int64
torch.distributed.broadcast(
sampled_token_ids.contiguous(), src=pp.last_rank, group=pp.device_group
)
combined = torch.stack((num_sampled, num_rejected), dim=0)
torch.distributed.broadcast(combined, src=pp.last_rank, group=pp.device_group)
def pp_receive(
num_reqs: int, max_sample_len: int = 1
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pp = get_pp_group()
assert not pp.is_last_rank
sampled_tokens = torch.empty(
num_reqs, max_sample_len, dtype=torch.int64, device=pp.device
)
torch.distributed.broadcast(sampled_tokens, src=pp.last_rank, group=pp.device_group)
combined = torch.empty(2, num_reqs, dtype=torch.int32, device=pp.device)
torch.distributed.broadcast(combined, src=pp.last_rank, group=pp.device_group)
num_sampled, num_rejected = combined.unbind(dim=0)
return sampled_tokens, num_sampled, num_rejected

View File

View File

@@ -0,0 +1,194 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
from vllm.v1.worker.gpu.states import RequestState
MAX_BAD_WORDS_TOTAL_TOKENS = 1024 # Max total tokens for all bad words per request
MAX_NUM_BAD_WORDS = 128 # Max number of bad words per request
class BadWordsState:
def __init__(self, req_states: RequestState):
self.req_states = req_states
self.max_num_reqs = req_states.max_num_reqs
self.device = req_states.device
# flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS]
self.bad_word_token_ids = StagedWriteTensor(
(self.max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS),
dtype=torch.int32,
device=self.device,
)
# cumulative offsets of bad words: [max_num_reqs, MAX_NUM_BAD_WORDS + 1]
self.bad_word_offsets = StagedWriteTensor(
(self.max_num_reqs, MAX_NUM_BAD_WORDS + 1),
dtype=torch.int32,
device=self.device,
)
# number of bad words per request
self.num_bad_words = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
bad_words_token_ids = sampling_params.bad_words_token_ids
if not bad_words_token_ids:
self.num_bad_words.np[req_idx] = 0
return
num_bad_words = len(bad_words_token_ids)
if num_bad_words > MAX_NUM_BAD_WORDS:
raise ValueError(
f"Too many bad words: {num_bad_words}. "
f"The max number is {MAX_NUM_BAD_WORDS}."
)
# Flatten bad words and compute offsets
flattened_tokens: list[int] = []
offsets: list[int] = [0]
for bad_word in bad_words_token_ids:
flattened_tokens.extend(bad_word)
offsets.append(len(flattened_tokens))
if len(flattened_tokens) > MAX_BAD_WORDS_TOTAL_TOKENS:
raise ValueError(
f"Too many total bad word tokens: {len(flattened_tokens)}. "
f"The max is {MAX_BAD_WORDS_TOTAL_TOKENS}."
)
# Stage writes
self.bad_word_token_ids.stage_write(req_idx, 0, flattened_tokens)
self.bad_word_offsets.stage_write(req_idx, 0, offsets)
self.num_bad_words.np[req_idx] = num_bad_words
def apply_staged_writes(self) -> None:
self.num_bad_words.copy_to_uva()
self.bad_word_token_ids.apply_write()
self.bad_word_offsets.apply_write()
def apply_bad_words(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> None:
max_num_bad_words = int(self.num_bad_words.np[idx_mapping_np].max())
if max_num_bad_words == 0:
# No request uses bad words. Skip the kernel launch.
return
apply_bad_words(
logits,
idx_mapping,
self.bad_word_token_ids.gpu,
self.bad_word_offsets.gpu,
self.num_bad_words.gpu,
self.req_states.all_token_ids.gpu,
self.req_states.prompt_len.gpu,
self.req_states.total_len.gpu,
input_ids,
expanded_local_pos,
max_num_bad_words,
)
@triton.jit
def _bad_words_kernel(
logits_ptr,
logits_stride,
expanded_idx_mapping_ptr,
bad_word_token_ids_ptr,
bad_word_token_ids_stride,
bad_word_offsets_ptr,
bad_word_offsets_stride,
num_bad_words_ptr,
all_token_ids_ptr,
all_token_ids_stride,
prompt_len_ptr,
total_len_ptr,
input_ids_ptr,
expanded_local_pos_ptr,
):
logit_idx = tl.program_id(0)
bw_idx = tl.program_id(1)
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
if bw_idx >= num_bad_words:
return
pos = tl.load(expanded_local_pos_ptr + logit_idx)
cur_req_first_pos = logit_idx - pos
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
total_len = tl.load(total_len_ptr + req_state_idx)
output_len = total_len - prompt_len
effective_len = output_len + pos
bd_offsets_base = bad_word_offsets_ptr + req_state_idx * bad_word_offsets_stride
bd_tokens_base = bad_word_token_ids_ptr + req_state_idx * bad_word_token_ids_stride
output_base = all_token_ids_ptr + req_state_idx * all_token_ids_stride + prompt_len
start = tl.load(bd_offsets_base + bw_idx)
end = tl.load(bd_offsets_base + bw_idx + 1)
bad_word_len = end - start
prefix_len = bad_word_len - 1
if prefix_len > effective_len:
return
last_token = tl.load(bd_tokens_base + end - 1)
match = 1
for i in range(prefix_len):
expected = tl.load(bd_tokens_base + start + i)
actual_pos = effective_len - prefix_len + i
from_spec_input = actual_pos >= output_len
if from_spec_input:
spec_offset = actual_pos - output_len
actual = tl.load(input_ids_ptr + cur_req_first_pos + spec_offset)
else:
actual = tl.load(output_base + actual_pos)
match = match & (expected == actual)
if match:
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
def apply_bad_words(
logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
bad_word_token_ids: torch.Tensor,
bad_word_offsets: torch.Tensor,
num_bad_words: torch.Tensor,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
total_len: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
max_num_bad_words: int,
) -> None:
total_num_tokens = logits.shape[0]
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
logits,
logits.stride(0),
expanded_idx_mapping,
bad_word_token_ids,
bad_word_token_ids.stride(0),
bad_word_offsets,
bad_word_offsets.stride(0),
num_bad_words,
all_token_ids,
all_token_ids.stride(0),
prompt_len,
total_len,
input_ids,
expanded_local_pos,
)

View File

@@ -0,0 +1,149 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _temperature_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
temperature_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
if temperature == 0.0 or temperature == 1.0:
# Early return to avoid loading logits.
return
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
logits = logits / temperature
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
def apply_temperature(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
temperature: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_temperature_kernel[(num_reqs, num_blocks)](
logits,
logits.stride(0),
idx_mapping,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
@triton.jit
def _gumbel_sample_kernel(
local_argmax_ptr,
local_argmax_stride,
local_max_ptr,
local_max_stride,
logits_ptr,
logits_stride,
idx_mapping_ptr,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + batch_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise in FP32.
u = tl.rand(gumbel_seed, block)
u = tl.maximum(u, 1e-7)
gumbel_noise = -tl.log(-tl.log(u))
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
value, idx = tl.max(logits, axis=0, return_indices=True)
token_id = block_idx * BLOCK_SIZE + idx
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
idx_mapping: torch.Tensor, # [max_num_reqs]
temperature: torch.Tensor, # [max_num_reqs]
seed: torch.Tensor, # [max_num_reqs]
pos: torch.Tensor, # [num_reqs]
apply_temperature: bool,
) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_reqs,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_reqs,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
_gumbel_sample_kernel[(num_reqs, num_blocks)](
local_argmax,
local_argmax.stride(0),
local_max,
local_max.stride(0),
logits,
logits.stride(0),
idx_mapping,
seed,
pos,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
APPLY_TEMPERATURE=apply_temperature,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
return sampled

View File

@@ -0,0 +1,280 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
MAX_NUM_ALLOWED_TOKEN_IDS = 1024
MAX_NUM_LOGIT_BIAS_TOKENS = 1024
MAX_NUM_STOP_TOKEN_IDS = 128
class LogitBiasState:
def __init__(self, max_num_reqs: int, device: torch.device):
self.max_num_reqs = max_num_reqs
# Allowed token IDs.
self.num_allowed_token_ids = UvaBackedTensor(
self.max_num_reqs, dtype=torch.int32
)
self.allowed_token_ids = StagedWriteTensor(
(self.max_num_reqs, MAX_NUM_ALLOWED_TOKEN_IDS),
dtype=torch.int32,
device=device,
)
# Logit bias.
self.num_logit_bias = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.logit_bias_token_ids = StagedWriteTensor(
(self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
dtype=torch.int32,
device=device,
)
self.logit_bias = StagedWriteTensor(
(self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
dtype=torch.float32,
device=device,
)
# Min tokens.
self.min_lens = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.num_stop_token_ids = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.stop_token_ids = StagedWriteTensor(
(self.max_num_reqs, MAX_NUM_STOP_TOKEN_IDS),
dtype=torch.int32,
device=device,
)
# Using any of the above.
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
def add_request(
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
) -> None:
# Using any logit bias.
use_logit_bias = False
# Allowed token IDs.
allowed_token_ids = sampling_params.allowed_token_ids
if allowed_token_ids:
num_allowed_token_ids = len(allowed_token_ids)
if num_allowed_token_ids > MAX_NUM_ALLOWED_TOKEN_IDS:
raise ValueError(
f"Too many allowed token IDs: {num_allowed_token_ids}. "
f"The max size is {MAX_NUM_ALLOWED_TOKEN_IDS}."
)
self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids
self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids)
use_logit_bias = True
else:
self.num_allowed_token_ids.np[req_idx] = 0
# Logit bias.
logit_bias = sampling_params.logit_bias
if logit_bias:
num_logit_bias = len(logit_bias)
if num_logit_bias > MAX_NUM_LOGIT_BIAS_TOKENS:
raise ValueError(
f"Too many logit bias tokens: {num_logit_bias}. "
f"The max size is {MAX_NUM_LOGIT_BIAS_TOKENS}."
)
self.num_logit_bias.np[req_idx] = num_logit_bias
self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys())
self.logit_bias.stage_write(req_idx, 0, logit_bias.values())
use_logit_bias = True
else:
self.num_logit_bias.np[req_idx] = 0
# Min tokens.
min_tokens = sampling_params.min_tokens
min_len = prompt_len + min_tokens
self.min_lens.np[req_idx] = min_len
stop_token_ids = sampling_params.all_stop_token_ids
if min_tokens > 0 and stop_token_ids:
num_stop_token_ids = len(stop_token_ids)
if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS:
raise ValueError(
f"Too many stop tokens: {num_stop_token_ids}. "
f"The max size is {MAX_NUM_STOP_TOKEN_IDS}."
)
self.num_stop_token_ids.np[req_idx] = num_stop_token_ids
self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids)
use_logit_bias = True
else:
self.num_stop_token_ids.np[req_idx] = 0
self.use_logit_bias[req_idx] = use_logit_bias
def apply_staged_writes(self) -> None:
self.num_allowed_token_ids.copy_to_uva()
self.allowed_token_ids.apply_write()
self.num_logit_bias.copy_to_uva()
self.logit_bias_token_ids.apply_write()
self.logit_bias.apply_write()
self.min_lens.copy_to_uva()
self.num_stop_token_ids.copy_to_uva()
self.stop_token_ids.apply_write()
def apply_logit_bias(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> None:
if not np.any(self.use_logit_bias[idx_mapping_np]):
# No request uses logit bias. Skip the kernel launch.
return
apply_logit_bias(
logits,
idx_mapping,
pos,
self.num_allowed_token_ids.gpu,
self.allowed_token_ids.gpu,
self.num_logit_bias.gpu,
self.logit_bias_token_ids.gpu,
self.logit_bias.gpu,
self.min_lens.gpu,
self.num_stop_token_ids.gpu,
self.stop_token_ids.gpu,
)
@triton.jit
def _bias_kernel(
logits_ptr,
logits_stride,
vocab_size,
idx_mapping_ptr,
# Allowed token IDs.
num_allowed_token_ids_ptr,
allowed_token_ids_ptr,
allowed_token_ids_stride,
# Logit bias.
num_logit_bias_ptr,
bias_token_ids_ptr,
bias_token_ids_stride,
bias_ptr,
bias_stride,
# Min tokens.
pos_ptr,
min_lens_ptr,
num_stop_token_ids_ptr,
stop_token_ids_ptr,
stop_token_ids_stride,
BLOCK_SIZE: tl.constexpr,
LOGITS_BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
block = tl.arange(0, BLOCK_SIZE)
# Allowed token IDs.
num_allowed_token_ids = tl.load(num_allowed_token_ids_ptr + req_state_idx)
if num_allowed_token_ids > 0:
block = tl.arange(0, BLOCK_SIZE)
mask = block < num_allowed_token_ids
# Save logits for allowed token IDs.
allowed_token_ids = tl.load(
allowed_token_ids_ptr + req_state_idx * allowed_token_ids_stride + block,
mask=mask,
)
logits = tl.load(
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
)
# Set logits to -inf for all tokens.
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
tl.store(
logits_ptr + batch_idx * logits_stride + offset,
-float("inf"),
mask=offset < vocab_size,
)
# Restore logits for allowed token IDs.
tl.store(
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
logits,
mask=mask,
)
# Logit bias.
num_logit_bias = tl.load(num_logit_bias_ptr + req_state_idx)
if num_logit_bias > 0:
mask = block < num_logit_bias
token_ids = tl.load(
bias_token_ids_ptr + req_state_idx * bias_token_ids_stride + block,
mask=mask,
)
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
logits += bias
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
# Apply min tokens.
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx)
min_len = tl.load(min_lens_ptr + req_state_idx)
if num_stop_token_ids > 0 and pos < min_len:
mask = block < num_stop_token_ids
stop_token_ids = tl.load(
stop_token_ids_ptr + req_state_idx * stop_token_ids_stride + block,
mask=mask,
)
tl.store(
logits_ptr + batch_idx * logits_stride + stop_token_ids,
-float("inf"),
mask=mask,
)
def apply_logit_bias(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
pos: torch.Tensor,
num_allowed_token_ids: torch.Tensor,
allowed_token_ids: torch.Tensor,
num_logit_bias: torch.Tensor,
logit_bias_token_ids: torch.Tensor,
logit_bias: torch.Tensor,
min_lens: torch.Tensor,
num_stop_token_ids: torch.Tensor,
stop_token_ids: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = triton.next_power_of_2(
max(
allowed_token_ids.shape[-1],
logit_bias_token_ids.shape[-1],
stop_token_ids.shape[-1],
)
)
LOGITS_BLOCK_SIZE = 8192
_bias_kernel[(num_reqs,)](
logits,
logits.stride(0),
vocab_size,
idx_mapping,
num_allowed_token_ids,
allowed_token_ids,
allowed_token_ids.stride(0),
num_logit_bias,
logit_bias_token_ids,
logit_bias_token_ids.stride(0),
logit_bias,
logit_bias.stride(0),
pos,
min_lens,
num_stop_token_ids,
stop_token_ids,
stop_token_ids.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
)

View File

@@ -0,0 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
@triton.jit
def _topk_log_softmax_kernel(
output_ptr,
logits_ptr,
logits_stride,
topk_ids_ptr,
topk,
vocab_size,
BLOCK_SIZE: tl.constexpr,
PADDED_TOPK: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore
se = 0.0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
logits = logits.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e)
lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < topk
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
logits = logits.to(tl.float32)
o = logits - max_val - lse
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
@triton.jit
def _ranks_kernel(
output_ptr,
logits_ptr,
logits_stride,
token_ids_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
token_id = tl.load(token_ids_ptr + req_idx)
x = tl.load(row_ptr + token_id)
n = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
n += tl.sum((logits >= x).to(tl.int32))
tl.store(output_ptr + req_idx, n)
def compute_token_logprobs(
logits: torch.Tensor, token_ids: torch.Tensor
) -> torch.Tensor:
batch_size, vocab_size = logits.shape
token_ids = token_ids.to(torch.int64)
num_logprobs = token_ids.shape[1]
logprobs = logits.new_empty((batch_size, num_logprobs), dtype=torch.float32)
_topk_log_softmax_kernel[(batch_size,)](
logprobs,
logits,
logits.stride(0),
token_ids,
num_logprobs,
vocab_size,
BLOCK_SIZE=1024, # type: ignore
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
)
return logprobs
def compute_topk_logprobs(
logits: torch.Tensor,
num_logprobs: int,
sampled_token_ids: torch.Tensor,
cu_num_logits: list[int] | None = None,
) -> LogprobsTensors:
assert num_logprobs >= 0
batch_size, vocab_size = logits.shape
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
if num_logprobs > 0:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat((logprob_token_ids, topk_indices), dim=1)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
logprobs = compute_token_logprobs(logits, logprob_token_ids)
token_ranks = torch.empty(batch_size, dtype=torch.int64, device=logits.device)
_ranks_kernel[(batch_size,)](
token_ranks,
logits,
logits.stride(0),
sampled_token_ids,
vocab_size,
BLOCK_SIZE=8192, # type: ignore
)
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=token_ranks,
cu_num_generated_tokens=cu_num_logits,
)

View File

@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _min_p_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
min_p_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
if min_p == 0.0:
return
max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
)
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore
threshold = max_val + tl.log(min_p)
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
)
logits = tl.where(logits < threshold, float("-inf"), logits)
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
def apply_min_p(
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
_min_p_kernel[(num_reqs,)](
logits,
logits.stride(0),
idx_mapping,
min_p,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)

View File

@@ -0,0 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.v1.outputs import LogprobsTensors
@dataclass
class SamplerOutput:
sampled_token_ids: torch.Tensor
logprobs_tensors: LogprobsTensors | None
num_nans: torch.Tensor | None

View File

@@ -0,0 +1,311 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import async_tensor_h2d
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
from vllm.v1.worker.gpu.states import RequestState
class PenaltiesState:
def __init__(self, req_states: RequestState):
self.req_states = req_states
max_num_reqs = req_states.max_num_reqs
self.vocab_size = req_states.vocab_size
self.device = req_states.device
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.use_penalty = np.zeros(max_num_reqs, dtype=bool)
# Initialize repetition penalty manually because 0 is an invalid value for it.
self.repetition_penalty.np.fill(1.0)
self.repetition_penalty.copy_to_uva()
# Statistics for penalties.
self.prompt_bin_mask = torch.zeros(
max_num_reqs,
cdiv(self.vocab_size, 32),
dtype=torch.int32,
device=self.device,
)
# TODO(woosuk): This tensor is rarely used but can be very large, taking up
# GBs of GPU memory. Optimize the memory usage.
self.output_bin_counts = torch.zeros(
max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
self._new_penalties_reqs: list[int] = []
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
do_penalty = use_penalty(sampling_params)
self.use_penalty[req_idx] = do_penalty
if do_penalty:
self._new_penalties_reqs.append(req_idx)
def apply_staged_writes(self) -> None:
if self._new_penalties_reqs:
idx_mapping = async_tensor_h2d(
self._new_penalties_reqs,
dtype=torch.int32,
target_device=self.device,
pin_memory=True,
)
prefill_lens = self.req_states.prefill_len.np[self._new_penalties_reqs]
max_prefill_len = int(prefill_lens.max())
bincount(
idx_mapping,
self.req_states.all_token_ids.gpu,
self.req_states.prompt_len.gpu,
self.req_states.prefill_len.gpu,
self.prompt_bin_mask,
self.output_bin_counts,
max_prefill_len,
)
self._new_penalties_reqs.clear()
self.repetition_penalty.copy_to_uva()
self.frequency_penalty.copy_to_uva()
self.presence_penalty.copy_to_uva()
def apply_penalties(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
num_speculative_tokens: int,
) -> None:
if not np.any(self.use_penalty[idx_mapping_np]):
# No request uses penalties. Skip the kernel launch.
return
apply_penalties(
logits,
idx_mapping,
input_ids,
expanded_local_pos,
self.repetition_penalty.gpu,
self.frequency_penalty.gpu,
self.presence_penalty.gpu,
self.prompt_bin_mask,
self.output_bin_counts,
num_speculative_tokens,
)
@triton.jit
def _penalties_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
token_ids_ptr,
expanded_local_pos_ptr,
repetition_penalty_ptr,
frequency_penalty_ptr,
presence_penalty_ptr,
prompt_bin_mask_ptr,
prompt_bin_mask_stride,
output_bin_counts_ptr,
output_bin_counts_stride,
vocab_size,
BLOCK_SIZE: tl.constexpr,
MAX_SPEC_LEN: tl.constexpr,
):
token_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0
use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
if not use_penalty:
# Early return to avoid loading logits.
return
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
base_output_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
mask=mask,
other=0,
)
# Compute cumulative draft_counts from previous positions in this request
pos = tl.load(expanded_local_pos_ptr + token_idx)
start_idx = token_idx - pos
draft_counts = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
for prev_pos in tl.static_range(MAX_SPEC_LEN):
if prev_pos < pos:
prev_token = tl.load(token_ids_ptr + start_idx + prev_pos + 1)
token_match = block == prev_token
draft_counts = draft_counts + token_match.to(tl.int32)
# Total counts = base output counts + cumulative draft counts
output_bin_counts = base_output_counts + draft_counts
output_bin_mask = output_bin_counts > 0
# Apply repetition penalties.
if use_rep_penalty:
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
packed_mask = tl.load(
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block,
mask=packed_block < tl.cdiv(vocab_size, 32),
other=0,
)
prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
logits *= tl.where(logits > 0, 1.0 / scale, scale)
# Apply frequency penalties.
logits -= freq_penalty * output_bin_counts
# Apply presence penalties.
logits -= pres_penalty * output_bin_mask
# Store back to logits.
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
def apply_penalties(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
token_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
repetition_penalty: torch.Tensor,
frequency_penalty: torch.Tensor,
presence_penalty: torch.Tensor,
prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor,
num_speculative_tokens: int,
) -> None:
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_penalties_kernel[(num_tokens, num_blocks)](
logits,
logits.stride(0),
idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
frequency_penalty,
presence_penalty,
prompt_bin_mask,
prompt_bin_mask.stride(0),
output_bin_counts,
output_bin_counts.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
MAX_SPEC_LEN=num_speculative_tokens,
)
@triton.jit
def _bincount_kernel(
idx_mapping_ptr,
all_token_ids_ptr,
all_token_ids_stride,
prompt_len_ptr,
prefill_len_ptr,
prompt_bin_mask_ptr,
prompt_bin_mask_stride,
output_bin_counts_ptr,
output_bin_counts_stride,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if block_idx * BLOCK_SIZE >= prefill_len:
return
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if block_idx * BLOCK_SIZE < prompt_len:
mask = block < prompt_len
prompt_tokens = tl.load(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
)
idx = prompt_tokens // 32
bit_idx = prompt_tokens % 32
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
tl.atomic_or(
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + idx,
bit,
mask=mask,
)
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
mask = block < prefill_len
mask &= block >= prompt_len
output_tokens = tl.load(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
)
tl.atomic_add(
output_bin_counts_ptr
+ req_state_idx * output_bin_counts_stride
+ output_tokens,
1,
mask=mask,
)
def bincount(
idx_mapping: torch.Tensor,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
prefill_len: torch.Tensor,
prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor,
max_prefill_len: int,
) -> None:
prompt_bin_mask[idx_mapping] = 0
output_bin_counts[idx_mapping] = 0
num_reqs = idx_mapping.shape[0]
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_reqs, num_blocks)](
idx_mapping,
all_token_ids,
all_token_ids.stride(0),
prompt_len,
prefill_len,
prompt_bin_mask,
prompt_bin_mask.stride(0),
output_bin_counts,
output_bin_counts.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
)
def use_penalty(sampling_params: SamplingParams) -> bool:
return (
sampling_params.repetition_penalty != 1.0
or sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
)

View File

@@ -0,0 +1,208 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
class PromptLogprobsWorker:
def __init__(self, max_num_reqs: int):
self.max_num_reqs = max_num_reqs
self.uses_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
# req_idx -> list of in-progress LogprobsTensors
self.in_progress_prompt_logprobs: dict[str, list[LogprobsTensors]] = {}
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
# For now, only support prompt logprobs for the prompt tokens (not top-k).
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.uses_prompt_logprobs[req_idx] = uses_prompt_logprobs
if uses_prompt_logprobs:
self.in_progress_prompt_logprobs[req_id] = []
def remove_request(self, req_id: str) -> None:
self.in_progress_prompt_logprobs.pop(req_id, None)
def compute_prompt_logprobs(
self,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
hidden_states: torch.Tensor,
input_batch: InputBatch,
# [max_num_reqs, max_model_len]
all_token_ids: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [max_num_reqs]
prompt_lens: np.ndarray,
# [max_num_reqs]
prefill_lens: np.ndarray,
# [max_num_reqs]
num_computed_prefill_tokens: np.ndarray,
) -> dict[str, LogprobsTensors]:
idx_mapping_np = input_batch.idx_mapping_np
needs_prompt_logprobs = self.uses_prompt_logprobs[idx_mapping_np]
if not np.any(needs_prompt_logprobs):
# Common case: No request asks for prompt logprobs.
return {}
prompt_lens = prompt_lens[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
computed_prefill = num_computed_prefill_tokens[idx_mapping_np]
includes_prompt = computed_prefill < prompt_lens - 1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt = prompt_lens < prefill_lens[idx_mapping_np]
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
if not np.any(needs_prompt_logprobs):
return {}
# Get the prompt logprobs token_ids.
prompt_logprobs_token_ids = get_prompt_logprobs_token_ids(
input_batch.num_tokens,
input_batch.query_start_loc,
input_batch.idx_mapping,
num_computed_tokens,
all_token_ids,
)
# Compute the prompt logprobs.
prompt_logprobs, prompt_ranks = compute_prompt_logprobs_with_chunking(
prompt_logprobs_token_ids,
hidden_states[: input_batch.num_tokens],
logits_fn,
)
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
is_prompt_chunked = pos_after_step < prompt_lens
query_start_loc_np = input_batch.query_start_loc_np
prompt_token_ids = prompt_logprobs_token_ids.unsqueeze(-1)
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
start_idx = query_start_loc_np[i]
end_idx = query_start_loc_np[i + 1]
assert start_idx < end_idx, (
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
)
if not is_prompt_chunked[i]:
end_idx -= 1
logprobs = LogprobsTensors(
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
logprobs=prompt_logprobs[start_idx:end_idx],
selected_token_ranks=prompt_ranks[start_idx:end_idx],
)
prompt_logprobs_list = self.in_progress_prompt_logprobs[req_id]
if is_prompt_chunked[i]:
# Prompt is chunked. Do not return the logprobs yet.
prompt_logprobs_list.append(logprobs)
continue
if prompt_logprobs_list:
# Merge the in-progress logprobs.
prompt_logprobs_list.append(logprobs)
logprobs = LogprobsTensors(
logprob_token_ids=torch.cat(
[x.logprob_token_ids for x in prompt_logprobs_list]
),
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
selected_token_ranks=torch.cat(
[x.selected_token_ranks for x in prompt_logprobs_list]
),
)
prompt_logprobs_list.clear()
prompt_logprobs_dict[req_id] = logprobs
return prompt_logprobs_dict
@triton.jit
def _prompt_logprobs_token_ids_kernel(
prompt_logprobs_token_ids_ptr,
query_start_loc_ptr,
idx_mapping_ptr,
num_computed_tokens_ptr,
all_token_ids_ptr,
all_token_ids_stride,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
# NOTE(woosuk): We should shift the pos by one
# because the logprob is computed for the next token.
target_pos = num_computed_tokens + 1 + block
token_ids = tl.load(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + target_pos,
mask=mask,
)
tl.store(
prompt_logprobs_token_ids_ptr + query_start + block, token_ids, mask=mask
)
def get_prompt_logprobs_token_ids(
num_tokens: int,
query_start_loc: torch.Tensor,
idx_mapping: torch.Tensor,
num_computed_tokens: torch.Tensor,
all_token_ids: torch.Tensor,
) -> torch.Tensor:
token_ids = torch.empty(num_tokens, dtype=torch.int64, device=idx_mapping.device)
num_reqs = idx_mapping.shape[0]
_prompt_logprobs_token_ids_kernel[(num_reqs,)](
token_ids,
query_start_loc,
idx_mapping,
num_computed_tokens,
all_token_ids,
all_token_ids.stride(0),
BLOCK_SIZE=1024,
)
return token_ids
def compute_prompt_logprobs_with_chunking(
prompt_token_ids: torch.Tensor,
prompt_hidden_states: torch.Tensor,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE = 1024
logprobs = []
ranks = []
prompt_token_ids = prompt_token_ids.to(torch.int64)
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
end_idx = start_idx + CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
prompt_logprobs = compute_topk_logprobs(
prompt_logits,
0, # num_logprobs
prompt_token_ids[start_idx:end_idx],
)
logprobs.append(prompt_logprobs.logprobs)
ranks.append(prompt_logprobs.selected_token_ranks)
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
return logprobs, ranks

View File

@@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
import vllm.envs as envs
from vllm.config.model import LogprobsMode
from vllm.sampling_params import SamplingParams
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
from vllm.v1.worker.gpu.states import RequestState
class Sampler:
def __init__(
self,
max_num_reqs: int,
vocab_size: int,
device: torch.device,
req_states: RequestState,
logprobs_mode: LogprobsMode = "raw_logprobs",
num_speculative_tokens: int = 1,
):
if logprobs_mode not in ("processed_logprobs", "raw_logprobs"):
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
self.logprobs_mode = logprobs_mode
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
self.penalties_state = PenaltiesState(req_states)
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
self.bad_words_state = BadWordsState(req_states)
self.num_speculative_tokens = num_speculative_tokens
def add_request(
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
) -> None:
self.sampling_states.add_request(req_idx, sampling_params)
self.penalties_state.add_request(req_idx, sampling_params)
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
self.bad_words_state.add_request(req_idx, sampling_params)
def apply_staged_writes(self) -> None:
self.sampling_states.apply_staged_writes()
self.penalties_state.apply_staged_writes()
self.logit_bias_state.apply_staged_writes()
self.bad_words_state.apply_staged_writes()
def __call__(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
cu_num_logits_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> SamplerOutput:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.compute_nans else None
sampled, processed_logits = self.sample(
logits,
idx_mapping,
idx_mapping_np,
pos,
input_ids,
expanded_local_pos,
)
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
if max_num_logprobs != NO_LOGPROBS:
if self.logprobs_mode == "processed_logprobs":
logits = processed_logits
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
logprobs_tensors = compute_topk_logprobs(
logits, max_num_logprobs, sampled, cu_num_logits
)
else:
logprobs_tensors = None
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors,
num_nans=num_nans,
)
return sampler_output
def sample(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
# Apply penalties in place.
self.penalties_state.apply_penalties(
logits,
idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
self.num_speculative_tokens,
)
# Apply bad words masking in place.
self.bad_words_state.apply_bad_words(
logits,
idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
)
# Apply temperature in place.
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
# Apply min_p in place.
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
# Apply top_k and/or top_p. This might or might not return a new tensor.
logits = self.sampling_states.apply_top_k_top_p(
logits, idx_mapping, idx_mapping_np
)
# Sample the next token.
sampled = gumbel_sample(
logits,
idx_mapping,
self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu,
pos,
apply_temperature=False,
)
return sampled, logits

View File

@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
NO_LOGPROBS = -1
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
class SamplingStates:
def __init__(self, max_num_reqs: int, vocab_size: int):
self.max_num_reqs = max_num_reqs
self.vocab_size = vocab_size
self.temperature = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.top_k = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
self.top_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.min_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.seeds = UvaBackedTensor(max_num_reqs, dtype=torch.int64)
# Initialize top_k and top_p manually because 0 is an invalid value for them.
self.top_k.np.fill(self.vocab_size)
self.top_k.copy_to_uva()
self.top_p.np.fill(1.0)
self.top_p.copy_to_uva()
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(NO_LOGPROBS)
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
top_k = sampling_params.top_k
if top_k <= 0 or top_k > self.vocab_size:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.min_p.np[req_idx] = sampling_params.min_p
seed = sampling_params.seed
if seed is None:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
num_logprobs = sampling_params.logprobs
if num_logprobs is None:
num_logprobs = NO_LOGPROBS
self.num_logprobs[req_idx] = num_logprobs
def apply_staged_writes(self) -> None:
self.temperature.copy_to_uva()
self.top_p.copy_to_uva()
self.top_k.copy_to_uva()
self.min_p.copy_to_uva()
self.seeds.copy_to_uva()
def apply_temperature(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
temp_np = self.temperature.np[idx_mapping_np]
if np.all((temp_np == 0.0) | (temp_np == 1.0)):
# No request requires temperature. Skip the kernel launch.
return
apply_temperature(logits, idx_mapping, self.temperature.gpu)
def apply_min_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
# No request uses min_p. Skip the kernel launch.
return
apply_min_p(logits, idx_mapping, self.min_p.gpu)
def apply_top_k_top_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> torch.Tensor:
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
do_top_p = np.any(self.top_p.np[idx_mapping_np] != 1.0)
if not (do_top_k or do_top_p):
return logits
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
return apply_top_k_top_p(logits, top_k, top_p)
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
return int(np.max(self.num_logprobs[idx_mapping_np]))

View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import VllmConfig
def init_speculator(vllm_config: VllmConfig, device: torch.device):
speculative_config = vllm_config.speculative_config
assert speculative_config is not None
if speculative_config.use_eagle():
from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator
return EagleSpeculator(vllm_config, device)
raise NotImplementedError(f"{speculative_config.method} is not supported yet.")

View File

@@ -0,0 +1,191 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.model_executor.offloader.base import get_offloader
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import (
capture_graphs,
get_cudagraph_sizes,
prepare_inputs_to_capture,
)
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.utils import AttentionGroup
class EagleCudaGraphManager:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode()
# only need to capture uniform decode cudagraph sizes (the 2nd return value)
_, self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
uniform_decode_query_len=1,
uniform_decode_cudagraph=True,
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle()
def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens)
def capture_graph(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> None:
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
num_reqs = min(num_tokens, self.max_num_reqs)
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
block_tables,
attn_groups,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=1,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
# Capture the graph.
capture_fn(
num_reqs=num_reqs,
num_tokens=num_tokens,
generate_fn=generate_fn,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
)
def _capture_full_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
with torch.cuda.graph(graph, self.pool):
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader().join_after_forward()
self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.PIECEWISE,
)
@torch.inference_mode()
def capture(
self,
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> None:
if self.cudagraph_mode == CUDAGraphMode.NONE:
return
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
capture_cudagraph_mode=self.cudagraph_mode,
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
generate_fn=generate_fn,
input_buffers=input_buffers,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
)
def run_fullgraph(self, num_tokens: int) -> None:
assert num_tokens in self.graphs
# Sync offloader before replay - needed when transitioning from
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader().sync_prev_onload()
self.graphs[num_tokens].replay()

View File

@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast
import torch.nn as nn
from vllm.config import SpeculativeConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsEagle3, supports_eagle3
logger = init_logger(__name__)
def set_eagle3_aux_hidden_state_layers(
model: nn.Module,
spec_config: SpeculativeConfig,
) -> None:
if not supports_eagle3(model):
raise RuntimeError("Model does not support EAGLE3 interface")
# mypy may infer the class-level overload for supports_eagle3.
# Narrow explicitly to the runtime protocol instance.
if isinstance(model, type):
raise RuntimeError("Expected model instance for EAGLE3 configuration")
eagle3_model = cast(SupportsEagle3, model)
aux_layers = get_eagle3_aux_layers_from_config(spec_config)
if aux_layers:
logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers)
else:
aux_layers = eagle3_model.get_eagle3_aux_hidden_state_layers()
logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers)
eagle3_model.set_aux_hidden_state_layers(aux_layers)
def get_eagle3_aux_layers_from_config(
spec_config: SpeculativeConfig,
) -> tuple[int, ...] | None:
if not (spec_config and spec_config.draft_model_config):
return None
hf_config = spec_config.draft_model_config.hf_config
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
return None
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
if layer_ids and isinstance(layer_ids, (list, tuple)):
return tuple(layer_ids)
return None

View File

@@ -0,0 +1,583 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
from vllm.v1.worker.utils import AttentionGroup
logger = init_logger(__name__)
class EagleSpeculator:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.device = device
self.speculative_config = vllm_config.speculative_config
assert self.speculative_config is not None
self.method = self.speculative_config.method
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
self.draft_model_config = self.speculative_config.draft_model_config
self.scheduler_config = vllm_config.scheduler_config
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = vllm_config.model_config.max_model_len
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
self.vocab_size = self.draft_model_config.get_vocab_size()
self.dtype = vllm_config.model_config.dtype
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
device=device,
)
self.hidden_states = torch.zeros(
self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
)
self.idx_mapping = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
self.temperature = torch.zeros(
self.max_num_reqs, dtype=torch.float32, device=device
)
self.seeds = torch.zeros(self.max_num_reqs, dtype=torch.int64, device=device)
self.draft_tokens = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
dtype=torch.int64,
device=device,
)
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
def load_model(self, target_model: nn.Module) -> None:
self.model = load_eagle_model(target_model, self.vllm_config)
def set_attn(
self,
kv_cache_config: KVCacheConfig,
attn_groups: list[list[AttentionGroup]],
block_tables: BlockTables,
) -> None:
self.kv_cache_config = kv_cache_config
self.attn_groups = attn_groups
self.block_tables = block_tables
@torch.inference_mode()
def run_model(
self,
num_tokens: int,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_descriptor = BatchDescriptor(num_tokens=num_tokens)
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
batch_descriptor=batch_descriptor,
):
ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids[:num_tokens],
positions=self.input_buffers.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
return last_hidden_states, hidden_states
def generate_draft(
self,
num_reqs: int,
num_tokens_padded: int,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
idx_mapping = self.idx_mapping[:num_reqs]
for step in range(1, self.num_speculative_steps):
# Run the eagle model.
last_hidden_states, hidden_states = self.run_model(
num_tokens_padded,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
cudagraph_runtime_mode,
)
last_hidden_states = last_hidden_states[:num_reqs]
hidden_states = hidden_states[:num_reqs]
logits = self.model.compute_logits(last_hidden_states)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits,
idx_mapping,
self.temperature,
self.seeds,
pos + 1,
apply_temperature=True,
)
self.draft_tokens[:num_reqs, step] = draft_tokens
if step < self.num_speculative_steps - 1:
# Update the inputs for the next step.
update_eagle_inputs(
draft_tokens,
hidden_states,
self.input_buffers,
self.hidden_states,
self.max_model_len,
)
self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
def capture_model(self) -> None:
if self.num_speculative_steps == 1:
return
logger.info("Capturing model for Eagle speculator...")
self.cudagraph_manager.capture(
self.generate_draft,
self.input_buffers,
self.block_tables,
self.attn_groups,
self.kv_cache_config,
)
@torch.inference_mode()
def propose(
self,
input_batch: InputBatch,
# [num_tokens, hidden_size]
last_hidden_states: torch.Tensor,
# num_layers x [num_tokens, hidden_size]
aux_hidden_states: list[torch.Tensor] | None,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [max_num_reqs]
last_sampled: torch.Tensor,
# [max_num_reqs]
next_prefill_tokens: torch.Tensor,
# [max_num_reqs]
temperature: torch.Tensor,
# [max_num_reqs]
seeds: torch.Tensor,
) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
# hidden_states the same as the target model's. This means, we pad each
# request's query length to include any rejected positions. By doing so,
# we can also reuse the attention metadata (e.g., query_start_loc,
# seq_lens) of the target model.
if aux_hidden_states:
assert self.method == "eagle3"
hidden_states = self.model.combine_hidden_states(
torch.cat(aux_hidden_states, dim=-1)
)
else:
hidden_states = last_hidden_states
num_tokens = input_batch.num_tokens_after_padding
self.hidden_states[:num_tokens] = hidden_states
# Get the input ids and last token indices for the speculator.
last_token_indices = prepare_eagle_inputs(
self.input_buffers,
input_batch,
num_sampled,
num_rejected,
last_sampled,
next_prefill_tokens,
)
# Prefill: Run the eagle speculator with eager mode.
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states, hidden_states = self.run_model(
num_tokens,
input_batch.attn_metadata,
input_batch.slot_mappings,
num_tokens_across_dp=None, # FIXME
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
idx_mapping = self.idx_mapping[:num_reqs]
idx_mapping.copy_(input_batch.idx_mapping)
self.temperature.copy_(temperature)
self.seeds.copy_(seeds)
# Gather the values and copy them to the pre-allocated buffers.
pos = self.input_buffers.positions[:num_reqs]
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits,
idx_mapping,
self.temperature,
self.seeds,
pos + 1,
apply_temperature=True,
)
if self.num_speculative_steps == 1:
# Early exit.
return draft_tokens.view(-1, 1)
# Save the draft tokens for the first step.
self.draft_tokens[:num_reqs, 0] = draft_tokens
# Prepare the inputs for the decode steps.
prepare_eagle_decode(
draft_tokens,
hidden_states,
last_token_indices,
input_batch.seq_lens,
num_rejected,
self.input_buffers,
self.hidden_states,
self.max_model_len,
self.max_num_reqs,
)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
cudagraph_mode = self.cudagraph_manager.cudagraph_mode
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
# Run full CUDA graph.
self.cudagraph_manager.run_fullgraph(cudagraph_size)
return self.draft_tokens[:num_reqs]
# Run eager or piecewise CUDA graph.
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu"
)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
self.generate_draft(
num_reqs,
num_tokens_padded,
attn_metadata,
slot_mappings_by_layer,
num_tokens_across_dp=None, # FIXME
cudagraph_runtime_mode=cudagraph_mode,
)
return self.draft_tokens[:num_reqs]
@triton.jit
def _prepare_eagle_inputs_kernel(
last_token_indices_ptr,
eagle_input_ids_ptr,
eagle_positions_ptr,
target_input_ids_ptr,
target_positions_ptr,
idx_mapping_ptr,
last_sampled_ptr,
next_prefill_tokens_ptr,
num_sampled_ptr,
num_rejected_ptr,
query_start_loc_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
# Get the true query length and next token after accounting for rejected tokens.
num_rejected = tl.load(num_rejected_ptr + batch_idx)
query_len -= num_rejected
num_sampled = tl.load(num_sampled_ptr + batch_idx)
if num_sampled > 0:
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
else:
# Chunked prefilling.
# Get the next prefill token.
next_token = tl.load(next_prefill_tokens_ptr + req_state_idx)
# Shift target_input_ids by one.
for i in range(1, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
input_ids = tl.load(target_input_ids_ptr + query_start + block, mask=mask)
tl.store(eagle_input_ids_ptr + query_start + block - 1, input_ids, mask=mask)
last_token_index = query_start + query_len - 1
tl.store(last_token_indices_ptr + batch_idx, last_token_index)
tl.store(eagle_input_ids_ptr + last_token_index, next_token)
# Copy positions.
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
target_pos = tl.load(target_positions_ptr + query_start + block, mask=mask)
tl.store(eagle_positions_ptr + query_start + block, target_pos, mask=mask)
def prepare_eagle_inputs(
input_buffers: InputBuffers,
input_batch: InputBatch,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [max_num_reqs]
last_sampled: torch.Tensor,
# [max_num_reqs]
next_prefill_tokens: torch.Tensor,
) -> torch.Tensor:
num_reqs = input_batch.num_reqs
last_token_indices = torch.empty(
num_reqs,
dtype=torch.int64,
device=num_sampled.device,
)
_prepare_eagle_inputs_kernel[(num_reqs,)](
last_token_indices,
input_buffers.input_ids,
input_buffers.positions,
input_batch.input_ids,
input_batch.positions,
input_batch.idx_mapping,
last_sampled,
next_prefill_tokens,
num_sampled,
num_rejected,
input_batch.query_start_loc,
BLOCK_SIZE=1024,
)
return last_token_indices
@triton.jit
def _prepare_eagle_docode_kernel(
draft_tokens_ptr,
output_hidden_states_ptr,
output_hidden_states_stride,
last_token_indices_ptr,
target_seq_lens_ptr,
num_rejected_ptr,
input_ids_ptr,
positions_ptr,
input_hidden_states_ptr,
input_hidden_states_stride,
query_start_loc_ptr,
seq_lens_ptr,
hidden_size,
max_model_len,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_reqs = tl.num_programs(0) - 1
if req_idx == num_reqs:
# Compute query_start_loc. Pad it with the last query_start_loc
# for CUDA graphs.
for i in range(0, max_num_reqs + 1, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
q = tl.where(block < num_reqs, block, num_reqs)
mask = block < max_num_reqs + 1
tl.store(query_start_loc_ptr + block, q, mask=mask)
# Pad seq_lens for CUDA graphs.
for i in range(req_idx, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(seq_lens_ptr + block, 0, mask=mask)
return
# draft token -> input id.
draft_token = tl.load(draft_tokens_ptr + req_idx)
tl.store(input_ids_ptr + req_idx, draft_token)
# output hidden states -> input hidden states.
src_idx = tl.load(last_token_indices_ptr + req_idx)
for i in range(0, hidden_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < hidden_size
output_hidden_states = tl.load(
output_hidden_states_ptr + src_idx * output_hidden_states_stride + block,
mask=mask,
)
tl.store(
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
output_hidden_states,
mask=mask,
)
# Compute position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position = tl.load(positions_ptr + req_idx)
position = tl.minimum(position + 1, max_model_len - 1)
tl.store(positions_ptr + req_idx, position)
target_seq_len = tl.load(target_seq_lens_ptr + req_idx)
num_rejected = tl.load(num_rejected_ptr + req_idx)
seq_len = target_seq_len - num_rejected
seq_len = tl.minimum(seq_len + 1, max_model_len)
tl.store(seq_lens_ptr + req_idx, seq_len)
def prepare_eagle_decode(
draft_tokens: torch.Tensor,
output_hidden_states: torch.Tensor,
last_token_indices: torch.Tensor,
target_seq_lens: torch.Tensor,
num_rejected: torch.Tensor,
input_buffers: InputBuffers,
input_hidden_states: torch.Tensor,
max_model_len: int,
max_num_reqs: int,
):
num_reqs = draft_tokens.shape[0]
hidden_size = output_hidden_states.shape[-1]
_prepare_eagle_docode_kernel[(num_reqs + 1,)](
draft_tokens,
output_hidden_states,
output_hidden_states.stride(0),
last_token_indices,
target_seq_lens,
num_rejected,
input_buffers.input_ids,
input_buffers.positions,
input_hidden_states,
input_hidden_states.stride(0),
input_buffers.query_start_loc,
input_buffers.seq_lens,
hidden_size,
max_model_len,
max_num_reqs,
BLOCK_SIZE=1024,
)
@triton.jit
def _update_eagle_inputs_kernel(
input_ids_ptr,
positions_ptr,
input_hidden_states_ptr,
input_hidden_states_stride,
seq_lens_ptr,
max_model_len,
draft_tokens_ptr,
output_hidden_states_ptr,
output_hidden_states_stride,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
# Draft token -> Input ID.
draft_token = tl.load(draft_tokens_ptr + req_idx)
tl.store(input_ids_ptr + req_idx, draft_token)
# Output hidden states -> Input hidden states.
for i in range(0, hidden_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < hidden_size
output_hidden_states = tl.load(
output_hidden_states_ptr + req_idx * output_hidden_states_stride + block,
mask=mask,
)
tl.store(
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
output_hidden_states,
mask=mask,
)
# Increment position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position = tl.load(positions_ptr + req_idx)
position = tl.minimum(position + 1, max_model_len - 1)
tl.store(positions_ptr + req_idx, position)
seq_len = tl.load(seq_lens_ptr + req_idx)
seq_len = tl.minimum(seq_len + 1, max_model_len)
tl.store(seq_lens_ptr + req_idx, seq_len)
def update_eagle_inputs(
draft_tokens: torch.Tensor,
output_hidden_states: torch.Tensor,
input_buffers: InputBuffers,
hidden_states: torch.Tensor,
max_model_len: int,
):
num_reqs, hidden_size = output_hidden_states.shape
_update_eagle_inputs_kernel[(num_reqs,)](
input_buffers.input_ids,
input_buffers.positions,
hidden_states,
hidden_states.stride(0),
input_buffers.seq_lens,
max_model_len,
draft_tokens,
output_hidden_states,
output_hidden_states.stride(0),
hidden_size,
BLOCK_SIZE=1024,
)

View File

@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.model_loader import get_model
def load_eagle_model(target_model: nn.Module, vllm_config: VllmConfig) -> nn.Module:
from vllm.compilation.backends import set_model_tag
speculative_config = vllm_config.speculative_config
assert speculative_config is not None
draft_model_config = speculative_config.draft_model_config
with set_model_tag("eagle_head"):
eagle_model = get_model(
vllm_config=vllm_config, model_config=draft_model_config
)
# Share target embeddings when the draft checkpoint does not include
# its own vocab embedding table.
share_embeddings = True
if hasattr(eagle_model, "has_own_embed_tokens"):
share_embeddings = not eagle_model.has_own_embed_tokens
if share_embeddings:
target_language_model = (
target_model.get_language_model()
if hasattr(target_model, "get_language_model")
else target_model
)
inner_model = getattr(target_language_model, "model", None)
target_embed_tokens = None
if inner_model is not None:
if hasattr(inner_model, "embed_tokens"):
target_embed_tokens = inner_model.embed_tokens
elif hasattr(inner_model, "embedding"):
target_embed_tokens = inner_model.embedding
if target_embed_tokens is not None and hasattr(eagle_model, "model"):
if hasattr(eagle_model.model, "embed_tokens"):
del eagle_model.model.embed_tokens
eagle_model.model.embed_tokens = target_embed_tokens
# Only share target lm_head when the draft model does not own one.
share_lm_head = True
if hasattr(eagle_model, "has_own_lm_head"):
share_lm_head = not eagle_model.has_own_lm_head
if share_lm_head and hasattr(target_model, "lm_head"):
if hasattr(eagle_model, "lm_head"):
del eagle_model.lm_head
eagle_model.lm_head = target_model.lm_head
return eagle_model

View File

@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _rejection_sample_kernel(
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
sampled_stride,
num_sampled_ptr, # [num_reqs]
target_sampled_ptr, # [num_draft_tokens + num_reqs]
input_ids_ptr, # [num_draft_tokens + num_reqs]
cu_num_logits_ptr, # [num_reqs + 1]
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
num_sampled = 0
rejected = False
for i in range(num_tokens - 1):
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
num_sampled += 1
if target_sampled != draft_sampled:
rejected = True
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
)
num_sampled += 1
tl.store(num_sampled_ptr + req_idx, num_sampled)
def rejection_sample(
# [num_draft_tokens + num_reqs]
target_sampled: torch.Tensor,
# [num_draft_tokens + num_reqs]
input_ids: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_sampled,
input_ids,
cu_num_logits,
num_warps=1,
)
return sampled, num_sampled

View File

@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.v1.outputs import DraftTokenIds
from vllm.v1.worker.gpu.async_utils import async_copy_to_np
from vllm.v1.worker.gpu.input_batch import InputBatch
class DraftTokensHandler:
def __init__(self, device: torch.device | None = None):
self.device = device
self.copy_stream = torch.cuda.Stream(device)
self.copy_event = torch.cuda.Event()
self.req_ids: list[str] = []
self.draft_tokens_np: np.ndarray | None = None
self.num_draft_tokens: int = 0
def set_draft_tokens(
self, input_batch: InputBatch, draft_tokens: torch.Tensor
) -> None:
self.req_ids = input_batch.req_ids
self.num_draft_tokens = draft_tokens.shape[1]
if not input_batch.has_structured_output_reqs:
# No draft token validation needs to be performed by
# the scheduler for this batch.
self.draft_tokens_np = None
return
# For spec decoding + structured outputs, we must transfer the
# draft tokens back to the scheduler for grammar validation.
current_stream = torch.cuda.current_stream(self.device)
self.copy_stream.wait_stream(current_stream)
with torch.cuda.stream(self.copy_stream):
self.draft_tokens_np = async_copy_to_np(draft_tokens)
self.copy_event.record()
def get_draft_tokens(self) -> DraftTokenIds | None:
if self.draft_tokens_np is not None:
self.copy_event.synchronize()
draft_token_ids = self.draft_tokens_np.tolist()
else:
# This case only happens when async scheduling is disabled.
draft_token_ids = [[-1] * self.num_draft_tokens for _ in self.req_ids]
return DraftTokenIds(self.req_ids, draft_token_ids)

View File

@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
num_speculative_steps: int,
vocab_size: int,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.num_speculative_steps = num_speculative_steps
self.vocab_size = vocab_size
self.device = device
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs))
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# depending on the configured max_num_reqs and max_model_len.
# To save GPU memory, we use UVA instead of GPU for this tensor.
self.all_token_ids = StagedWriteTensor(
(self.max_num_reqs, self.max_model_len),
dtype=torch.int32,
device=device,
uva_instead_of_gpu=True,
)
# NOTE(woosuk): Distinguish clearly between prompt_len and prefill_len:
# - prompt_len: Number of tokens in the user-provided prompt.
# - prefill_len: Number of tokens passed into the model runner.
# This can include the prompt and additional partial output tokens,
# so prefill_len >= prompt_len.
# Usually, prefill_len equals prompt_len, but in cases such as resumption after
# preemption, prefill_len may be greater. Differentiating between these values
# is crucial, as certain features such as prompt logprobs or frequency penalties
# must treat prompt and output tokens separately.
self.prompt_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
# total_len = prompt_len + output_len. It grows as the request progresses.
self.total_len = StagedWriteTensor(
self.max_num_reqs, dtype=torch.int32, device=device
)
# Number of computed tokens.
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = StagedWriteTensor(
self.max_num_reqs, dtype=torch.int32, device=device
)
# Last sampled tokens.
self.last_sampled_tokens = torch.zeros(
self.max_num_reqs, 1, dtype=torch.int64, device=device
)
# Draft tokens.
self.draft_tokens = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
dtype=torch.int64,
device=device,
)
self.next_prefill_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
def add_request(
self,
req_id: str,
prompt_len: int,
all_token_ids: list[int],
num_computed_tokens: int,
) -> None:
assert len(self.free_indices) > 0, "No free indices"
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
self.prompt_len.np[req_idx] = prompt_len
prefill_len = len(all_token_ids)
assert prefill_len >= prompt_len, (
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
)
self.prefill_len.np[req_idx] = prefill_len
self.total_len.stage_write_elem(req_idx, prefill_len)
self.all_token_ids.stage_write(req_idx, 0, all_token_ids)
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
def apply_staged_writes(self) -> None:
self.prompt_len.copy_to_uva()
self.prefill_len.copy_to_uva()
self.total_len.apply_write()
self.all_token_ids.apply_write()
self.num_computed_tokens.apply_write()
def remove_request(self, req_id: str) -> None:
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
def any_prefills(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(
self.num_computed_prefill_tokens[idx_mapping_np]
< self.prefill_len.np[idx_mapping_np]
)

View File

@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.input_batch import InputBatch
class StructuredOutputsWorker:
def __init__(self, max_num_logits: int, vocab_size: int, device: torch.device):
self.logits_indices = torch.zeros(
max_num_logits, dtype=torch.int32, device=device
)
self.grammar_bitmask = torch.zeros(
(max_num_logits, cdiv(vocab_size, 32)), dtype=torch.int32, device=device
)
self.device = device
self.copy_stream = torch.cuda.Stream()
def apply_grammar_bitmask(
self,
logits: torch.Tensor,
input_batch: InputBatch,
grammar_req_ids: list[str],
grammar_bitmask: np.ndarray,
) -> None:
if not grammar_req_ids:
return
# Asynchronously copy the bitmask to GPU.
with torch.cuda.stream(self.copy_stream):
bitmask = async_copy_to_gpu(
grammar_bitmask, out=self.grammar_bitmask[: grammar_bitmask.shape[0]]
)
# Construct bitmask -> logits mapping
mapping: list[int] = []
req_ids = input_batch.req_ids
cu_num_logits = input_batch.cu_num_logits_np.tolist()
req_id_to_idx = {req_id: i for i, req_id in enumerate(req_ids)}
for grammar_req_id in grammar_req_ids:
req_idx = req_id_to_idx[grammar_req_id]
logits_start_idx = cu_num_logits[req_idx]
logits_end_idx = cu_num_logits[req_idx + 1]
mapping.extend(range(logits_start_idx, logits_end_idx))
# Asynchronously copy the mapping to GPU.
with torch.cuda.stream(self.copy_stream):
logits_indices = torch.tensor(
mapping, dtype=torch.int32, device="cpu", pin_memory=True
)
logits_indices = self.logits_indices[: len(mapping)].copy_(
logits_indices, non_blocking=True
)
# Ensure all async copies are complete before launching the kernel.
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(self.copy_stream)
num_masks = bitmask.shape[0]
assert num_masks == len(mapping)
vocab_size = logits.shape[-1]
BLOCK_SIZE = 8192
grid = (num_masks, triton.cdiv(vocab_size, BLOCK_SIZE))
_apply_grammar_bitmask_kernel[grid](
logits,
logits.stride(0),
logits_indices,
bitmask,
bitmask.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
# Ensure the copy stream waits for the device tensors to finish being used
# before it re-uses or deallocates them
self.copy_stream.wait_stream(current_stream)
# Adapted from
# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
@triton.jit
def _apply_grammar_bitmask_kernel(
logits_ptr,
logits_stride,
logits_indices_ptr,
bitmask_ptr,
bitmask_stride,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
bitmask_idx = tl.program_id(0)
logits_idx = tl.load(logits_indices_ptr + bitmask_idx)
# Load the bitmask.
block_id = tl.program_id(1)
bitmask_offset = (block_id * BLOCK_SIZE) // 32 + tl.arange(0, BLOCK_SIZE // 32)
packed_bitmask = tl.load(
bitmask_ptr + bitmask_idx * bitmask_stride + bitmask_offset,
mask=bitmask_offset < bitmask_stride,
)
# Unpack the bitmask.
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)
# Apply the bitmask to the logits.
block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
logits_ptr + logits_idx * logits_stride + block_offset,
-float("inf"),
mask=bitmask & (block_offset < vocab_size),
)