Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,4 @@
# [Experimental] Model Runner V2
This directory contains the new model runner which is under active development.
Ping [Woosuk Kwon](https://github.com/WoosukKwon) for any changes.

View File

View File

@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
import numpy as np
import torch
from vllm.v1.outputs import (
AsyncModelRunnerOutput,
LogprobsTensors,
ModelRunnerOutput,
)
from vllm.v1.worker.gpu.sample.output import SamplerOutput
class AsyncOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput,
num_sampled_tokens: torch.Tensor,
copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event,
):
# NOTE(woosuk): We must retain references to the GPU tensors,
# as the copy operations are performed on a different CUDA stream than
# the one where the tensors were created.
self.model_runner_output = model_runner_output
self.sampler_output = sampler_output
self.num_sampled_tokens = num_sampled_tokens
self.copy_stream = copy_stream
self.copy_event = copy_event
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.copy_stream):
self.copy_stream.wait_stream(default_stream)
self.sampled_token_ids = async_copy_to_np(sampler_output.sampled_token_ids)
if sampler_output.logprobs_tensors is not None:
self.logprobs_tensors: LogprobsTensors | None = (
sampler_output.logprobs_tensors.to_cpu_nonblocking()
)
else:
self.logprobs_tensors = None
if sampler_output.num_nans is not None:
self.num_nans = async_copy_to_np(sampler_output.num_nans)
else:
self.num_nans = None
self.num_sampled_tokens_np = async_copy_to_np(num_sampled_tokens)
self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
if self.model_runner_output.prompt_logprobs_dict:
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
if v is not None:
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
else:
self.prompt_logprobs_dict[k] = None
self.copy_event.record(self.copy_stream)
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
# NOTE(woosuk): The following code is to ensure compatibility with
# the existing model runner.
# Going forward, we should keep the data structures as NumPy arrays
# rather than Python lists.
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
num_reqs = len(sampled_token_ids)
num_sampled_tokens = self.num_sampled_tokens_np.tolist()
for i in range(num_reqs):
del sampled_token_ids[i][num_sampled_tokens[i] :]
self.model_runner_output.sampled_token_ids = sampled_token_ids
if self.num_nans is not None:
num_nans = self.num_nans.tolist()
self.model_runner_output.num_nans_in_logits = {
req_id: num_nans[i]
for i, req_id in enumerate(self.model_runner_output.req_ids)
}
if self.logprobs_tensors is not None:
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
return self.model_runner_output
@contextmanager
def async_barrier(event: torch.cuda.Event | None):
if event is not None:
event.synchronize()
try:
yield
finally:
if event is not None:
event.record()
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
return x.to("cpu", non_blocking=True).numpy()

View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any, cast
import numpy as np
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheConfig,
KVCacheSpec,
)
from vllm.v1.worker.utils import bind_kv_cache
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
kv_cache_spec: dict[str, KVCacheSpec] = {}
layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type)
for layer_name, attn_module in attn_layers.items():
# Skip modules that don't need KV cache (eg encoder-only attention)
if spec := attn_module.get_kv_cache_spec(vllm_config):
kv_cache_spec[layer_name] = spec
return kv_cache_spec
def init_attn_backend(
kv_cache_config: KVCacheConfig,
vllm_config: VllmConfig,
device: torch.device,
):
attn_backends: dict[str, type[AttentionBackend]] = {}
attn_metadata_builders: list[AttentionMetadataBuilder] = []
flashinfer_workspace: torch.Tensor | None = None
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
layer_names = kv_cache_group_spec.layer_names
any_layer_name = next(iter(layer_names))
layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
attn_backend = attn_layers[any_layer_name].get_attn_backend()
for layer_name in layer_names:
attn_backends[layer_name] = attn_backend
attn_metadata_builder = attn_backend.get_builder_cls()(
kv_cache_group_spec.kv_cache_spec,
layer_names,
vllm_config,
device,
)
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
if "FLASHINFER" in attn_backend.get_name():
if flashinfer_workspace is None:
flashinfer_workspace = attn_metadata_builder._get_workspace_buffer()
else:
attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
return attn_backends, attn_metadata_builders
def _allocate_kv_cache(
kv_cache_config: KVCacheConfig,
device: torch.device,
):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys()), (
"Some layers are not correctly initialized"
)
return kv_cache_raw_tensors
def _reshape_kv_cache(
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, AttentionBackend],
) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
assert isinstance(kv_cache_spec, AttentionSpec)
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
attn_backend = attn_backends[layer_name]
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
)
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
dtype = kv_cache_spec.dtype
raw_tensor = raw_tensor.view(dtype)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
return kv_caches
def init_kv_cache(
runner_kv_caches: list[torch.Tensor],
forward_context: dict[str, Any],
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend],
device: torch.device,
) -> None:
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,
num_tokens: int,
query_start_loc_gpu: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_np: np.ndarray,
num_computed_tokens_cpu: torch.Tensor | None,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
max_query_len = int(query_start_loc_cpu.max())
seq_lens = seq_lens[:num_reqs]
seq_lens_cpu = torch.from_numpy(seq_lens_np)
max_seq_len = int(seq_lens_np.max())
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
_seq_lens_cpu=seq_lens_cpu,
max_seq_len=max_seq_len,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
causal=True,
)
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata

View File

@@ -0,0 +1,314 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
class BlockTables:
def __init__(
self,
block_sizes: list[int],
max_num_reqs: int,
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
pin_memory: bool,
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
self.pin_memory = pin_memory
self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[torch.Tensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size)
block_table = torch.zeros(
self.max_num_reqs,
max_num_blocks,
dtype=torch.int32,
device=self.device,
)
self.block_tables.append(block_table)
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
# Block tables used for model's forward pass.
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.input_block_tables: list[torch.Tensor] = [
torch.zeros_like(block_table) for block_table in self.block_tables
]
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
self.block_table_strides = torch.tensor(
[b.stride(0) for b in self.block_tables],
dtype=torch.int64,
device=self.device,
)
self.block_sizes_tensor = torch.tensor(
self.block_sizes, dtype=torch.int32, device=self.device
)
self.num_blocks = torch.zeros(
self.num_kv_cache_groups,
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
self.slot_mappings = torch.zeros(
self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device,
)
# Misc buffers.
self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
self.cu_num_new_blocks = self._make_buffer(
self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32
)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
ptrs_tensor_cpu = torch.tensor(
[t.data_ptr() for t in x],
dtype=torch.uint64,
device="cpu",
pin_memory=self.pin_memory,
)
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids(
self,
# [num_reqs]
req_indices: list[int],
# [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks: tuple[list[int], ...],
# [num_kv_cache_groups, num_new_blocks]
new_block_ids: tuple[list[int], ...],
# [num_reqs]
overwrite: list[bool],
) -> None:
num_reqs = len(req_indices)
self.req_indices.np[:num_reqs] = req_indices
self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i]
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
# no clear upper bound to the number of new blocks in a single step.
# NOTE(woosuk): The buffer has to be cached, because otherwise we cannot
# guarantee that the buffer is not freed before the copy is completed.
self.new_block_ids_cpu = torch.empty(
self.num_kv_cache_groups,
max(len(x) for x in new_block_ids),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
new_block_ids_np = self.new_block_ids_cpu.numpy()
for i in range(self.num_kv_cache_groups):
new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i]
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True)
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
self.req_indices.copy_to_gpu(num_reqs),
self.cu_num_new_blocks.copy_to_gpu(),
self.cu_num_new_blocks.gpu.stride(0),
new_block_ids_gpu,
new_block_ids_gpu.stride(0),
self.overwrite.copy_to_gpu(num_reqs),
self.block_table_strides,
self.block_table_ptrs,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024, # type: ignore
)
def gather_block_tables(
self,
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
idx_mapping,
self.block_table_ptrs,
self.input_block_table_ptrs,
self.block_table_strides,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024, # type: ignore
)
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def compute_slot_mappings(
self,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
num_reqs = query_start_loc.shape[0] - 1
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens,
query_start_loc,
positions,
self.input_block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mappings,
self.slot_mappings.stride(0),
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024, # type: ignore
)
return self.slot_mappings[:, :num_tokens]
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
self.slot_mappings.fill_(PAD_SLOT_ID)
return self.slot_mappings[:, :num_tokens]
@triton.jit
def _append_block_ids_kernel(
# Inputs
req_indices, # [num_reqs]
cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks_stride,
new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
new_block_ids_stride,
overwrite, # [num_reqs]
block_table_strides, # [num_kv_cache_groups]
# Outputs
block_table_ptrs, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
# Constants
BLOCK_SIZE: tl.constexpr,
):
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(req_indices + batch_idx)
do_overwrite = tl.load(overwrite + batch_idx)
group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
num_new_blocks = end_idx - start_idx
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
dst_start_idx = tl.load(group_num_blocks_ptr + req_idx) if not do_overwrite else 0
dst_end_idx = dst_start_idx + num_new_blocks
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
# Destination
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
row_ptr = block_table_ptr + req_idx * block_table_stride
group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride
for i in range(0, num_new_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(
group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks
)
tl.store(
row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks
)
@triton.jit
def _gather_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_kv_cache_groups]
dst_block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
@triton.jit
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
cu_num_tokens, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
page_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
req_idx = tl.program_id(1)
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
if req_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
page_size = tl.load(page_sizes + group_id)
start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in range(start_idx, end_idx, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size
block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices
)
slot_ids = block_numbers * page_size + positions % page_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr)
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
return tl.multiple_of(ptr, 16)

View File

@@ -0,0 +1,259 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_mode: CUDAGraphMode
if self.compilation_config.cudagraph_mode is None:
self.cudagraph_mode = CUDAGraphMode.NONE
else:
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
def needs_capture(self) -> bool:
return len(self.cudagraph_sizes) > 0
def get_cudagraph_size(
self,
scheduler_output: SchedulerOutput,
num_tokens_after_padding: int,
) -> int | None:
return get_cudagraph_size(
num_tokens_after_padding,
scheduler_output.num_scheduled_tokens.values(),
self.cudagraph_sizes,
self.cudagraph_mode,
)
def capture_graph(
self,
num_tokens: int,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
attn_metadata = prepare_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
block_tables,
attn_metadata_builders,
self.max_model_len,
kv_cache_config,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
with (
set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
),
torch.cuda.graph(graph, self.pool),
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
self.hidden_states[:num_tokens] = hidden_states
self.graphs[num_tokens] = graph
@torch.inference_mode()
def capture(
self,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
model=model,
input_buffers=input_buffers,
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
kv_cache_config=kv_cache_config,
)
def run(self, num_tokens: int) -> torch.Tensor:
assert num_tokens in self.graphs
self.graphs[num_tokens].replay()
assert self.hidden_states is not None
return self.hidden_states[:num_tokens]
def get_cudagraph_sizes(
capture_sizes: list[int] | None,
max_num_reqs: int,
max_num_tokens: int,
cudagraph_mode: CUDAGraphMode,
) -> dict[int, int]:
if not cudagraph_mode.has_full_cudagraphs():
return {}
if not capture_sizes:
return {}
capture_sizes = sorted(capture_sizes)
# Limit the capture sizes to the max number of requests or tokens.
upper_bound = (
max_num_reqs
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
else max_num_tokens
)
capture_sizes = [x for x in capture_sizes if x <= upper_bound]
if not capture_sizes:
return {}
cudagraph_sizes: dict[int, int] = {}
for i in range(1, capture_sizes[-1] + 1):
for x in capture_sizes:
if i <= x:
cudagraph_sizes[i] = x
break
return cudagraph_sizes
def get_cudagraph_size(
num_tokens_after_dp_padding: int,
num_tokens_per_request: Iterable[int],
cudagraph_sizes: dict[int, int],
cudagraph_mode: CUDAGraphMode,
) -> int | None:
size = cudagraph_sizes.get(num_tokens_after_dp_padding)
if size is None:
# No CUDA graph for this size.
return None
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
all_decode = all(x == 1 for x in num_tokens_per_request)
if not all_decode:
# Prefill is included.
return None
return size
def capture_graphs(
cudagraph_sizes: dict[int, int],
device: torch.device,
capture_fn: Callable,
**capture_kwargs,
) -> None:
# Capture larger graphs first.
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
with graph_capture(device=device):
for size in sizes_to_capture:
capture_fn(size, **capture_kwargs)
def prepare_inputs_to_capture(
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
max_model_len: int,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc = input_buffers.query_start_loc
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req
query_start_loc.np[num_reqs:] = num_tokens
query_start_loc.copy_to_gpu()
seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32)
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len. This introduces a discrepancy between
# seq_lens (on GPU) and seq_lens_np (on CPU), which may cause issues for
# certain attention backends.
input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1],
query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1],
seq_lens=input_buffers.seq_lens,
seq_lens_np=seq_lens_np,
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
)
return attn_metadata

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import get_dp_group
def get_batch_metadata_across_dp(
num_tokens: int,
cudagraph_size: int,
dp_size: int,
dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert dp_size > 1
# Use CPU group to avoid CPU-GPU synchronization.
group = get_dp_group().cpu_group
tensor = torch.zeros(2, dp_size, dtype=torch.int32, device="cpu")
tensor[0][dp_rank] = num_tokens
tensor[1][dp_rank] = cudagraph_size
dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1]
def make_num_tokens_across_dp(
dp_size: int,
num_tokens: int,
) -> torch.Tensor | None:
if dp_size == 1:
return None
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")

View File

@@ -0,0 +1,479 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.utils import random_uuid
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
class InputBuffers:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
inputs_embeds_size: int,
vocab_size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.device = device
self.pin_memory = pin_memory
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
# Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.grammar_bitmask = self._make_buffer(
max_num_reqs, cdiv(vocab_size, 32), dtype=torch.int32
)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
)
@dataclass
class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
num_reqs: int
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# [num_reqs]
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
# sum(num_scheduled_tokens)
num_tokens: int
num_tokens_after_padding: int
num_draft_tokens: int
# [num_reqs + 1]
query_start_loc: torch.Tensor
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
seq_lens_np: np.ndarray
# [num_tokens_after_padding]
input_ids: torch.Tensor
# [num_tokens_after_padding]
positions: torch.Tensor
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# [total_num_logits]
logits_indices: torch.Tensor
# [num_reqs + 1]
cu_num_logits: torch.Tensor
@classmethod
def make_dummy(
cls,
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
device: torch.device,
) -> "InputBatch":
assert 0 < num_reqs <= num_tokens
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
assert int(num_scheduled_tokens.sum()) == num_tokens
input_buffers.query_start_loc.np[0] = 0
input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
num_scheduled_tokens
)
input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
# seq_len equals to query_len
seq_lens_np = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
seq_lens_np[-1] += num_tokens % num_reqs
input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
input_buffers.seq_lens[num_reqs:] = 0
seq_lens = input_buffers.seq_lens[:num_reqs]
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
num_draft_tokens=0,
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
seq_lens_np=seq_lens_np,
input_ids=input_ids,
positions=positions,
attn_metadata=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
)
@triton.jit
def _prepare_prefill_inputs_kernel(
input_ids_ptr,
next_prefill_tokens_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
prefill_token_ids_ptr,
prefill_token_ids_stride,
prefill_lens_ptr,
num_computed_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
if num_computed >= prefill_len:
# Not prefill.
return
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
prefill_ptr = prefill_token_ids_ptr + req_state_idx * prefill_token_ids_stride
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
tokens = tl.load(prefill_ptr + num_computed + block, mask=mask)
tl.store(input_ids_ptr + query_start + block, tokens, mask=mask)
next_pos = num_computed + query_len
if next_pos < prefill_len:
next_token = tl.load(prefill_ptr + next_pos)
tl.store(next_prefill_tokens_ptr + req_state_idx, next_token)
def prepare_prefill_inputs(
input_ids: torch.Tensor,
next_prefill_tokens: torch.Tensor,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
prefill_token_ids: torch.Tensor,
prefill_len: torch.Tensor,
num_computed_tokens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_prepare_prefill_inputs_kernel[(num_reqs,)](
input_ids,
next_prefill_tokens,
idx_mapping,
query_start_loc,
prefill_token_ids,
prefill_token_ids.stride(0),
prefill_len,
num_computed_tokens,
BLOCK_SIZE=1024,
)
@triton.jit
def _prepare_pos_seq_lens_kernel(
pos_ptr,
seq_lens_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
num_computed_tokens_ptr,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
req_id = tl.program_id(0)
num_reqs = tl.num_programs(0) - 1
if req_id == num_reqs:
# Pad unused seq_lens as 0 for full CUDA graphs.
for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(seq_lens_ptr + block, 0, mask=mask)
return
req_state_idx = tl.load(idx_mapping_ptr + req_id)
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
start = tl.load(query_start_loc_ptr + req_id)
end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = end - start
seq_len = num_computed_tokens + query_len
tl.store(seq_lens_ptr + req_id, seq_len)
for i in tl.range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
pos = num_computed_tokens + block
tl.store(pos_ptr + start + block, pos, mask=mask)
def prepare_pos_seq_lens(
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
num_computed_tokens: torch.Tensor,
pos: torch.Tensor,
seq_lens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
# NOTE(woosuk): We do +1 because the last thread block is used
# to pad unused seq_lens as 0 for full CUDA graphs.
_prepare_pos_seq_lens_kernel[(num_reqs + 1,)](
pos,
seq_lens,
idx_mapping,
query_start_loc,
num_computed_tokens,
seq_lens.shape[0],
BLOCK_SIZE=1024,
)
@triton.jit
def _combine_sampled_and_draft_tokens_kernel(
input_ids_ptr,
idx_mapping_ptr,
last_sampled_tokens_ptr,
query_start_loc_ptr,
seq_lens_ptr,
prefill_len_ptr,
draft_tokens_ptr,
draft_tokens_stride,
cu_num_logits_ptr,
logits_indices_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
# Get the number of logits and draft tokens.
cu_num_logits_start = tl.load(cu_num_logits_ptr + batch_idx)
cu_num_logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
num_logits = cu_num_logits_end - cu_num_logits_start
num_draft_tokens = num_logits - 1
# Compute the logits indices.
block = tl.arange(0, BLOCK_SIZE)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
logits_start = query_end - num_logits
tl.store(
logits_indices_ptr + cu_num_logits_start + block,
logits_start + block,
mask=block < num_logits,
)
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if seq_len <= prefill_len:
# Handling prefill tokens. No sampled or draft tokens.
return
# Write the last sampled token ID to input_ids.
last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
tl.store(input_ids_ptr + query_end - num_logits, last_token_id)
# Write the draft tokens (if any) to input_ids.
if num_draft_tokens > 0:
mask = block < num_draft_tokens
draft_tokens = tl.load(
draft_tokens_ptr + req_state_idx * draft_tokens_stride + block,
mask=mask,
)
tl.store(
input_ids_ptr + query_end - num_draft_tokens + block,
draft_tokens,
mask=mask,
)
def combine_sampled_and_draft_tokens(
input_ids: torch.Tensor,
idx_mapping: torch.Tensor,
last_sampled_tokens: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
prefill_len: torch.Tensor,
draft_tokens: torch.Tensor,
cu_num_logits: torch.Tensor,
num_logits: int,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
num_speculative_steps = draft_tokens.shape[-1]
logits_indices = torch.empty(
num_logits,
dtype=torch.int64,
device=input_ids.device,
)
_combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
input_ids,
idx_mapping,
last_sampled_tokens,
query_start_loc,
seq_lens,
prefill_len,
draft_tokens,
draft_tokens.stride(0),
cu_num_logits,
logits_indices,
# NOTE(woosuk): Add 1 to ensure the block can cover the last sampled token
# in addition to all draft tokens.
BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
)
return logits_indices
@triton.jit
def _get_num_sampled_and_rejected_kernel(
num_sampled_ptr,
num_rejected_ptr,
seq_lens_ptr,
cu_num_logits_ptr,
idx_mapping_ptr,
prefill_len_ptr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
is_chunked_prefilling = seq_len < prefill_len
num_sampled = tl.load(num_sampled_ptr + batch_idx)
num_sampled = tl.where(is_chunked_prefilling, 0, num_sampled)
tl.store(num_sampled_ptr + batch_idx, num_sampled)
logits_start = tl.load(cu_num_logits_ptr + batch_idx)
logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
num_logits = logits_end - logits_start
num_rejected = num_logits - num_sampled
num_rejected = tl.where(is_chunked_prefilling, 0, num_rejected)
tl.store(num_rejected_ptr + batch_idx, num_rejected)
def get_num_sampled_and_rejected(
num_sampled: torch.Tensor,
seq_lens: torch.Tensor,
cu_num_logits: torch.Tensor,
idx_mapping: torch.Tensor,
prefill_len: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = idx_mapping.shape[0]
num_rejected = torch.empty_like(num_sampled)
_get_num_sampled_and_rejected_kernel[(num_reqs,)](
num_sampled,
num_rejected,
seq_lens,
cu_num_logits,
idx_mapping,
prefill_len,
)
return num_sampled, num_rejected
@triton.jit
def _post_update_kernel(
idx_mapping_ptr,
num_computed_tokens_ptr,
last_sampled_tokens_ptr,
output_bin_counts_ptr,
output_bin_counts_stride,
sampled_tokens_ptr,
sampled_tokens_stride,
num_sampled_ptr,
num_rejected_ptr,
query_start_loc_ptr,
):
req_id = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_id)
num_sampled = tl.load(num_sampled_ptr + req_id)
if num_sampled > 0:
token_id = tl.load(
sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1
)
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
for i in range(num_sampled):
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
token_ptr = (
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
)
count = tl.load(token_ptr)
count += 1
tl.store(token_ptr, count)
query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = query_end - query_start
num_rejected = tl.load(num_rejected_ptr + req_id)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
num_computed += query_len - num_rejected
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
def post_update(
# [num_reqs]
idx_mapping: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [max_num_reqs]
last_sampled_tokens: torch.Tensor,
# [max_num_reqs, vocab_size]
output_bin_counts: torch.Tensor,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens: torch.Tensor,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [num_reqs + 1]
query_start_loc: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_post_update_kernel[(num_reqs,)](
idx_mapping,
num_computed_tokens,
last_sampled_tokens,
output_bin_counts,
output_bin_counts.stride(0),
sampled_tokens,
sampled_tokens.stride(0),
num_sampled,
num_rejected,
query_start_loc,
num_warps=1,
)

View File

View File

@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch._inductor.runtime.triton_helpers import libdevice
from vllm.triton_utils import tl, triton
@triton.jit
def _num_nans_kernel(
logits_ptr,
logits_stride,
num_nans_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_nans = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=0
)
logits = logits.to(tl.float32)
is_nan = libdevice.isnan(logits).to(tl.int1)
num_nans += tl.sum(is_nan).to(tl.int32)
tl.store(num_nans_ptr + req_idx, num_nans)
def get_num_nans(logits: torch.Tensor) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device)
_num_nans_kernel[(num_reqs,)](
logits,
logits.stride(0),
num_nans,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
return num_nans

File diff suppressed because it is too large Load Diff

View File

View File

@@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _gumbel_sample_kernel(
local_argmax_ptr,
local_argmax_stride,
local_max_ptr,
local_max_stride,
logits_ptr,
logits_stride,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
req_idx = tl.program_id(0)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_idx)
pos = tl.load(pos_ptr + req_idx)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise.
r = tl.rand(gumbel_seed, block).to(tl.float64)
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
gumbel_noise = gumbel_noise.to(tl.float32)
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
idx = tl.argmax(logits, axis=0)
token_id = block_idx * BLOCK_SIZE + idx
value = tl.max(logits, axis=0)
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
apply_temperature: bool,
) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_reqs,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_reqs,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
_gumbel_sample_kernel[(num_reqs, num_blocks)](
local_argmax,
local_argmax.stride(0),
local_max,
local_max.stride(0),
logits,
logits.stride(0),
seed,
pos,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
APPLY_TEMPERATURE=apply_temperature,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
return sampled

View File

@@ -0,0 +1,167 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
@triton.jit
def _topk_log_softmax_kernel(
output_ptr,
logits_ptr,
logits_stride,
topk_ids_ptr,
topk,
vocab_size,
BLOCK_SIZE: tl.constexpr,
PADDED_TOPK: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore
se = 0.0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
logits = logits.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e)
lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < topk
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
logits = logits.to(tl.float32)
o = logits - max_val - lse
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
@triton.jit
def _ranks_kernel(
output_ptr,
logits_ptr,
logits_stride,
token_ids_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
token_id = tl.load(token_ids_ptr + req_idx)
x = tl.load(row_ptr + token_id)
n = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
n += tl.sum((logits > x).to(tl.int32))
tl.store(output_ptr + req_idx, n)
def compute_token_logprobs(
logits: torch.Tensor,
token_ids: torch.Tensor,
) -> torch.Tensor:
batch_size = logits.shape[0]
vocab_size = logits.shape[1]
token_ids = token_ids.to(torch.int64)
num_logprobs = token_ids.shape[1]
logprobs = torch.empty(
batch_size,
num_logprobs,
dtype=torch.float32,
device=logits.device,
)
_topk_log_softmax_kernel[(batch_size,)](
logprobs,
logits,
logits.stride(0),
token_ids,
num_logprobs,
vocab_size,
BLOCK_SIZE=1024, # type: ignore
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
)
return logprobs
def compute_topk_logprobs(
logits: torch.Tensor,
num_logprobs: int,
sampled_token_ids: torch.Tensor,
) -> LogprobsTensors:
assert num_logprobs >= 0
batch_size, vocab_size = logits.shape
if num_logprobs == 0:
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
else:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat(
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
logprobs = compute_token_logprobs(logits, logprob_token_ids)
token_ranks = torch.empty(
batch_size,
dtype=torch.int64,
device=logits.device,
)
_ranks_kernel[(batch_size,)](
token_ranks,
logits,
logits.stride(0),
sampled_token_ids,
vocab_size,
BLOCK_SIZE=8192, # type: ignore
)
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=token_ranks,
)
def compute_prompt_logprobs(
prompt_token_ids: torch.Tensor,
prompt_hidden_states: torch.Tensor,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE = 1024
logprobs = []
ranks = []
prompt_token_ids = prompt_token_ids.to(torch.int64)
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
end_idx = start_idx + CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
prompt_logprobs = compute_topk_logprobs(
prompt_logits,
0, # num_logprobs
prompt_token_ids[start_idx:end_idx],
)
logprobs.append(prompt_logprobs.logprobs)
ranks.append(prompt_logprobs.selected_token_ranks)
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
return logprobs, ranks

View File

@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.triton_utils import tl, triton
@dataclass
class SamplingMetadata:
temperature: torch.Tensor
top_p: torch.Tensor | None
top_k: torch.Tensor | None
min_p: torch.Tensor | None
repetition_penalty: torch.Tensor
frequency_penalty: torch.Tensor
presence_penalty: torch.Tensor
seeds: torch.Tensor
pos: torch.Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: int | None
# For penalties
idx_mapping: torch.Tensor
prompt_bin_mask: torch.Tensor
output_bin_counts: torch.Tensor
@classmethod
def make_dummy(
cls,
num_reqs: int,
device: torch.device,
) -> "SamplingMetadata":
assert num_reqs > 0
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
temperature[0] = 0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
# Currently, they are disabled because of memory usage.
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
top_p = None
top_k = None
min_p = torch.zeros(num_reqs, dtype=torch.float32, device=device)
# NOTE(woosuk): We must set penalties to their default values to make sure
# the penalties kernel does not touch the placeholder bin_counts tensors.
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
return cls(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
idx_mapping=idx_mapping,
prompt_bin_mask=prompt_bin_mask,
output_bin_counts=output_bin_counts,
)
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
@triton.jit
def _expand_sampling_metadata_kernel(
temp_ptr,
expanded_temp_ptr,
top_p_ptr,
expanded_top_p_ptr,
top_k_ptr,
expanded_top_k_ptr,
min_p_ptr,
expanded_min_p_ptr,
rep_penalty_ptr,
expanded_rep_penalty_ptr,
freq_penalty_ptr,
expanded_freq_penalty_ptr,
pres_penalty_ptr,
expanded_pres_penalty_ptr,
seeds_ptr,
expanded_seeds_ptr,
cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
block = tl.arange(0, BLOCK_SIZE)
mask = block < num_tokens
temp = tl.load(temp_ptr + req_idx)
tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
if top_p_ptr is not None:
top_p = tl.load(top_p_ptr + req_idx)
tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
if top_k_ptr is not None:
top_k = tl.load(top_k_ptr + req_idx)
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
if min_p_ptr is not None:
min_p = tl.load(min_p_ptr + req_idx)
tl.store(expanded_min_p_ptr + start_idx + block, min_p, mask=mask)
rep_penalty = tl.load(rep_penalty_ptr + req_idx)
tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask)
freq_penalty = tl.load(freq_penalty_ptr + req_idx)
tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask)
pres_penalty = tl.load(pres_penalty_ptr + req_idx)
tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask)
seed = tl.load(seeds_ptr + req_idx)
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
def expand_sampling_metadata(
sampling_metadata: SamplingMetadata,
cu_num_logits: torch.Tensor,
max_expand_len: int,
) -> SamplingMetadata:
total_num_logits = sampling_metadata.pos.shape[0]
create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
expanded_temp = create_empty(sampling_metadata.temperature)
expanded_top_p = create_empty(sampling_metadata.top_p)
expanded_top_k = create_empty(sampling_metadata.top_k)
expanded_min_p = create_empty(sampling_metadata.min_p)
expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty)
expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty)
expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty)
expanded_seeds = create_empty(sampling_metadata.seeds)
num_reqs = cu_num_logits.shape[0] - 1
_expand_sampling_metadata_kernel[(num_reqs,)](
sampling_metadata.temperature,
expanded_temp,
sampling_metadata.top_p,
expanded_top_p,
sampling_metadata.top_k,
expanded_top_k,
sampling_metadata.min_p,
expanded_min_p,
sampling_metadata.repetition_penalty,
expanded_repetition_penalty,
sampling_metadata.frequency_penalty,
expanded_frequency_penalty,
sampling_metadata.presence_penalty,
expanded_presence_penalty,
sampling_metadata.seeds,
expanded_seeds,
cu_num_logits,
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
)
return SamplingMetadata(
temperature=expanded_temp,
top_p=expanded_top_p,
top_k=expanded_top_k,
min_p=expanded_min_p,
seeds=expanded_seeds,
repetition_penalty=expanded_repetition_penalty,
frequency_penalty=expanded_frequency_penalty,
presence_penalty=expanded_presence_penalty,
pos=sampling_metadata.pos,
max_num_logprobs=sampling_metadata.max_num_logprobs,
# TODO(woosuk): Support penalties with spec decoding.
idx_mapping=sampling_metadata.idx_mapping,
prompt_bin_mask=sampling_metadata.prompt_bin_mask,
output_bin_counts=sampling_metadata.output_bin_counts,
)

View File

@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _min_p_kernel(
logits_ptr,
logits_stride,
min_p_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
min_p = tl.load(min_p_ptr + req_idx).to(tl.float32)
if min_p == 0.0:
return
max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
)
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore
threshold = max_val + tl.log(min_p)
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
)
logits = tl.where(logits < threshold, float("-inf"), logits)
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
_min_p_kernel[(num_reqs,)](
logits,
logits.stride(0),
min_p,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)

View File

@@ -0,0 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.v1.outputs import LogprobsTensors
@dataclass
class SamplerOutput:
sampled_token_ids: torch.Tensor
logprobs_tensors: LogprobsTensors | None
num_nans: torch.Tensor | None

View File

@@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
@triton.jit
def _penalties_and_temperature_kernel(
logits_ptr,
logits_stride,
repetition_penalty_ptr,
frequency_penalty_ptr,
presence_penalty_ptr,
temperature_ptr,
idx_mapping_ptr,
prompt_bin_mask_ptr,
prompt_bin_mask_stride,
output_bin_counts_ptr,
output_bin_counts_stride,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
temperature = tl.load(temperature_ptr + batch_idx)
temperature = tl.where(temperature == 0.0, 1.0, temperature)
use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0
use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
use_temperature = temperature != 1.0
if not (use_penalty or use_temperature):
# Early return to avoid loading logits.
return
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
if use_penalty:
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
output_bin_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
mask=mask,
)
output_bin_mask = output_bin_counts > 0
# Apply repetition penalties.
if use_rep_penalty:
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
packed_mask = tl.load(
prompt_bin_mask_ptr
+ req_state_idx * prompt_bin_mask_stride
+ packed_block,
mask=packed_block < tl.cdiv(vocab_size, 32),
)
prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
logits *= tl.where(logits > 0, 1.0 / scale, scale)
# Apply frequency penalties.
logits -= freq_penalty * output_bin_counts
# Apply presence penalties.
logits -= pres_penalty * output_bin_mask
# Apply temperature.
logits = logits / temperature
# Store back to logits.
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
def apply_penalties_and_temperature(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
logits,
logits.stride(0),
sampling_metadata.repetition_penalty,
sampling_metadata.frequency_penalty,
sampling_metadata.presence_penalty,
sampling_metadata.temperature,
sampling_metadata.idx_mapping,
sampling_metadata.prompt_bin_mask,
sampling_metadata.prompt_bin_mask.stride(0),
sampling_metadata.output_bin_counts,
sampling_metadata.output_bin_counts.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
def _bincount_kernel(
prefill_token_ids_ptr,
prefill_len,
prompt_len,
prompt_bin_mask_ptr,
output_bin_counts_ptr,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
if block_idx * BLOCK_SIZE >= prefill_len:
return
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if block_idx * BLOCK_SIZE < prompt_len:
mask = block < prompt_len
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
idx = prefill_tokens // 32
bit_idx = prefill_tokens % 32
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
tl.atomic_or(prompt_bin_mask_ptr + idx, bit, mask=mask)
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
mask = block < prefill_len
mask &= block >= prompt_len
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask)
def bincount(
prefill_token_ids: torch.Tensor,
prefill_len: int,
prompt_len: int,
prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor,
) -> None:
prompt_bin_mask.zero_()
output_bin_counts.zero_()
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_blocks,)](
prefill_token_ids,
prefill_len,
prompt_len,
prompt_bin_mask,
output_bin_counts,
BLOCK_SIZE=BLOCK_SIZE,
)

View File

@@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.envs as envs
from vllm.config.model import LogprobsMode
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature
class Sampler:
def __init__(
self,
logprobs_mode: LogprobsMode = "raw_logprobs",
):
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
self.logprobs_mode = logprobs_mode
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
def __call__(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.compute_nans else None
sampled, processed_logits = self.sample(logits, sampling_metadata)
if sampling_metadata.max_num_logprobs is not None:
logits = (
processed_logits
if self.logprobs_mode == "processed_logprobs"
else logits
)
logprobs_tensors = compute_topk_logprobs(
logits,
sampling_metadata.max_num_logprobs,
sampled,
)
else:
logprobs_tensors = None
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors,
num_nans=num_nans,
)
return sampler_output
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply penalties and temperature in place.
apply_penalties_and_temperature(logits, sampling_metadata)
# Apply min_p in place.
if sampling_metadata.min_p is not None:
apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p. This might return a new tensor.
logits = apply_top_k_top_p(
logits, sampling_metadata.top_k, sampling_metadata.top_p
)
sampled = gumbel_sample(
logits,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
apply_temperature=False,
)
return sampled, logits

View File

@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import VllmConfig
def init_speculator(
vllm_config: VllmConfig,
device: torch.device,
):
speculative_config = vllm_config.speculative_config
assert speculative_config is not None
if speculative_config.use_eagle():
from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator
return EagleSpeculator(vllm_config, device)
raise NotImplementedError(f"{speculative_config.method} is not supported yet.")

View File

@@ -0,0 +1,565 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager
logger = init_logger(__name__)
class EagleSpeculator:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.device = device
self.speculative_config = vllm_config.speculative_config
assert self.speculative_config is not None
self.method = self.speculative_config.method
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
self.draft_model_config = self.speculative_config.draft_model_config
self.scheduler_config = vllm_config.scheduler_config
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = vllm_config.model_config.max_model_len
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
self.vocab_size = self.draft_model_config.get_vocab_size()
self.pin_memory = is_pin_memory_available()
self.dtype = vllm_config.model_config.dtype
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
inputs_embeds_size=self.inputs_embeds_size,
vocab_size=self.vocab_size,
dtype=self.dtype,
device=device,
pin_memory=self.pin_memory,
)
self.hidden_states = torch.zeros(
self.max_num_tokens,
self.hidden_size,
dtype=self.dtype,
device=device,
)
self.temperature = torch.zeros(
self.max_num_reqs,
dtype=torch.float32,
device=device,
)
self.seeds = torch.zeros(
self.max_num_reqs,
dtype=torch.int64,
device=device,
)
self.draft_tokens = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
dtype=torch.int64,
device=device,
)
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=self.draft_model_config
)
share_lm_head = True
if share_lm_head and hasattr(target_model, "lm_head"):
if hasattr(self.model, "lm_head"):
del self.model.lm_head
self.model.lm_head = target_model.lm_head
def set_attn(
self,
kv_cache_config: KVCacheConfig,
attn_metadata_builders: list[AttentionMetadataBuilder],
block_tables: BlockTables,
) -> None:
self.kv_cache_config = kv_cache_config
self.attn_metadata_builders = attn_metadata_builders
self.block_tables = block_tables
@torch.inference_mode()
def run_model(
self,
num_tokens: int,
attn_metadata: dict[str, Any],
num_tokens_across_dp: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
):
ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids[:num_tokens],
positions=self.input_buffers.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
return last_hidden_states, hidden_states
def generate_draft(
self,
num_reqs: int,
attn_metadata: dict[str, Any],
num_tokens_across_dp: torch.Tensor | None,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
for step in range(1, self.num_speculative_steps):
# Run the eagle model.
last_hidden_states, hidden_states = self.run_model(
num_reqs, attn_metadata, num_tokens_across_dp
)
logits = self.model.compute_logits(last_hidden_states)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits,
self.temperature[:num_reqs],
self.seeds[:num_reqs],
pos + 1,
apply_temperature=True,
)
self.draft_tokens[:num_reqs, step] = draft_tokens
if step < self.num_speculative_steps - 1:
# Update the inputs for the next step.
update_eagle_inputs(
draft_tokens,
hidden_states,
self.input_buffers,
self.hidden_states,
self.max_model_len,
)
self.block_tables.compute_slot_mappings(query_start_loc, pos)
def capture_model(self) -> None:
if self.num_speculative_steps == 1:
return
logger.info("Capturing model for Eagle speculator...")
self.cudagraph_manager.capture(
self.generate_draft,
self.input_buffers,
self.block_tables,
self.attn_metadata_builders,
self.kv_cache_config,
)
@torch.inference_mode()
def propose(
self,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
# [num_tokens, hidden_size]
last_hidden_states: torch.Tensor,
# num_layers x [num_tokens, hidden_size]
aux_hidden_states: list[torch.Tensor] | None,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [num_reqs]
last_sampled: torch.Tensor,
# [num_reqs]
next_prefill_tokens: torch.Tensor,
) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
# hidden_states the same as the target model's. This means, we pad each
# request's query length to include any rejected positions. By doing so,
# we can also reuse the attention metadata (e.g., query_start_loc,
# seq_lens) of the target model.
if aux_hidden_states:
assert self.method == "eagle3"
hidden_states = self.model.combine_hidden_states(
torch.cat(aux_hidden_states, dim=-1)
)
else:
hidden_states = last_hidden_states
num_tokens = input_batch.num_tokens_after_padding
self.hidden_states[:num_tokens] = hidden_states
# Get the input ids and last token indices for the speculator.
last_token_indices = prepare_eagle_inputs(
self.input_buffers,
input_batch,
num_sampled,
num_rejected,
last_sampled,
next_prefill_tokens,
)
# Prefill: Run the eagle speculator with eager mode.
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states, hidden_states = self.run_model(
num_tokens,
input_batch.attn_metadata,
num_tokens_across_dp=None, # FIXME
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs
cu_num_logits = input_batch.cu_num_logits[:num_reqs]
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
temperature = self.temperature[:num_reqs]
seeds = self.seeds[:num_reqs]
pos = self.input_buffers.positions[:num_reqs]
# Gather the values and copy them to the pre-allocated buffers.
torch.gather(sampling_metadata.temperature, 0, cu_num_logits, out=temperature)
torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds)
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits, temperature, seeds, pos + 1, apply_temperature=True
)
if self.num_speculative_steps == 1:
# Early exit.
return draft_tokens.view(-1, 1)
# Save the draft tokens for the first step.
self.draft_tokens[:num_reqs, 0] = draft_tokens
# Prepare the inputs for the decode steps.
prepare_eagle_decode(
draft_tokens,
hidden_states,
last_token_indices,
input_batch.seq_lens,
num_rejected,
self.input_buffers,
self.hidden_states,
self.max_model_len,
self.max_num_reqs,
)
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
if cudagraph_size is not None:
# Run CUDA graph.
self.cudagraph_manager.run(cudagraph_size)
return self.draft_tokens[:num_reqs]
# Run eager mode.
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1)
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
# HACK(woosuk)
seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
seq_lens_np=seq_lens_np,
num_computed_tokens_cpu=None, # FIXME
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
self.generate_draft(num_reqs, attn_metadata, num_tokens_across_dp=None) # FIXME
return self.draft_tokens[:num_reqs]
@triton.jit
def _prepare_eagle_inputs_kernel(
last_token_indices_ptr,
eagle_input_ids_ptr,
eagle_positions_ptr,
target_input_ids_ptr,
target_positions_ptr,
last_sampled_ptr,
next_prefill_tokens_ptr,
num_sampled_ptr,
num_rejected_ptr,
query_start_loc_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
# Get the true query length and next token after accounting for rejected tokens.
num_rejected = tl.load(num_rejected_ptr + batch_idx)
query_len -= num_rejected
num_sampled = tl.load(num_sampled_ptr + batch_idx)
if num_sampled > 0:
next_token = tl.load(last_sampled_ptr + batch_idx).to(tl.int32)
else:
# Chunked prefilling.
# Get the next prefill token.
next_token = tl.load(next_prefill_tokens_ptr + batch_idx)
# Shift target_input_ids by one.
for i in range(1, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
input_ids = tl.load(target_input_ids_ptr + query_start + block, mask=mask)
tl.store(eagle_input_ids_ptr + query_start + block - 1, input_ids, mask=mask)
last_token_index = query_start + query_len - 1
tl.store(last_token_indices_ptr + batch_idx, last_token_index)
tl.store(eagle_input_ids_ptr + last_token_index, next_token)
# Copy positions.
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
target_pos = tl.load(target_positions_ptr + query_start + block, mask=mask)
tl.store(eagle_positions_ptr + query_start + block, target_pos, mask=mask)
def prepare_eagle_inputs(
input_buffers: InputBuffers,
input_batch: InputBatch,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [num_reqs]
last_sampled: torch.Tensor,
# [num_reqs]
next_prefill_tokens: torch.Tensor,
) -> torch.Tensor:
num_reqs = input_batch.num_reqs
last_token_indices = torch.empty(
num_reqs,
dtype=torch.int64,
device=num_sampled.device,
)
_prepare_eagle_inputs_kernel[(num_reqs,)](
last_token_indices,
input_buffers.input_ids,
input_buffers.positions,
input_batch.input_ids,
input_batch.positions,
last_sampled,
next_prefill_tokens,
num_sampled,
num_rejected,
input_batch.query_start_loc,
BLOCK_SIZE=1024,
)
return last_token_indices
@triton.jit
def _prepare_eagle_docode_kernel(
draft_tokens_ptr,
output_hidden_states_ptr,
output_hidden_states_stride,
last_token_indices_ptr,
target_seq_lens_ptr,
num_rejected_ptr,
input_ids_ptr,
positions_ptr,
input_hidden_states_ptr,
input_hidden_states_stride,
query_start_loc_ptr,
seq_lens_ptr,
hidden_size,
max_model_len,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_reqs = tl.num_programs(0) - 1
if req_idx == num_reqs:
# Compute query_start_loc. Pad it with the last query_start_loc
# for CUDA graphs.
for i in range(0, max_num_reqs + 1, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
q = tl.where(block < num_reqs, block, num_reqs)
mask = block < max_num_reqs + 1
tl.store(query_start_loc_ptr + block, q, mask=mask)
# Pad seq_lens for CUDA graphs.
for i in range(req_idx, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(seq_lens_ptr + block, 0, mask=mask)
return
# draft token -> input id.
draft_token = tl.load(draft_tokens_ptr + req_idx)
tl.store(input_ids_ptr + req_idx, draft_token)
# output hidden states -> input hidden states.
src_idx = tl.load(last_token_indices_ptr + req_idx)
for i in range(0, hidden_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < hidden_size
output_hidden_states = tl.load(
output_hidden_states_ptr + src_idx * output_hidden_states_stride + block,
mask=mask,
)
tl.store(
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
output_hidden_states,
mask=mask,
)
# Compute position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position = tl.load(positions_ptr + req_idx)
position = tl.minimum(position + 1, max_model_len - 1)
tl.store(positions_ptr + req_idx, position)
target_seq_len = tl.load(target_seq_lens_ptr + req_idx)
num_rejected = tl.load(num_rejected_ptr + req_idx)
seq_len = target_seq_len - num_rejected
seq_len = tl.minimum(seq_len + 1, max_model_len)
tl.store(seq_lens_ptr + req_idx, seq_len)
def prepare_eagle_decode(
draft_tokens: torch.Tensor,
output_hidden_states: torch.Tensor,
last_token_indices: torch.Tensor,
target_seq_lens: torch.Tensor,
num_rejected: torch.Tensor,
input_buffers: InputBuffers,
input_hidden_states: torch.Tensor,
max_model_len: int,
max_num_reqs: int,
):
num_reqs = draft_tokens.shape[0]
hidden_size = output_hidden_states.shape[-1]
_prepare_eagle_docode_kernel[(num_reqs + 1,)](
draft_tokens,
output_hidden_states,
output_hidden_states.stride(0),
last_token_indices,
target_seq_lens,
num_rejected,
input_buffers.input_ids,
input_buffers.positions,
input_hidden_states,
input_hidden_states.stride(0),
input_buffers.query_start_loc.gpu,
input_buffers.seq_lens,
hidden_size,
max_model_len,
max_num_reqs,
BLOCK_SIZE=1024,
)
@triton.jit
def _update_eagle_inputs_kernel(
input_ids_ptr,
positions_ptr,
input_hidden_states_ptr,
input_hidden_states_stride,
seq_lens_ptr,
max_model_len,
draft_tokens_ptr,
output_hidden_states_ptr,
output_hidden_states_stride,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
# Draft token -> Input ID.
draft_token = tl.load(draft_tokens_ptr + req_idx)
tl.store(input_ids_ptr + req_idx, draft_token)
# Output hidden states -> Input hidden states.
for i in range(0, hidden_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < hidden_size
output_hidden_states = tl.load(
output_hidden_states_ptr + req_idx * output_hidden_states_stride + block,
mask=mask,
)
tl.store(
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
output_hidden_states,
mask=mask,
)
# Increment position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position = tl.load(positions_ptr + req_idx)
position = tl.minimum(position + 1, max_model_len - 1)
tl.store(positions_ptr + req_idx, position)
seq_len = tl.load(seq_lens_ptr + req_idx)
seq_len = tl.minimum(seq_len + 1, max_model_len)
tl.store(seq_lens_ptr + req_idx, seq_len)
def update_eagle_inputs(
draft_tokens: torch.Tensor,
output_hidden_states: torch.Tensor,
input_buffers: InputBuffers,
hidden_states: torch.Tensor,
max_model_len: int,
):
num_reqs, hidden_size = output_hidden_states.shape
_update_eagle_inputs_kernel[(num_reqs,)](
input_buffers.input_ids,
input_buffers.positions,
hidden_states,
hidden_states.stride(0),
input_buffers.seq_lens,
max_model_len,
draft_tokens,
output_hidden_states,
output_hidden_states.stride(0),
hidden_size,
BLOCK_SIZE=1024,
)

View File

@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import (
capture_graphs,
get_cudagraph_sizes,
prepare_inputs_to_capture,
)
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
class EagleCudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
cudagraph_mode: CUDAGraphMode
if self.compilation_config.cudagraph_mode is None:
cudagraph_mode = CUDAGraphMode.NONE
else:
cudagraph_mode = self.compilation_config.cudagraph_mode
if cudagraph_mode == CUDAGraphMode.FULL:
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
self.cudagraph_mode = cudagraph_mode
self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens)
def capture_graph(
self,
num_tokens: int,
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
num_reqs = min(num_tokens, self.max_num_reqs)
attn_metadata = prepare_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
block_tables,
attn_metadata_builders,
self.max_model_len,
kv_cache_config,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.pool):
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
self.graphs[num_tokens] = graph
@torch.inference_mode()
def capture(
self,
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
generate_fn=generate_fn,
input_buffers=input_buffers,
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
kv_cache_config=kv_cache_config,
)
def run(self, num_tokens: int) -> None:
assert num_tokens in self.graphs
self.graphs[num_tokens].replay()

View File

@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _rejection_sample_kernel(
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
sampled_stride,
num_sampled_ptr, # [num_reqs]
target_sampled_ptr, # [num_draft_tokens + num_reqs]
input_ids_ptr, # [num_draft_tokens + num_reqs]
cu_num_logits_ptr, # [num_reqs + 1]
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
num_sampled = 0
rejected = False
for i in range(num_tokens - 1):
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
num_sampled += 1
if target_sampled != draft_sampled:
rejected = True
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
)
num_sampled += 1
tl.store(num_sampled_ptr + req_idx, num_sampled)
def rejection_sample(
# [num_draft_tokens + num_reqs]
target_sampled: torch.Tensor,
# [num_draft_tokens + num_reqs]
input_ids: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_sampled,
input_ids,
cu_num_logits,
num_warps=1,
)
return sampled, num_sampled

View File

@@ -0,0 +1,316 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.penalties import bincount
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
NO_LORA_ID = 0
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
num_speculative_steps: int,
vocab_size: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.num_speculative_steps = num_speculative_steps
self.vocab_size = vocab_size
self.device = device
self.pin_memory = pin_memory
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs))
self.extra_data: dict[str, ExtraData] = {}
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# depending on the configured max_num_reqs and max_model_len.
self.prefill_token_ids = UvaBuffer(
self.max_num_reqs, self.max_model_len, dtype=torch.int32
)
# NOTE(woosuk): We don't use UVA for prefill_len because its GPU view
# can be used outside of update_states and prepare_inputs.
# Without async barrier, using UVA can cause race conditions.
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
# Number of computed tokens.
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
# Last sampled tokens.
self.last_sampled_tokens = torch.zeros(
self.max_num_reqs,
1,
dtype=torch.int64,
device=device,
)
# Draft tokens.
self.draft_tokens = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
dtype=torch.int64,
device=device,
)
self.next_prefill_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
# LoRA.
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
# Sampling parameters.
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
self.min_p = self._make_param(self.max_num_reqs, torch.float32)
self.repetition_penalty = self._make_param(self.max_num_reqs, torch.float32)
self.frequency_penalty = self._make_param(self.max_num_reqs, torch.float32)
self.presence_penalty = self._make_param(self.max_num_reqs, torch.float32)
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(-1)
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
# Statistics for penalties.
self.prompt_bin_mask = torch.zeros(
self.max_num_reqs,
cdiv(self.vocab_size, 32),
dtype=torch.int32,
device=self.device,
)
# TODO(woosuk): This tensor is rarely used but can be extremely large.
# Optimize the memory usage.
self.output_bin_counts = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
def add_request(
self,
req_id: str,
prompt_len: int,
prefill_token_ids: list[int],
num_computed_tokens: int,
sampling_params: SamplingParams,
lora_request: LoRARequest | None,
) -> None:
assert len(self.free_indices) > 0, "No free indices"
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
self.extra_data[req_id] = ExtraData(lora_request)
self.prompt_len[req_idx] = prompt_len
prefill_len = len(prefill_token_ids)
assert prefill_len >= prompt_len, (
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
)
self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids.np[req_idx, :prefill_len] = prefill_token_ids
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
# FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
# Optimize this.
self.num_computed_tokens[req_idx] = num_computed_tokens
if lora_request is not None:
self.lora_ids[req_idx] = lora_request.lora_int_id
else:
self.lora_ids[req_idx] = NO_LORA_ID
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.min_p.np[req_idx] = sampling_params.min_p
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
if use_penalty(sampling_params):
bincount(
self.prefill_token_ids.gpu[req_idx],
prefill_len,
prompt_len,
self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx],
)
if sampling_params.seed is not None:
seed = sampling_params.seed
else:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
num_logprobs = -1
self.num_logprobs[req_idx] = num_logprobs
# For now, only support prompt logprobs for the prompt tokens.
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
def remove_request(self, req_id: str) -> None:
self.extra_data.pop(req_id, None)
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
def make_sampling_metadata(
self,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> SamplingMetadata:
temperature = self.temperature.np[idx_mapping_np]
temperature = self.temperature.copy_np_to_gpu(temperature)
top_p = self.top_p.np[idx_mapping_np]
no_top_p = np.all(top_p == 1.0)
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
top_k = self.top_k.np[idx_mapping_np]
no_top_k = np.all(top_k == self.vocab_size)
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
min_p = self.min_p.np[idx_mapping_np]
no_min_p = np.all(min_p == 0.0)
min_p = self.min_p.copy_np_to_gpu(min_p) if not no_min_p else None
rep_penalty = self.repetition_penalty.np[idx_mapping_np]
rep_penalty = self.repetition_penalty.copy_np_to_gpu(rep_penalty)
freq_penalty = self.frequency_penalty.np[idx_mapping_np]
freq_penalty = self.frequency_penalty.copy_np_to_gpu(freq_penalty)
pres_penalty = self.presence_penalty.np[idx_mapping_np]
pres_penalty = self.presence_penalty.copy_np_to_gpu(pres_penalty)
seeds = self.seeds.np[idx_mapping_np]
seeds = self.seeds.copy_np_to_gpu(seeds)
num_logprobs = self.num_logprobs[idx_mapping_np]
max_num_logprobs: int | None = int(np.max(num_logprobs))
if max_num_logprobs == -1:
max_num_logprobs = None
return SamplingMetadata(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
repetition_penalty=rep_penalty,
frequency_penalty=freq_penalty,
presence_penalty=pres_penalty,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
idx_mapping=idx_mapping,
prompt_bin_mask=self.prompt_bin_mask,
output_bin_counts=self.output_bin_counts,
)
def make_lora_inputs(
self,
req_ids: list[str],
idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray,
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
lora_ids = self.lora_ids[idx_mapping]
prompt_lora_mapping = tuple(lora_ids)
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set()
for req_id in req_ids:
lora_request = self.extra_data[req_id].lora_request
if lora_request is not None:
active_lora_requests.add(lora_request)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
class Param:
def __init__(
self,
size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.buffer = CpuGpuBuffer(
size,
dtype=dtype,
device=device,
pin_memory=pin_memory,
)
self.np = np.zeros_like(self.buffer.np)
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
n = x.shape[0]
self.buffer.np[:n] = x
return self.buffer.copy_to_gpu(n)
@dataclass
class ExtraData:
lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
class UvaBuffer:
def __init__(self, *size: int | torch.SymInt, dtype: torch.dtype):
assert is_uva_available()
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
def use_penalty(sampling_params: SamplingParams) -> bool:
return (
sampling_params.repetition_penalty != 1.0
or sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
)

View File

@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.input_batch import InputBuffers
def apply_grammar_bitmask(
logits: torch.Tensor,
req_ids: list[str],
grammar_req_ids: list[str],
grammar_bitmask: np.ndarray,
input_buffers: InputBuffers,
) -> None:
input_buffers.grammar_bitmask.np[: grammar_bitmask.shape[0]] = grammar_bitmask
input_buffers.grammar_bitmask.copy_to_gpu(grammar_bitmask.shape[0])
batch_size = logits.shape[0]
grammar_req_id_to_idx = {req_id: i for i, req_id in enumerate(grammar_req_ids)}
# logits -> bitmask mapping
mapping = [grammar_req_id_to_idx.get(req_id, -1) for req_id in req_ids]
input_buffers.bitmask_indices.np[:batch_size] = mapping
input_buffers.bitmask_indices.copy_to_gpu(batch_size)
vocab_size = logits.shape[-1]
BLOCK_SIZE = 8192
grid = (batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))
_apply_grammar_bitmask_kernel[grid](
logits,
logits.stride(0),
input_buffers.grammar_bitmask.gpu,
input_buffers.grammar_bitmask.gpu.stride(0),
input_buffers.bitmask_indices.gpu,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
# Adapted from
# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
@triton.jit
def _apply_grammar_bitmask_kernel(
logits_ptr,
logits_stride,
bitmask_ptr,
bitmask_stride,
bitmask_indices_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
logits_idx = tl.program_id(0)
bitmask_idx = tl.load(bitmask_indices_ptr + logits_idx)
if bitmask_idx == -1:
# No bitmask to apply.
return
# Load the bitmask.
block_id = tl.program_id(1)
bitmask_offset = (block_id * BLOCK_SIZE) // 32 + tl.arange(0, BLOCK_SIZE // 32)
packed_bitmask = tl.load(
bitmask_ptr + bitmask_idx * bitmask_stride + bitmask_offset,
mask=bitmask_offset < bitmask_stride,
)
# Unpack the bitmask.
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)
# Apply the bitmask to the logits.
block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
logits_ptr + logits_idx * logits_stride + block_offset,
-float("inf"),
mask=bitmask & (block_offset < vocab_size),
)