[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 17ea2ec6aa
1232 changed files with 777 additions and 36 deletions

0
vllm/v1/__init__.py Normal file
View File

View File

View File

View File

@@ -0,0 +1,167 @@
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
TorchSDPAMetadata)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_input_batch import InputBatch
class TorchSDPABackend:
accept_output_buffer: bool = False
@staticmethod
def get_name() -> str:
return "TORCH_SDPA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
return TorchSDPAMetadataBuilderV1
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable) -> None:
self.runner = runner
self.block_table = block_table
# For reorder
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
dtype=np.int64)
self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
dtype=np.int64)
self.num_prompt_req: int = 0
self.seq_start_loc_cpu = torch.zeros(
runner.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
)
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
prompt_list_idx = 0
decode_list_idx = 0
for req_index in range(input_batch.num_reqs):
if input_batch.num_computed_tokens_cpu[
req_index] < input_batch.num_prompt_tokens[req_index]:
# prompt stage
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
prompt_list_idx += 1
else:
# decode stage
self.reorder_decode_req_index_list[decode_list_idx] = req_index
decode_list_idx += 1
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
# Update prompt requests number
self.num_prompt_req = prompt_list_idx
reorder_req_num = 0
for req_index in range(decode_list_idx):
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
reorder_req_num += 1
else:
break
if reorder_req_num == 0:
return False
reorder_prompt_list = (
self.reorder_prompt_req_index_list[:prompt_list_idx]
[-reorder_req_num:])
reorder_decode_list = (
self.reorder_decode_req_index_list[:decode_list_idx]
[:reorder_req_num])
assert reorder_decode_list.size == reorder_prompt_list.size
for idx in range(reorder_req_num):
prompt_req_index = reorder_prompt_list[idx].item()
decode_req_index = reorder_decode_list[idx].item()
input_batch.swap_states(prompt_req_index, decode_req_index)
return True
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
runner = self.runner
block_table = self.block_table
seq_lens_np = runner.seq_lens_np[:num_reqs]
num_prompt_req = self.num_prompt_req
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
) if num_prompt_req > 0 else 0
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
) if num_prompt_req < num_reqs else 0
self.seq_start_loc_np[0] = 0
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
) - num_prefill_tokens
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
block_table_tensor = block_table.get_device_tensor()
attn_metadata = TorchSDPAMetadata(
num_prefills=num_prompt_req,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens_tensor=runner.
seq_lens_cpu[num_prompt_req:num_reqs], # decode
max_decode_seq_len=max_decode_seq_len, # decode
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
chunked_prefill=True,
max_query_len=max_query_len,
max_kv_len=max_prefill_seq_len,
prefill_query_start_loc=runner.
query_start_loc_cpu[:num_prompt_req + 1], # prefill
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
1], # prefill
prefill_block_tables=block_table_tensor[:
num_prompt_req], # prefill
query_start_loc=runner.query_start_loc_cpu[:num_reqs +
1], # for logits index
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
)
return attn_metadata

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,657 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashInfer."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper)
#MultiLevelCascadeAttentionWrapper
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
logger = init_logger(__name__)
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [64, 128, 256]
@staticmethod
def get_name() -> str:
return "FLASHINFER_VLLM_V1"
@staticmethod
def get_impl_cls() -> type[FlashInferImpl]:
return FlashInferImpl
@staticmethod
def get_metadata_cls() -> type[FlashInferMetadata]:
return FlashInferMetadata
@staticmethod
def get_builder_cls() -> type[FlashInferMetadataBuilder]:
return FlashInferMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float
def get_per_layer_parameters(
vllm_config: VllmConfig) -> dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
impl = layer.impl
assert isinstance(impl, FlashInferImpl)
# Infer hyperparameters from the attention layer
window_size = impl.sliding_window
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = impl.logits_soft_cap
sm_scale = impl.scale
per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)
return per_layer_params
def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params
@dataclass
class FlashInferMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
qo_indptr: torch.Tensor
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: torch.Tensor
# The page indices of the paged kv cache
paged_kv_indices: torch.Tensor
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: torch.Tensor
# The number of query/output heads
num_qo_heads: int
# The number of key/value heads
num_kv_heads: int
# The dimension of the attention heads
head_dim: int
# Block size of vllm
page_size: int
# The data type of the paged kv cache
data_type: torch.dtype
# The data type of the query
q_data_type: torch.dtype
slot_mapping: torch.Tensor
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
# For cascade attention.
use_cascade: bool
shared_qo_indptr: Optional[torch.Tensor] = None
shared_kv_page_indptr: Optional[torch.Tensor] = None
shared_kv_page_indices: Optional[torch.Tensor] = None
shared_kv_last_page_len: Optional[torch.Tensor] = None
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
# cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
@property
def query_start_loc(self):
# The GPUModelRunner expects to be able to access this property.
return self.qo_indptr
def __post_init__(self):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f" received {self.head_dim}.")
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
self.runner = runner
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode
self._cascade_wrapper = None # Wrapper for cascade attention
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = runner.vllm_config
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the decode run only supports num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.runner.device)
return self._workspace_buffer
def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
return self._prefill_wrapper
def _get_decode_wrapper(self):
if self._decode_wrapper is None:
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper
# def _get_cascade_wrapper(self):
# if self._cascade_wrapper is None:
# self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
# 2, self._get_workspace_buffer(), "NHD")
# return self._cascade_wrapper
def _plan(self, attn_metadata: FlashInferMetadata):
if self.global_hyperparameters is None:
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
if attn_metadata.use_cascade and False: # not supported
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan(
[attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr],
[
attn_metadata.shared_kv_page_indptr,
attn_metadata.paged_kv_indptr
],
[
attn_metadata.shared_kv_page_indices,
attn_metadata.paged_kv_indices
],
[
attn_metadata.shared_kv_last_page_len,
attn_metadata.paged_kv_last_page_len
],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
q_data_type=attn_metadata.q_data_type,
)
else:
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
if self._num_prefills > 0:
# Decodes are first so prefills start after the last decode
prefill_start = self._num_decodes
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
assert attn_metadata.qo_indptr[prefill_start:].shape[
0] == self._num_prefills + 1
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
0] == self._num_prefills + 1
assert attn_metadata.paged_kv_last_page_len[
prefill_start:].shape[0] == self._num_prefills
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
qo_indptr = attn_metadata.qo_indptr[
prefill_start:] - attn_metadata.qo_indptr[prefill_start]
attn_metadata.prefill_wrapper.plan(
qo_indptr,
attn_metadata.paged_kv_indptr[prefill_start:],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[prefill_start:],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.data_type,
)
if self._num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper()
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[:self._num_decodes],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.data_type,
)
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
assert self._num_decodes + self._num_prefills == num_reqs
assert (self._num_decode_tokens +
self._num_prefill_tokens == num_actual_tokens)
page_size = self.kv_cache_spec.block_size
device = self.runner.device
qo_indptr = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
block_table_bounds = (seq_lens + page_size - 1) // page_size
use_cascade = common_prefix_len > 0
if use_cascade:
# Grab the blocks of the shared prefix from the first request.
assert common_prefix_len % page_size == 0
num_common_kv_blocks = common_prefix_len // page_size
shared_qo_indptr = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=device)
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
dtype=torch.int32,
device=device)
shared_kv_page_indices = block_table_tensor[
0, :num_common_kv_blocks]
shared_kv_last_page_len = torch.tensor([page_size],
dtype=torch.int32,
device=device)
# Remove the blocks of the shared prefix from all requests.
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
block_table_bounds -= num_common_kv_blocks
else:
shared_qo_indptr = None
shared_kv_page_indptr = None
shared_kv_page_indices = None
shared_kv_last_page_len = None
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
device=block_table_tensor.device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[mask]
paged_kv_indptr = torch.cat([
torch.zeros(1,
dtype=block_table_bounds.dtype,
device=block_table_bounds.device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
qo_indptr=qo_indptr,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=self.runner.num_query_heads,
num_kv_heads=self.kv_cache_spec.num_kv_heads,
head_dim=self.kv_cache_spec.head_size,
page_size=page_size,
data_type=self.kv_cache_spec.dtype,
q_data_type=self.runner.dtype,
slot_mapping=slot_mapping,
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=self._num_prefills,
num_prefill_tokens=self._num_prefill_tokens,
use_cascade=use_cascade,
shared_qo_indptr=shared_qo_indptr,
shared_kv_page_indptr=shared_kv_page_indptr,
shared_kv_page_indices=shared_kv_page_indices,
shared_kv_last_page_len=shared_kv_last_page_len,
)
self._plan(attn_metadata)
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
logger.warning_once(
"Using cascade attention in FlashInfer is not supported yet")
return False
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
# kv cache dtype to something different from query dtype.
return False
return use_cascade_attention(*args, **kwargs)
class FlashInferImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
if use_irope:
logger.warning_once(
"Using irope in FlashInfer is not supported yet, it will fall"
" back to global attention for long context.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashInfer.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
window_left = (self.sliding_window[0]
if self.sliding_window is not None else -1)
# Inputs and outputs may be padded for CUDA graphs
query = query[:num_actual_tokens]
output_padded = output
output = output[:num_actual_tokens]
# if attn_metadata.use_cascade:
# # Cascade attention (rare case).
# assert attn_metadata.cascade_wrapper is not None
# output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
# return output
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
if prefill_wrapper := attn_metadata.prefill_wrapper:
prefill_query = query[num_decode_tokens:]
assert prefill_query.shape[0] == num_prefill_tokens
assert prefill_wrapper is not None
assert prefill_wrapper._causal
assert prefill_wrapper._window_left == window_left
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0)
assert prefill_wrapper._sm_scale == self.scale
prefill_wrapper.run(
prefill_query,
kv_cache,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
)
if decode_wrapper := attn_metadata.decode_wrapper:
decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None
assert decode_wrapper._window_left == window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0)
assert decode_wrapper._sm_scale == self.scale
decode_wrapper.run(
decode_query,
kv_cache,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
)
return output_padded

View File

@@ -0,0 +1,473 @@
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
_score_mod_signature,
create_block_mask,
flex_attention)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if current_platform.is_cuda():
pass
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
create_block_mask_compiled = torch.compile(create_block_mask,
fullgraph=True,
mode="reduce-overhead")
flex_attention_compiled = torch.compile(flex_attention, fullgraph=True)
def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
device = offsets.device
counts = offsets[1:] - offsets[:-1]
return torch.repeat_interleave(
torch.arange(len(counts), device=device, dtype=torch.int32), counts)
class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [16, 32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "FLEX_ATTENTION"
@staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return FlexAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]:
return FlexAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
# @torch.compile(fullgraph=True, mode="reduce-overhead")
def physical_to_logical_mapping(
block_table: torch.Tensor,
total_blocks: Optional[int] = None) -> torch.Tensor:
"""
Creates an inverse mapping from physical block locations to logical indices.
The original block_table maps from logical blocks to physical locations:
Logical to Physical (Original block_table):
┌───────────────────────────────────────────┐
│ Request 0: │
│ │
│ Logical Blocks: 0 1 2 3 4 5 6 7 │
│ │ │ │ │ │ │ │ │ │
│ v v v v v v v v │
│ Physical Blocks: 3 5 1 7 4 2 0 6 │
└───────────────────────────────────────────┘
This function creates the inverse mapping:
Physical to Logical (Inverse mapping):
┌───────────────────────────────────────────┐
│ Request 0: │
│ │
│ Physical Blocks: 0 1 2 3 4 5 6 7 │
│ │ │ │ │ │ │ │ │ │
│ v v v v v v v v │
│ Logical Blocks: 6 2 5 0 4 1 7 3 │
└───────────────────────────────────────────┘
If multiple logical blocks map to the same physical block,
this function returns the first (minimum) logical block index.
If a physical block is not mapped to by any logical block,
its value in the result will be -1.
Args:
block_table: Tensor of shape [max_reqs, max_num_blocks]
mapping logical blocks to physical locations
Returns:
A tensor of shape [max_reqs, max_physical_block]
"""
max_reqs, max_num_blocks = block_table.shape
device = block_table.device
physical_to_logical = torch.full((max_reqs, total_blocks),
-1,
dtype=torch.long,
device=device)
logical_indices = (torch.arange(max_num_blocks,
device=device).unsqueeze(0).expand(
max_reqs, -1))
physical_to_logical.scatter_(-1, block_table.to(torch.int64),
logical_indices)
# TODO Confirm - Seems like block 0 is always empty so we reset it manually
physical_to_logical[:, 0] = -1
return physical_to_logical
def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
kv_idx: torch.Tensor):
return q_idx >= kv_idx
@dataclass
class FlexAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# Block info
total_cache_tokens: int
block_size: int
max_possible_sequence_length: int
num_reqs: int
physical_to_logical: torch.Tensor
decode_offset: torch.Tensor
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# Flex Metadata
num_blocks = 0
block_mask: Optional[BlockMask] = None
score_mod: Optional[_score_mod_signature] = None
mask_mod: Optional[_mask_mod_signature] = None
logical_mask_mod: _mask_mod_signature = causal_mask_mod
def get_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for FlexAttention.
This function creates the combined mask mod function that handles:
1. The paged attention block mapping
2. The mapping from packed query sequences to logical query entries
It also by defaults adds the decoding offset to the query indices.
With this info we create the "logical" indices that are passed to
mask_mod functions. This allows mask mod functions to be agnostic to
layout of the query and key/value tensors.
TODO is_within_lower_bound: do sequences start on block_boundaries?
"""
# Create a lookup mapping from query indices -> request number
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
# Map query indices to corresponding request indices
q_req = request_lookup[q_idx]
# Convert physical KV indices to logical indices
physical_kv_block = physical_kv_idx // self.block_size
physical_kv_offset = physical_kv_idx % self.block_size
logical_block_idx = self.physical_to_logical[q_req,
physical_kv_block]
logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # noqa: E501
# Determine valid kv indices
live_block = logical_block_idx >= 0
within_upper_bound = logical_kv_idx < self.seq_lens[q_req]
within_lower_bound = logical_kv_idx >= 0
is_valid = live_block & within_upper_bound & within_lower_bound
# Convert physical query indices to logical indices
local_q_idx = q_idx - self.query_start_loc[q_req]
logical_q_idx = local_q_idx + self.decode_offset[q_req]
# Apply mask modification only for valid indices
return torch.where(
is_valid,
self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx),
False,
)
return final_mask_mod
def build_block_mask(self) -> BlockMask:
assert self.mask_mod is not None
return create_block_mask_compiled(
self.mask_mod,
None,
None,
self.num_actual_tokens,
self.total_cache_tokens,
)
def __post_init__(self):
assert self.use_cascade is False, "Not implemented yet."
assert self.common_prefix_len == 0, "Not implemented yet."
assert self.cu_prefix_query_lens is None, "Not implemented yet."
assert self.prefix_kv_lens is None, "Not implemented yet."
assert self.suffix_kv_lens is None, "Not implemented yet."
self.num_blocks = self.total_cache_tokens // self.block_size
self.mask_mod = self.get_mask_mod()
self.block_mask = self.build_block_mask()
class FlexAttentionMetadataBuilder(
AttentionMetadataBuilder[FlexAttentionMetadata]):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
model_config = runner.model_config
self.runner = runner
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
use_cascade = common_prefix_len > 0
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
if use_cascade:
raise NotImplementedError("Not yet my friend")
block_size = self.kv_cache_spec.block_size
max_possible_seq_len = self.runner.model_config.max_model_len
total_cache_tokens = (self.runner.cache_config.num_gpu_blocks *
block_size)
inverse_block_table = physical_to_logical_mapping(
block_table_tensor, self.runner.cache_config.num_gpu_blocks)
# Get the original offset tensor
offset_tensor = torch.tensor(
self.runner.input_batch.num_computed_tokens_cpu[:num_reqs]).to(
self.runner.device, non_blocking=True)
out = FlexAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
block_size=block_size,
max_possible_sequence_length=max_possible_seq_len,
num_reqs=num_reqs,
physical_to_logical=inverse_block_table,
total_cache_tokens=total_cache_tokens,
decode_offset=offset_tensor,
)
return out
class FlexAttentionImpl(AttentionImpl):
sliding_window: Optional[tuple[int, int]]
alibi_slopes: Optional[torch.Tensor]
logits_soft_cap: Optional[float]
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if blocksparse_params is not None:
# TODO we should support this :think
raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
raise NotImplementedError(
"FlexAttention does not support alibi slopes yet.")
else:
self.alibi_slopes = None
if sliding_window is not None:
raise NotImplementedError(
"FlexAttention does not support sliding window yet.")
else:
self.sliding_window = (-1, -1)
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
if self.logits_soft_cap is not None:
raise NotImplementedError(
"FlexAttention does not support logits soft cap yet.")
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if kv_sharing_target_layer_name is not None:
raise NotImplementedError(
"FlexAttention does not support kv sharing yet.")
support_head_sizes = FlexAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlexAttention does not support quantized kv-cache. Yet")
@staticmethod
def view_as_4d(tensor: torch.Tensor) -> torch.Tensor:
"""View a 3d tensor as 4D."""
if tensor.ndim == 4:
return tensor
assert tensor.ndim == 3
return tensor[None, :, :, :]
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlexAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FLexAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
enable_gqa = self.num_kv_heads != self.num_heads
if attn_metadata is None:
# Profiling run.
return output
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
# return torch.empty_like(query)
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
query, key_cache, value_cache = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key_cache, value_cache),
)
query = query[:, :, :num_actual_tokens, :]
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)
out = flex_attention_compiled(
query,
key_cache,
value_cache,
attn_metadata.score_mod,
attn_metadata.block_mask,
self.scale,
enable_gqa=enable_gqa,
kernel_options={"FORCE_USE_FLEX_ATTENTION": True},
)
# Flex doesn't have an out variant today, rely on epilogue fusion
out = out.permute(0, 2, 1, 3).squeeze(0)
output[:num_actual_tokens, :, :].copy_(out)
return output

View File

@@ -0,0 +1,975 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
# MLA Common Components
This file implements common components for MLA implementations.
First we define:
Sq as Q sequence length
Skv as KV sequence length
MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large").
NOTE what we deem small and large is currently determined by if its labelled
prefill or decode by the scheduler, but this is something we should probably
tune.
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention.
Below is example of both paths assuming batchsize = 1
## More Extent Definitions:
C Context length, `Skv - Sq`
H hidden size
N number of attention heads
Lq latent dimension for Q 1536 in DSV3
Lkv latent dimension for K/V 512 in DSV3
P nope dimension, no rope. 128 in DSV3
R rope dimension, goes through rope. 64 in DSV3
V V head dim. 128 in DSV3
## Vector/Matrix Definitions
h_t hidden states (input to attention) shape [Sq, H]
q_c latent/compressed Q shape [Sq, Lq]
q_nope uncompressed Q (no-rope) shape [Sq, N, P]
q_pe uncompressed Q (rope) shape [Sq, N, R]
kv_c latent/compressed KV shape [Skv, Lkv]
k_pe decoupled k position embeddings shape [Skv, R]
new_kv_c new kv_c from current iter shape [Sq, Lkv]
new_k_pe new k_pe from current iter shape [Sq, R]
cache_kv_c cached k_c from previous iters shape [C, Lkv]
cache_k_pe cached k_pe from previous iters shape [C, R]
W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N, P]
W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N, V]
W_O project v to h_t shape [N * V, H]
## Compute Friendly Approach (i.e. "_forward_prefill"):
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
// MHA with QK headdim = P + R
// V headdim = V
// spda_o shape [Sq, N, V]
spda_o = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
v
)
return spda_o @ W_O
NOTE: in the actual code,
`kv_b_proj` is [W_UK; W_UV] concatenated per head
`q_b_proj` is [W_UQ; W_QR] concatenated per head
`out_proj` is W_O
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Runtime
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// MQA with QK headdim = Lkv + R
// V headdim = Lkv
// spda_o shape [Sq, N, Lkv]
// NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
torch.cat([ql_nope, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1),
kv_c
)
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill
For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
fixed workspace size.
The chunked prefill approach is as follows:
MCC Max chunk of context to process per iter, computed dynamically,
used to bound the memory usage
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
// MHA between queries and new KV
// with QK headdim = P + R
// V headdim = V
// curr_o shape [Sq, N, V]
// curr_lse shape [N, Sq], this is just order FA returns
curr_o, curr_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
new_v,
casual=True,
return_softmax_lse=True
)
// Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)):
chunk_start = chunk_idx * MCC
chunk_end = min(chunk_start + MCC, C)
Sc = chunk_end - chunk_start
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
chunk_o, chunk_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([cache_k_nope_chunk,
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
dim=-1),
cache_v_chunk,
casual=False,
return_softmax_lse=True
)
curr_o, curr_lse = merge_attn_states(
suffix_output=curr_o,
suffix_lse=curr_lse,
prefix_output=chunk_o,
prefix_lse=chunk_lse,
)
return curr_o @ W_O
"""
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.merge_attn_states import merge_attn_states
# from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
UnquantizedLinearMethod)
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
except ImportError:
# For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func
is_vllm_fa = False
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm import envs
logger = init_logger(__name__)
def get_flash_attn_version():
return None
class MLACommonBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return MLACommonMetadata
@staticmethod
def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
return MLACommonMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [576]
@dataclass
class MLACommonPrefillMetadata:
""" Prefill Specific Metadata """
@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
block_table: torch.Tensor
query_start_loc: torch.Tensor
max_query_len: int
chunked_context: Optional[ChunkedContextMetadata] = None
@dataclass
class MLACommonDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
D = TypeVar("D", bound=MLACommonDecodeMetadata)
@dataclass
class MLACommonMetadata(Generic[D]):
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
# The dimension of the attention heads
head_dim: Optional[int] = None
decode: Optional[D] = None
prefill: Optional[MLACommonPrefillMetadata] = None
def __post_init__(self):
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
M = TypeVar("M", bound=MLACommonMetadata)
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(self,
runner: "GPUModelRunner",
kv_cache_spec: AttentionSpec,
block_table: BlockTable,
metadata_cls: Optional[type[M]] = None):
self.metadata_cls = metadata_cls \
if metadata_cls is not None else MLACommonMetadata
self.runner = runner
scheduler_config = runner.scheduler_config
model_config = runner.model_config
cache_config = runner.cache_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.num_heads = model_config.get_num_attention_heads(
runner.parallel_config)
self.mla_dims = get_mla_dims(model_config)
self.aot_schedule = False # current_platform.is_cuda()
self.kv_cache_spec = kv_cache_spec
# Dont try to access the runner on AMD
if self.aot_schedule:
self.page_size = self.kv_cache_spec.block_size
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(
8 * model_config.max_model_len, 4 *
scheduler_config.max_num_seqs * cache_config.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * cache_config.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
model_config.get_head_size()),
dtype=model_config.dtype,
device=runner.device,
)
self.block_table = block_table
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
# Update state usually set in reorder_batch.
self._num_decodes = m.num_reqs
self._num_decode_tokens = m.num_actual_tokens
self._num_prefills = 0
self._num_prefill_tokens = 0
return self.build(0, m)
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
assert self._num_decodes + self._num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
prefill_metadata = None
if self._num_prefills > 0:
reqs_start = self._num_decodes # prefill_start
context_lens_cpu = self.runner.input_batch.\
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
chunked_context_metadata = None
if self.chunked_prefill_enabled and self._num_prefills > 0 \
and max_context_len_cpu > 0:
# NOTE: it is recommend you read the `Chunked Prefill` section
# in the comment at the top of the file before trying to
# understand the following code
# currently we allocate an equal amount of workspace for each
# prefill in the batch, we could probably use a more advanced
# algorithm here and allocate more workspace to prefills with
# longer context lengths
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
if self.aot_schedule:
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk,
self.page_size)
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
# if `max_context_chunk = 256`, `num_chunks = 3`, and
# `num_prefills_with_context = 4`, create a tensor that looks
# like
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
# Note(simon): this is done in CPU because of downstream's
# of `to_list`.
chunk_starts = \
torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, self._num_prefills) \
* max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
self._num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata = \
MLACommonPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
workspace=self.chunked_prefill_workspace,
)
assert max(chunked_context_metadata.max_seq_lens) <= \
self.chunked_prefill_workspace_size
prefill_metadata = MLACommonPrefillMetadata(
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
)
decode_metadata = None
if self._num_decodes > 0:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
seq_lens=seq_lens[:self._num_decodes],
)
return self.metadata_cls(
num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc,
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
# MLACommonMetadata Chunk prefill specific
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=self._num_prefills,
prefill=prefill_metadata,
decode=decode_metadata,
)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported for MLA")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.kv_b_proj = kv_b_proj
self.vllm_flash_attn_version = get_flash_attn_version()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
def _flash_attn_varlen_diff_headdims(self,
q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
maybe_padded_v = v
if self._pad_v:
maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]], value=0)
if is_vllm_fa:
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
else:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_attn_probs=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
# Unpack the output if there is multiple results
lse = None
if isinstance(attn_out, tuple):
attn_out, lse = attn_out[0], attn_out[1]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
return attn_out, lse
return attn_out
def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight if not envs.MACA_VLLM_USE_TN_2_NN else layer.weight.T
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _compute_prefill_context(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.gather_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
kv_c_normed = workspace[:toks]\
[..., :self.kv_lora_rank]
k_pe = workspace[:toks]\
[..., self.kv_lora_rank:].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
if output is None:
output = attn_output
output_lse = attn_softmax_lse
else:
output_tmp = torch.empty_like(output)
output_lse_tmp = torch.empty_like(output_lse)
merge_attn_states(
output=output_tmp,
output_lse=output_lse_tmp,
prefix_output=output,
prefix_lse=output_lse,
suffix_output=attn_output,
suffix_lse=attn_softmax_lse,
)
output = output_tmp
output_lse = output_lse_tmp
return output, output_lse
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
has_context = False # attn_metadata.prefill.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
# return_softmax_lse=has_context,
)
if has_context:
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)
output = torch.empty_like(suffix_output)
merge_attn_states(
output=output,
prefix_output=context_output,
prefix_lse=context_lse,
suffix_output=suffix_output,
suffix_lse=suffix_lse,
)
# unpad if necessary
if self._pad_v:
output = output.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]] \
.reshape(-1, self.num_heads * v.shape[-1])
return output
@abstractmethod
def _forward_decode(
self,
ql_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: M,
) -> torch.Tensor:
raise NotImplementedError
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: M,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
if has_decode:
assert attn_metadata.decode is not None
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
return output_padded

View File

@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
import torch
import vllm._custom_ops as ops
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
logger = init_logger(__name__)
class CutlassMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["CutlassMLAImpl"]:
return CutlassMLAImpl
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"CutlassMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CutlassMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"CutlassMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
B = q_nope.shape[0]
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
dtype=q_nope.dtype,
device=q_nope.device)
# Run MLA
# Clone q_nope and q_pe to make sure strides computation is correct.
q_nope = q_nope.clone()
q_pe = q_pe.clone()
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table, self.scale)
return self._v_up_proj(o)

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASHMLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type["FlashMLAMetadata"]:
return FlashMLAMetadata
@staticmethod
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
tile_scheduler_metadata: torch.Tensor
num_splits: torch.Tensor
@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
pass
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
if self.runner.full_cuda_graph:
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_tile_scheduler_metadata is None:
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
self.cg_buf_num_splits = num_splits
else:
assert self.cg_buf_num_splits is not None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert (self.cg_buf_tile_scheduler_metadata.size() ==
tile_scheduler_metadata.size())
self.cg_buf_tile_scheduler_metadata.\
copy_(tile_scheduler_metadata)
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
num_splits = num_splits_view
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
)
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
assert is_flashmla_supported(), \
"FlashMLA is not supported on this device"
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
)
return self._v_up_proj(o)

View File

@@ -0,0 +1,220 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional
import torch
import vllm.envs as envs
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
# yapf: enable
def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_MLA
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["AiterMLAImpl"]:
return AiterMLAImpl
@staticmethod
def get_metadata_cls() -> type["AiterMLAMetadata"]:
return AiterMLAMetadata
@staticmethod
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder
@dataclass
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None
# The query indptr, shape : [num_decode + 1]
qo_indptr: Optional[torch.Tensor] = None
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
pass
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1."
def _get_paged_kv_tensors(
self, block_table: torch.Tensor,
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size
device = self.runner.device
mask = (torch.arange(block_table.size(1),
dtype=block_table.dtype,
device=device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table[mask]
paged_kv_indptr = torch.cat([
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
qo_indptr = torch.arange(0,
self._num_decodes + 1,
step=1,
dtype=torch.int32,
device=device)
return (
paged_kv_indices,
paged_kv_indptr,
paged_kv_last_page_len,
qo_indptr,
)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
(
paged_kv_indices,
paged_kv_indptr,
paged_last_page_len,
qo_indptr,
) = self._get_paged_kv_tensors(block_table_tensor, seq_lens)
attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_last_page_len,
qo_indptr=qo_indptr)
return attn_metadata
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
assert (num_heads == 16 or num_heads == 128), (
f"Aiter MLA only supports 16 or 128 number of heads.\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value.")
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def _flash_attn_varlen_diff_headdims(self,
q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)
return output
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
if self.num_heads == 16:
# AITER MLA decode kernel only supports
# max_seqlen_q=1 when using 16 heads.
max_seqlen_qo = 1
else:
# AITER MLA decode Kernel handles arbitrary
# max_seqlen_q values when using 128 heads.
assert attn_metadata.prefill is not None
max_seqlen_qo = attn_metadata.prefill.max_query_len
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.decode.qo_indptr, max_seqlen_qo,
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len)
return self._v_up_proj(o)

View File

@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional
import torch
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.attention.backends.triton_mla import (load_config,
find_best_mla_para)
logger = init_logger(__name__)
import os
# TODO: Configure environment variables temporarily. New versions do not need to be configured
os.environ['TRITON_ENABLE_MACA_OPT_MOVE_DOT_OPERANDS_OUT_LOOP'] = '1'
os.environ['TRITON_ENABLE_MACA_CHAIN_DOT_OPT'] = '1'
JSON_DATA = load_config()
class TritonMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type["TritonMLAMetadata"]:
return TritonMLAMetadata
@staticmethod
def get_builder_cls() -> type["TritonMLAMetadataBuilder"]:
return TritonMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["TritonMLAImpl"]:
return TritonMLAImpl
@dataclass
class TritonMLADecodeMetadata(MLACommonDecodeMetadata):
num_kv_splits: int
num_stages: int
@dataclass
class TritonMLAMetadata(MLACommonMetadata[TritonMLADecodeMetadata]):
pass
class TritonMLAMetadataBuilder(MLACommonMetadataBuilder[TritonMLAMetadata]):
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> TritonMLADecodeMetadata:
if seq_lens is not None:
batch = seq_lens.shape[0]
max_seq_len = int(seq_lens.max())
num_kv_splits, num_stages = find_best_mla_para(JSON_DATA, batch, max_seq_len, 8)
else:
num_kv_splits = 4
num_stages = 1
return TritonMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
num_kv_splits=num_kv_splits,
num_stages=num_stages,
)
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
self.num_heads,
attn_metadata.decode.num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
self.kv_lora_rank + 1,
),
dtype=torch.float32,
device=q.device,
)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens, attn_logits,
attn_metadata.decode.num_kv_splits,
attn_metadata.decode.num_stages,
self.scale, PAGE_SIZE)
return self._v_up_proj(o)

View File

@@ -0,0 +1,240 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional
import torch
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, next_power_of_2
logger = init_logger(__name__)
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "PALLAS_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads * 2, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
# block_tables within the PallasMetadata constitute almost the entire SMEM
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
# we simply make sure that the size is smaller than half of SMEM capacity.
@staticmethod
def get_min_page_size(vllm_config: VllmConfig) -> int:
max_num_page_per_req = (1024 * 1024 // 2 //
vllm_config.scheduler_config.max_num_seqs // 4)
min_page_size = cdiv(vllm_config.model_config.max_model_len,
max_num_page_per_req)
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size
# TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
@staticmethod
def get_page_size(vllm_config: VllmConfig) -> int:
page_size = next_power_of_2(
vllm_config.model_config.max_model_len) // 16
if page_size <= 16:
return 16
if page_size >= 256:
return 256
return page_size
@dataclass
class PallasMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Used in the PallasAttentionBackendImpl
slot_mapping: torch.Tensor
block_tables: torch.Tensor
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: torch.Tensor
class PallasAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
if use_irope:
logger.warning_once(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context.")
if blocksparse_params is not None:
raise ValueError("Paged attention Pallas kernel does "
"not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
tpu_version = torch_xla.tpu.version()
if tpu_version < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# For determine_available_memory case.
if kv_cache.numel() == 0:
if output is None:
output = torch.ones_like(query)
return output
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(key, value, kv_cache, slot_mapping)
output = torch.ops.xla.ragged_paged_attention(
query,
kv_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.query_start_loc,
attn_metadata.num_seqs,
# By default, the system utilizes optimized block size and
# vmem_limit_bytes parameters from the kernel repository. However,
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
use_kernel=True,
sm_scale=self.scale,
sliding_window=self.sliding_window,
soft_cap=self.logits_soft_cap,
)
return output.reshape(num_tokens, hidden_size)
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
""" Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
"""
_, _, num_combined_kv_heads, head_size = kv_cache.shape
num_kv_heads = num_combined_kv_heads // 2
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
head_size)
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
kv_cache = kv_cache.flatten(0, 1)
kv_cache.index_copy_(0, slot_mapping, kv)

View File

@@ -0,0 +1,295 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table)
self.aot_schedule = False
class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "TRITON_ATTN_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["TritonAttentionImpl"]:
return TritonAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
return TritonAttentionMetadataBuilder
class TritonAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
sinks: Optional[torch.Tensor] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"TritonAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.use_irope = use_irope
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by TritonAttention. "
f"Supported head sizes are: {support_head_sizes}.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonAttentionImpl")
self.fp8_dtype = current_platform.fp8_dtype()
self.force_prefill_decode_attn = \
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
use_prefill_decode_attn = self.force_prefill_decode_attn
num_actual_tokens = attn_metadata.num_actual_tokens
if use_prefill_decode_attn:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
else:
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if use_prefill_decode_attn:
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale == 1.0, \
"A non 1.0 q_scale is not currently supported."
if not current_platform.is_rocm():
# Skip Q quantization on ROCm, since dequantizing back to
# f32 in the attention kernel is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if use_prefill_decode_attn:
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
sinks=self.sinks)
else:
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
)
return output

View File

@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
import numpy as np
import torch
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
@dataclass
class CommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
"""
query_start_loc: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
M = TypeVar("M")
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False
@abstractmethod
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
"""
raise NotImplementedError
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
"""
Can this batch (with given metadata) use CUDA Graphs for attention.
"""
return False
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
def use_cascade_attention(
self,
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
num_sms: int,
) -> bool:
return False
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
"""
This method can reorder the batch if desired by the backend.
:return: Has the batch been reordered (default False).
"""
return False
def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context):
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
f"is not valid: target layer {target_layer_name} ")
if current_layer_name == target_layer_name:
raise ValueError(error_msg +
"cannot be the same as the current layer.")
if target_layer_name not in static_forward_context:
from vllm.model_executor.models.utils import extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx = extract_layer_index(current_layer_name)
target_layer_idx = extract_layer_index(target_layer_name)
if current_layer_idx <= target_layer_idx:
raise ValueError(error_msg + "must come before the current layer.")
else:
raise ValueError(error_msg +
"is not a valid Attention layer in the model.")
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type = static_forward_context[
target_layer_name].attn_type
expected = static_forward_context[current_layer_name].attn_type
if target_layer_attn_type != expected:
raise ValueError(
error_msg +
f"must be the same type as the current layer ({expected}).")

0
vllm/v1/core/__init__.py Normal file
View File

349
vllm/v1/core/block_pool.py Normal file
View File

@@ -0,0 +1,349 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable
from typing import Callable, Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens)
from vllm.v1.request import Request
logger = init_logger(__name__)
class BlockPool:
"""BlockPool that manages KVCacheBlocks.
It provides methods to allocate, free and cache the kv cache blocks. The
free_block_queue stores the free blocks in eviction order to enable
allocation, free, and cache eviction. The cached_block_hash_to_block
maps between block hash and cached block to support finding cached blocks
by their block hash.
Args:
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
enable_kv_cache_events: Whether to enable kv cache events.
"""
def __init__(
self,
num_gpu_blocks: int,
enable_caching: bool,
enable_kv_cache_events: bool = False,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
# All kv-cache blocks.
self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
# {block_hash: {block ID: block}}. A cached block is
# a full block with a block hash that can be used for prefix caching.
# The cached block may be used by running requests or in the
# free_block_queue that could potentially be evicted.
# NOTE: We currently don't de-duplicate the blocks in the cache,
# meaning that if a block becomes full and is cached, we don't check
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[
int, KVCacheBlock]] = defaultdict(dict)
# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
self.null_block = self.free_block_queue.popleft()
self.null_block.is_null = True
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue: list[KVCacheEvent] = []
def get_cached_block(
self, block_hash: BlockHash,
kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]:
"""Get the cached block by the block hash for each group in
`kv_cache_group_ids`, or None if cache miss for any group.
If there are duplicated blocks, we return the first block in the cache.
Args:
block_hash: The hash value of the block.
kv_cache_group_ids: The ids of the KV cache groups.
Returns:
The cached blocks if exists, or None.
"""
cached_blocks = []
for group_id in kv_cache_group_ids:
cached_blocks_one_group = self.cached_block_hash_to_block.get(
BlockHashWithGroupId(block_hash, group_id))
if not cached_blocks_one_group:
return None
first_block = next(iter(cached_blocks_one_group.values()))
cached_blocks.append(first_block)
return cached_blocks
def cache_full_blocks(
self,
request: Request,
blocks: list[KVCacheBlock],
block_hashes: list[BlockHash],
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
kv_cache_group_id: int,
hash_fn: Callable,
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `num_cached_blocks` to
`num_full_blocks`, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.
Args:
request: The request to cache the blocks.
blocks: All blocks in the request.
block_hashes: Block hashes of the blocks in the request. Note that
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
kv_cache_group_id: The id of the KV cache group.
hash_fn: The hash function to use for block hashes.
"""
if num_cached_blocks == num_full_blocks:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(block_hashes) >= num_cached_blocks
new_block_hashes = block_hashes[num_cached_blocks:]
# Update the new blocks with the block hashes through the chain.
if num_cached_blocks == 0:
prev_block_hash_value = None
else:
prev_block = blocks[num_cached_blocks - 1]
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.get_hash_value()
parent_block_hash = prev_block_hash_value
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
else None)
for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None
if i < len(new_block_hashes):
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption), or by other
# single_type_managers with the same block_size.
# In this case we simply reuse the block hash.
block_hash = new_block_hashes[i]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
blk_idx = num_cached_blocks + i
start_token_idx = blk_idx * block_size
end_token_idx = (blk_idx + 1) * block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == block_size, (
f"Expected {block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
block_tokens, extra_keys)
block_hashes.append(block_hash)
# Update and added the full block to the cache.
block_hash_with_group_id = BlockHashWithGroupId(
block_hash, kv_cache_group_id)
blk.block_hash = block_hash_with_group_id
self.cached_block_hash_to_block[block_hash_with_group_id][
blk.block_id] = blk
if new_hashes is not None:
new_hashes.append(block_hash.hash_value)
prev_block_hash_value = block_hash.hash_value
if self.enable_kv_cache_events:
self.kv_event_queue.append(
BlockStored(
block_hashes=new_hashes,
parent_block_hash=parent_block_hash,
token_ids=request.
all_token_ids[num_cached_blocks *
block_size:num_full_blocks * block_size],
block_size=block_size,
lora_id=request.lora_request.id
if request.lora_request else None,
))
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
"""Get new blocks from the free block pool.
Note that we do not check block cache in this function.
Args:
num_blocks: The number of blocks to allocate.
Returns:
A list of new block.
"""
if num_blocks > self.get_num_free_blocks():
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
ret: list[KVCacheBlock] = []
idx = 0
while idx < num_blocks:
# First allocate blocks.
curr_block = self.free_block_queue.popleft()
assert curr_block.ref_cnt == 0
# If the block is cached, evict it.
if self.enable_caching:
self._maybe_evict_cached_block(curr_block)
curr_block.incr_ref()
ret.append(curr_block)
idx += 1
return ret
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
"""
If a block is cached in `cached_block_hash_to_block`, we reset its hash
metadata and evict it from the cache.
Args:
block: The block to evict.
Returns:
True if the block is evicted, False otherwise.
"""
block_hash = block.block_hash
if block_hash and block_hash in self.cached_block_hash_to_block:
block.reset_hash()
del self.cached_block_hash_to_block[block_hash][block.block_id]
if len(self.cached_block_hash_to_block[block_hash]) == 0:
del self.cached_block_hash_to_block[block_hash]
if self.enable_kv_cache_events:
# FIXME (Chen): Not sure whether we should return `hash_value`
# or `(hash_value, group_id)` here. But it's fine now because
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
return True
return False
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
Args:
blocks: A list of blocks to touch.
"""
for blocks_per_group in blocks:
for block in blocks_per_group:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.incr_ref()
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their
eviction priority, where the first block will be evicted first.
Args:
ordered_blocks: A list of blocks to free ordered by their eviction
priority.
"""
for block in ordered_blocks:
block.decr_ref()
# null_block should not be added to the free list.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.append(block)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
if num_used_blocks != 1: # The null block is always marked as used
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks - 1)
return False
# Remove all hashes so that no new blocks will hit.
self.cached_block_hash_to_block = defaultdict(dict)
# Remove all hashes from all blocks.
for block in self.blocks:
block.reset_hash()
logger.info("Successfully reset prefix cache")
if self.enable_kv_cache_events:
self.kv_event_queue.append(AllBlocksCleared())
return True
def get_num_free_blocks(self) -> int:
"""Get the number of free blocks in the pool.
Returns:
The number of free blocks.
"""
return self.free_block_queue.num_free_blocks
def get_usage(self) -> float:
"""Get the KV cache usage.
Returns:
The KV cache usage (between 0.0 and 1.0).
"""
return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks)
def take_events(self) -> list[KVCacheEvent]:
"""Atomically takes all events and clears the queue.
Returns:
A list of KV cache events.
"""
if not self.enable_kv_cache_events:
return []
events = self.kv_event_queue
self.kv_event_queue = []
return events

View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.multimodal import MultiModalRegistry
from vllm.v1.request import Request
if TYPE_CHECKING:
from vllm.config import ModelConfig, SchedulerConfig
logger = init_logger(__name__)
class EncoderCacheManager:
def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
# req_id -> cached input ids
self.cached: dict[str, set[int]] = {}
# list of [req_id, input_id]
self.freed: list[tuple[str, int]] = []
def has_cache(self, request: Request, input_id: int) -> bool:
req_id = request.request_id
return req_id in self.cached and input_id in self.cached[req_id]
def can_allocate(self, request: Request, input_id: int) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id)
return num_tokens <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
req_id = request.request_id
if req_id not in self.cached:
self.cached[req_id] = set()
self.cached[req_id].add(input_id)
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
def get_cached_input_ids(self, request: Request) -> set[int]:
return self.cached.get(request.request_id, set())
def free_encoder_input(self, request: Request, input_id: int) -> None:
"""Free a single encoder input id for the request."""
req_id = request.request_id
if req_id not in self.cached:
return
self.cached[req_id].discard(input_id)
if len(self.cached[req_id]) == 0:
del self.cached[req_id]
self.num_free_slots += request.get_num_encoder_tokens(input_id)
self.freed.append((req_id, input_id))
def free(self, request: Request) -> None:
"""Free all cached input ids for the request."""
input_ids = self.get_cached_input_ids(request).copy()
for input_id in input_ids:
self.free_encoder_input(request, input_id)
def get_freed_ids(self) -> list[tuple[str, int]]:
freed = self.freed
self.freed = []
return freed
def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""
if not model_config.is_multimodal_model:
return 0, 0
# TODO: handle encoder-decoder models once we support them.
(
encoder_compute_budget,
encoder_cache_size,
) = _compute_encoder_budget_multimodal(
model_config,
scheduler_config,
mm_registry,
)
return encoder_compute_budget, encoder_cache_size
def _compute_encoder_budget_multimodal(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""
max_tokens_by_modality_dict = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config)
if not max_tokens_by_modality_dict:
logger.warning(
"All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
"not be initialized.")
return 0, 0
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
key=lambda item: item[1])
if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item
> scheduler_config.max_num_batched_tokens):
raise ValueError(
"Chunked MM input disabled but max_tokens_per_mm_item "
f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens"
f" ({scheduler_config.max_num_batched_tokens}). Please increase "
"max_num_batched_tokens.")
encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
max_tokens_per_mm_item)
encoder_cache_size = max(scheduler_config.encoder_cache_size,
max_tokens_per_mm_item)
return encoder_compute_budget, encoder_cache_size

View File

@@ -0,0 +1,363 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Callable, Optional
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager, get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.request import Request
class KVCacheCoordinator(ABC):
"""
Coordinate the KV cache of different KV cache groups.
"""
def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
use_eagle: bool,
enable_caching: bool,
caching_hash_fn: Callable,
enable_kv_cache_events: bool,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
enable_kv_cache_events)
# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
self.single_type_managers = tuple(
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn,
) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups))
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks.
"""
num_blocks_to_allocate = 0
for i, manager in enumerate(self.single_type_managers):
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i])
return num_blocks_to_allocate
def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None:
"""
Add the new computed blocks to the request.
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
"""
for i, manager in enumerate(self.single_type_managers):
manager.save_new_computed_blocks(request_id,
new_computed_blocks[i])
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
Returns:
The new allocated blocks.
"""
return tuple(
manager.allocate_new_blocks(request_id, num_tokens)
for manager in self.single_type_managers)
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
for manager in self.single_type_managers:
manager.cache_blocks(request, block_hashes, num_computed_tokens)
def free(self, request_id: str) -> None:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
for manager in self.single_type_managers:
manager.free(request_id)
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> list[int]:
"""
Get the number of common prefix blocks for a request.
Args:
request_id: The request ID.
block_hashes: The block hashes of the request.
Returns:
The number of common prefix blocks.
"""
num_blocks_per_group = [
manager.get_num_common_prefix_blocks(request_id,
num_running_requests)
for manager in self.single_type_managers
]
return num_blocks_per_group
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks` and replace
the removed blocks with null_block.
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
"""
for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens)
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
"""
Get the blocks for the request.
"""
return tuple(
manager.req_to_blocks.get(request_id) or []
for manager in self.single_type_managers)
@abstractmethod
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
pass
class UnitaryKVCacheCoordinator(KVCacheCoordinator):
"""
KV cache coordinator for models with only one KV cache group. This is the
case for models with only one KV cache type, e.g., all attention layers use
full attention or all attention layers use sliding window attention.
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group")
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=[0],
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
)
return hit_blocks, len(hit_blocks[0]) * self.block_size
class HybridKVCacheCoordinator(KVCacheCoordinator):
"""
KV cache coordinator for hybrid models with multiple KV cache types, and
thus multiple kv cache groups.
To simplify `find_longest_cache_hit`, it only supports the combination of
two types of KV cache groups, and one of them must be full attention.
May extend to more general cases in the future.
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
"""
Verifies that the model has exactly two types of KV cache groups, and
one of them is full attention. Then, split the kv cache groups into full
attention groups and other groups.
"""
full_attention_type_id: Optional[str] = None
other_type_id: Optional[str] = None
self.full_attention_group_ids: list[int] = []
self.other_group_ids: list[int] = []
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
if isinstance(g.kv_cache_spec, FullAttentionSpec):
if full_attention_type_id is None:
full_attention_type_id = g.kv_cache_spec.type_id
else:
assert full_attention_type_id == g.kv_cache_spec.type_id, (
"HybridKVCacheCoordinator assumes exactly one type of "
"full attention groups now.")
self.full_attention_group_ids.append(i)
else:
if other_type_id is None:
other_type_id = g.kv_cache_spec.type_id
else:
assert other_type_id == g.kv_cache_spec.type_id, (
"HybridKVCacheCoordinator assumes "
"exactly one other type of groups now.")
self.other_group_ids.append(i)
assert full_attention_type_id is not None, (
"HybridKVCacheCoordinator assumes exactly one type of full "
"attention groups now.")
assert other_type_id is not None, (
"HybridKVCacheCoordinator assumes exactly one type of other "
"groups now.")
self.full_attention_manager_cls = FullAttentionManager
self.other_attention_cls = self.single_type_managers[
self.other_group_ids[0]].__class__
self.full_attention_spec = self.kv_cache_config.kv_cache_groups[
self.full_attention_group_ids[0]].kv_cache_spec
self.other_spec = self.kv_cache_config.kv_cache_groups[
self.other_group_ids[0]].kv_cache_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
assert self.other_block_size % self.full_attention_block_size == 0, (
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now.")
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
self.full_attn_first = False
else:
raise ValueError(
"HybridKVCacheCoordinator assumes the full "
"attention group ids and other attention group ids "
"do not interleave, either full attention group ids "
"are before other attention group ids or vice versa."
"This is for simplifying merging hit_blocks_full_attn and "
"hit_blocks_other_attn to hit_blocks.")
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
"""
Find the longest cache hit for the request.
Args:
block_hashes: The block hashes of the request.
max_cache_hit_length: The maximum length of the cache hit.
Returns:
A tuple containing:
- A list of the cache hit blocks for each single type manager.
- The number of tokens of the longest cache hit.
"""
# First, find the longest cache hit for full attention.
hit_blocks_full_attn = (
self.full_attention_manager_cls.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle,
))
hit_length = len(
hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention.
hit_blocks_other_attn = (
self.other_attention_cls.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=hit_length,
kv_cache_group_ids=self.other_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.other_spec,
use_eagle=self.use_eagle,
))
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
# NOTE: the prefix cache hit length must be a multiple of block_size as
# we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multiple of the block size of
# full attention layers in current implementation, because hit_length is
# a multiple of other attention's block size, and other attention's
# block size is a multiple of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`).
assert hit_length % self.full_attention_block_size == 0
# Truncate the full attention cache hit to the length of the
# cache hit of the other attention.
for group_hit_blocks in hit_blocks_full_attn:
del group_hit_blocks[hit_length // self.full_attention_block_size:]
# Merge the hit blocks of full attention and other attention.
if self.full_attn_first:
hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn
else:
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
return hit_blocks, hit_length
def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
enable_caching: bool, caching_hash_fn: Callable,
enable_kv_cache_events: bool) -> KVCacheCoordinator:
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)

View File

@@ -0,0 +1,392 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.utils import sha256
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
hash_request_tokens)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__)
@dataclass
class KVCacheBlocks:
"""
The allocation result of KVCacheManager, work as the interface between
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler.
"""
blocks: tuple[list[KVCacheBlock], ...]
"""
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
We don't use block of tokens as the outer dimension because it assumes all
kv_cache_groups have the same number of blocks, which is true for now but
will be broken if we want to give different block_size to different
kv_cache_groups in the future.
"""
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances."""
return KVCacheBlocks(
tuple(blk1 + blk2
for blk1, blk2 in zip(self.blocks, other.blocks)))
def get_block_ids(self) -> tuple[list[int], ...]:
"""
Converts the KVCacheBlocks instance to block_ids.
Returns:
tuple[list[int], ...]: A tuple of lists where
* the outer tuple corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group
"""
return tuple([blk.block_id for blk in group] for group in self.blocks)
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
assert len(self.blocks) == 1, "Only one group is supported"
return [
block.block_id for block in self.blocks[0]
if block.block_hash is None
]
def new_empty(self) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
class KVCacheManager:
def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
enable_caching: bool = True,
caching_hash_algo: str = "builtin",
use_eagle: bool = False,
log_stats: bool = False,
enable_kv_cache_events: bool = False,
) -> None:
self.max_model_len = max_model_len
self.enable_caching = enable_caching
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
self.use_eagle = use_eagle
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
assert len(
set(g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups)
) == 1, "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
use_eagle=self.use_eagle,
enable_caching=enable_caching,
caching_hash_fn=self.caching_hash_fn,
enable_kv_cache_events=enable_kv_cache_events,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
self.kv_cache_config = kv_cache_config
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self.req_to_block_hashes: defaultdict[
str, list[BlockHash]] = defaultdict(list)
@property
def usage(self) -> float:
"""Get the KV cache usage.
Returns:
The KV cache usage (between 0.0 and 1.0).
"""
return self.block_pool.get_usage()
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
"""Get (and reset) the prefix cache stats.
Returns:
The current prefix caching stats, or None if logging is disabled.
"""
if not self.log_stats:
return None
stats = self.prefix_cache_stats
self.prefix_cache_stats = PrefixCacheStats()
return stats
def get_computed_blocks(self,
request: Request) -> tuple[KVCacheBlocks, int]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
Args:
request: The request to get the computed blocks.
Returns:
A tuple containing:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching
or request.sampling_params.prompt_logprobs is not None):
return self.create_empty_block_list(), 0
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
block_hashes = hash_request_tokens(self.caching_hash_fn,
self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just
# the single last token, because allocate_slots() requires
# num_computed_tokens to be block-size aligned. Removing this limitation
# could slightly improve performance in the future.
max_cache_hit_length = request.num_tokens - 1
computed_blocks, num_new_computed_tokens = (
self.coordinator.find_longest_cache_hit(block_hashes,
max_cache_hit_length))
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens
return KVCacheBlocks(computed_blocks), num_new_computed_tokens
def allocate_slots(
self,
request: Request,
num_new_tokens: int,
num_new_computed_tokens: int = 0,
new_computed_blocks: Optional[KVCacheBlocks] = None,
num_draft_tokens: int = 0,
num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False,
) -> Optional[KVCacheBlocks]:
"""Add slots for a request with new tokens to append.
Args:
request: The request to allocate slots.
num_new_tokens: The number of tokens to allocate, including external
tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks).
num_new_computed_tokens: The number of new computed tokens just
hitting the prefix caching, excluding external tokens.
new_computed_blocks: The cached blocks for the above new computed
tokens.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
delay_cache_blocks: Whether to skip caching the blocks. This is
used by P/D when allocating blocks used in a KV transfer
which will complete in a future step.
Blocks layout:
```
-----------------------------------------------------------------------
| < computed > | < new computed > | < new > | < pre-allocated > |
-----------------------------------------------------------------------
| < required > |
--------------------------------------------------
| < full > |
------------------------------------------------
| <new full> |
--------------
```
The following *_blocks are illustrated in this layout.
Returns:
A list of new allocated blocks.
"""
if num_new_tokens == 0:
raise ValueError("num_new_tokens must be greater than 0")
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = tuple(
[] for _ in range(len(self.kv_cache_config.kv_cache_groups)))
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# We can do this even if we cannot schedule this request due to
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
self.coordinator.remove_skipped_blocks(request.request_id,
request.num_computed_tokens)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
num_new_computed_tokens)
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
request_id=request.request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
# Cannot allocate new blocks
return None
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self.block_pool.touch(new_computed_block_list)
else:
assert not any(new_computed_block_list), (
"Computed blocks should be empty when "
"prefix caching is disabled")
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.coordinator.save_new_computed_blocks(request.request_id,
new_computed_block_list)
new_blocks = self.coordinator.allocate_new_blocks(
request.request_id, num_tokens_need_slot)
# P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
if not self.enable_caching or delay_cache_blocks:
return KVCacheBlocks(new_blocks)
# Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.
self.coordinator.cache_blocks(
request, self.req_to_block_hashes[request.request_id],
num_computed_tokens + num_new_tokens - num_draft_tokens)
return KVCacheBlocks(new_blocks)
def free(self, request: Request) -> None:
"""Free the blocks allocated for the request.
We free the blocks in reverse order so that he tail blocks are evicted
first when caching is enabled.
Args:
request: The request to free the blocks.
"""
self.coordinator.free(request.request_id)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalidate prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
if not self.block_pool.reset_prefix_cache():
return False
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.reset = True
return True
def get_num_common_prefix_blocks(
self,
request: Request,
num_running_requests: int,
) -> list[int]:
"""Calculate the number of common prefix blocks shared by all requests
in the RUNNING state for each kv cache group.
The function determines this by selecting any request and iterating
through its blocks. A block is considered a common prefix block if its
`ref_cnt` equals the total number of requests in the RUNNING state.
NOTE(woosuk): The number of requests in the RUNNING state is **greater
than or equal to** the number of requests scheduled in the current step.
This is because the RUNNING state only indicates that:
1. The request has not yet finished, and
2. The request holds its blocks unfreed.
While all scheduled requests must be in the RUNNING state, the inverse
is not necessarily true. There may be RUNNING requests that are not
scheduled in the current step.
This can result in an edge case where the number of common prefix blocks
is 0, even though all scheduled requests share a common prefix. This
occurs because there may be unscheduled RUNNING requests that do not
share the common prefix. Currently, this case cannot be easily detected,
so the function returns 0 in such cases.
Args:
request: Any request in the RUNNING state, used to identify the
common prefix blocks.
num_running_requests: The total number of requests in the RUNNING
state. This can be different from the number of scheduled
requests in the current step.
Returns:
list[int]: The number of common prefix blocks for each kv cache
group.
"""
assert request.status == RequestStatus.RUNNING
return self.coordinator.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.
NOTE: Unlike `free`, this method should be called only when the request
is finished, not when it is preempted.
"""
self.req_to_block_hashes.pop(request.request_id, None)
def take_events(self) -> list[KVCacheEvent]:
"""Take the KV cache events from the block pool.
Returns:
A list of KV cache events.
"""
return self.block_pool.take_events()
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
"""Get the block ids of a request."""
return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids()
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request."""
block_hashes = self.req_to_block_hashes[request.request_id]
self.coordinator.cache_blocks(request, block_hashes,
num_computed_tokens)
def create_empty_block_list(self) -> KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks(tuple([]
for _ in range(self.num_kv_cache_groups)))

View File

@@ -0,0 +1,996 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV-Cache Utilities."""
import os
from collections import defaultdict, deque
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, cdiv, sha256
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
logger = init_logger(__name__)
class BlockHash(NamedTuple):
"""Hash value of a block (int), the token IDs in the block, and extra keys.
We keep a tuple of token IDs and extra keys to reduce the likelihood of
hash collisions when the hash value is the same. By using SHA256 however,
hash collisions are practically impossible.
"""
# Hash value of the block in an integer.
hash_value: int
# Token IDs in the block.
token_ids: tuple[int, ...]
# Extra keys for the block.
extra_keys: Optional[Any] = None
class BlockHashWithGroupId(NamedTuple):
# The hash value for the contents (e.g., token_ids) of a block without group
# ID. The value is the same for blocks representing the same tokens but for
# different groups.
block_hash: BlockHash
# The KV cache group ID.
group_id: int
def get_hash_value(self) -> int:
return self.block_hash.hash_value
# The hash seed for the first block of the prefix block sequence.
#
# Even if the hash function is the builtin hash(), we use sha256 to generate
# the initial hash to simplify the code. This is not performance critical
# as it is done one per process.
#
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
# variable if set such that processes can share the seed if needed.
# This aligns with the behavior of Python's hash() function, which also uses
# a random seed if PYTHONHASHSEED is not set.
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
"PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED"))
class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the max recent N requests.
Args:
max_recent_requests: The number of the max recent requests to aggregate.
Defaults to 1000.
"""
def __init__(self, max_recent_requests: int = 1000):
self.max_recent_requests = max_recent_requests
# The current aggregated values.
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
# A deque of (requests, queries, hits) for the most recent requests.
self.query_queue: deque[tuple[int, int, int]] = deque()
def observe(self, stats: PrefixCacheStats):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `interval` requests, the oldest set of
requests are removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if stats.reset:
self.reset()
# Update the metrics.
self.query_queue.append((stats.requests, stats.queries, stats.hits))
self.aggregated_requests += stats.requests
self.aggregated_query_total += stats.queries
self.aggregated_query_hit += stats.hits
# Remove the oldest stats if the number of requests exceeds.
if self.aggregated_requests > self.max_recent_requests:
old_requests, old_queries, old_hits = self.query_queue.popleft()
self.aggregated_requests -= old_requests
self.aggregated_query_total -= old_queries
self.aggregated_query_hit -= old_hits
def reset(self):
"""Reset the metrics."""
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.query_queue.clear()
@property
def hit_rate(self) -> float:
"""Calculate the hit rate for the past N requests."""
if self.aggregated_query_total == 0:
return 0.0
return self.aggregated_query_hit / self.aggregated_query_total
@dataclass
class KVCacheBlock:
"""KV-cache block metadata."""
# Block ID, ranging from 0 to num_gpu_blocks - 1.
block_id: int
# Reference count.
ref_cnt: int = 0
# The hash of the block composed of (block hash, tuple of token IDs).
# It is only available when the block is full.
_block_hash: Optional[BlockHashWithGroupId] = None
# Used to construct a doubly linked list for free blocks.
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
prev_free_block: Optional["KVCacheBlock"] = None
next_free_block: Optional["KVCacheBlock"] = None
# Whether the block is a null block that should never be cached.
is_null: bool = False
def incr_ref(self):
self.ref_cnt += 1
def decr_ref(self):
self.ref_cnt -= 1
@property
def block_hash(self) -> Optional[BlockHashWithGroupId]:
return self._block_hash
@block_hash.setter
def block_hash(self, block_hash: BlockHashWithGroupId):
assert self.block_hash is None, (
"The block already has a hash. This should not happen.")
self._block_hash = block_hash
def reset_hash(self):
"""Reset the block hash when the block is evicted."""
self._block_hash = None
def __repr__(self) -> str:
# Use block_id instead of KVCacheBlock object to avoid calling __repr__
# on KVCacheBlock object recursively.
prev_block_id = (self.prev_free_block.block_id
if self.prev_free_block else None)
next_block_id = (self.next_free_block.block_id
if self.next_free_block else None)
return (f"KVCacheBlock(block_id={self.block_id}, "
f"ref_cnt={self.ref_cnt}, "
f"_block_hash={self._block_hash}, "
f"prev_free_block={prev_block_id}, "
f"next_free_block={next_block_id})")
class FreeKVCacheBlockQueue:
"""This class organizes a list of KVCacheBlock objects to a doubly linked
list of free blocks. We implement this class instead of using Python
builtin deque to support removing a block in the middle of the queue
in O(1) time. To close the performance gap to the builtin deque which is
implemented in C++, this class does not allocate any Python objects when
manipulating the linked list. Instead, this class manipulates the
prev_free_block and next_free_block attributes of the given blocks.
The queue is ordered by block ID in the beginning. When a block is allocated
and then freed, it will be appended back with the eviction order:
1. The least recent used block is at the front (LRU).
2. If two blocks have the same last accessed time (allocated by the
same sequence), the one with more hash tokens (the tail of a block
chain) is at the front.
Note that we maintain this order by reversing the block order when free
blocks of a request. This operation is outside of this class.
Args:
blocks: A list of KVCacheBlock objects.
"""
def __init__(self, blocks: list[KVCacheBlock]) -> None:
self.num_free_blocks = len(blocks)
# Initialize the doubly linked list of free blocks.
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
for i in range(self.num_free_blocks):
if i > 0:
blocks[i].prev_free_block = blocks[i - 1]
if i < self.num_free_blocks - 1:
blocks[i].next_free_block = blocks[i + 1]
def popleft(self) -> KVCacheBlock:
"""Pop the first free block and reduce num_free_blocks by 1.
Returns:
The first free block.
"""
if not self.free_list_head:
raise ValueError("No free blocks available")
block = self.free_list_head
self.remove(block)
return block
def remove(self, block: KVCacheBlock) -> None:
"""Remove a block in the free list and reduce num_free_blocks by 1.
Args:
block: The block to remove.
"""
if block.prev_free_block is not None:
# Link the previous block to the next block.
block.prev_free_block.next_free_block = block.next_free_block
if block.next_free_block is not None:
# Link the next block to the previous block.
block.next_free_block.prev_free_block = block.prev_free_block
if block == self.free_list_head:
# Update the head if the block is the head.
self.free_list_head = block.next_free_block
if block == self.free_list_tail:
# Update the tail if the block is the tail.
self.free_list_tail = block.prev_free_block
# Remove the block from the linked list.
block.prev_free_block = block.next_free_block = None
self.num_free_blocks -= 1
def append(self, block: KVCacheBlock) -> None:
"""Put a block back into the free list and increase
num_free_blocks by 1.
Args:
block: The block to append.
"""
if self.free_list_tail is not None:
# Link the last block to the new block.
self.free_list_tail.next_free_block = block
block.prev_free_block = self.free_list_tail
self.free_list_tail = block
else:
# The free list is empty.
assert self.free_list_head is None
self.free_list_head = self.free_list_tail = block
block.next_free_block = None
self.num_free_blocks += 1
def get_all_free_blocks(self) -> list[KVCacheBlock]:
"""Get all free blocks in the free list. Mainly used for testing.
Returns:
A list of free blocks.
"""
ret = []
curr_block = self.free_list_head
while curr_block is not None:
ret.append(curr_block)
curr_block = curr_block.next_free_block
return ret
def need_extra_keys(request: Request) -> bool:
"""Check whether the blocks allocated to this request need extra hash keys.
Args:
request (Request): The request.
Returns:
bool: Whether blocks allocated to this request need extra hash keys.
"""
# Multimodal requests need to include the MM hash.
# LoRA requests need to include the LoRA ID.
# Request with provided cache salt need to include the salt.
return bool(request.mm_positions) or (request.lora_request
is not None) or (request.cache_salt
is not None)
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
end_token_idx: int,
start_mm_idx: int) -> tuple[list[Any], int]:
"""Generate extra keys related to MultiModal request for block hash
computation. For multi-modal inputs, the extra keys are
(mm_hash, start_offset) that indicate a mm input contained in the
block and its starting offset in the block tokens.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
extra_keys: list[Any] = []
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if not mm_positions:
return extra_keys, start_mm_idx
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set disable_mm_preprocessor_cache=False.")
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
return extra_keys, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if start_mm_idx < 0:
assert -start_mm_idx <= len(mm_positions)
start_mm_idx = len(mm_positions) + start_mm_idx
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
offset = mm_positions[curr_mm_idx].offset
length = mm_positions[curr_mm_idx].length
if end_token_idx > offset:
if start_token_idx > offset + length:
# This block has passed the current mm input.
curr_mm_idx += 1
continue
# The block contains the current mm input.
extra_keys.append(mm_hashes[curr_mm_idx])
if end_token_idx >= offset + length:
# If this block contains the end of the current mm input,
# move to the next mm input as this block may also contain
# the next mm input.
curr_mm_idx += 1
else:
# Otherwise this block is done with mm inputs.
break
else:
# This block has not reached the current mm input.
break
return extra_keys, curr_mm_idx
def _gen_lora_extra_hash_keys(request: Request) -> list[int]:
"""Generate extra keys related to LoRA for block hash computation.
Args:
request: The request object.
Returns:
Return LoRA id of the request if it is a LoRA request. Return empty
list otherwise.
"""
if not request.lora_request:
return []
return [request.lora_request.lora_int_id]
def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int,
start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_extra_keys: list[Any]
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
request, start_token_idx, end_token_idx, start_mm_idx)
lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request)
cache_salt_keys: list[str] = [request.cache_salt] if (
start_token_idx == 0 and request.cache_salt) else []
extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys
if not extra_keys:
return None, new_start_mm_idx
return tuple(extra_keys), new_start_mm_idx
def hash_block_tokens(
hash_function: Callable,
parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
hash values for the same block contents.
Args:
parent_block_hash: The hash of the parent block. None
if this is the first block.
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
if not parent_block_hash:
parent_block_hash = NONE_HASH
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHash(
hash_function(
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
curr_block_token_ids_tuple, extra_keys)
def hash_request_tokens(hash_function: Any, block_size: int,
request: Request) -> list[BlockHash]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
block_size: The size of each block.
request: The request object.
Returns:
The list of computed hash values.
"""
token_ids = request.all_token_ids
req_need_extra_keys = need_extra_keys(request)
req_extra_keys = None
curr_mm_idx = 0
ret = []
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = token_ids[start:end]
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
if req_need_extra_keys:
# MM and LoRA requests need extra keys for block-hash computation.
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
block_token_ids, req_extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret
def max_memory_usage_bytes(vllm_config: VllmConfig,
kv_cache_specs: Iterable[KVCacheSpec]) -> int:
"""
Get the maximum memory usage in bytes for the given KV cache specs.
"""
return sum(
spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs)
def estimate_max_model_len(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> int:
"""
Estimates the maximum model length that can fit in the available memory
using binary search.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The estimated maximum model length that can fit in the available memory.
"""
# Define a function to check if a given model length fits in memory
def fits_in_memory(model_len: int) -> bool:
# Modify the max_model_len for this calculation
vllm_config.model_config.max_model_len = model_len
# Calculate memory needed for the given model length
memory_needed = max_memory_usage_bytes(vllm_config,
kv_cache_spec.values())
return memory_needed <= available_memory
# Binary search for the maximum model length
current_max = vllm_config.model_config.max_model_len
left, right = 1, current_max
# If even the smallest model length doesn't fit, return 0
if not fits_in_memory(left):
return 0
# Binary search for the maximum model length that fits
result = 1
while left <= right:
mid = (left + right) // 2
if fits_in_memory(mid):
result = mid
left = mid + 1
else:
right = mid - 1
return result
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int):
"""
Checks whether `available_memory` is enough for the KV cache to hold at
least one request with the model's max_model_len.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Raises:
ValueError: If there is not enough memory available for the KV cache.
"""
if available_memory <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_model_len = vllm_config.model_config.max_model_len
needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values())
if needed_memory > available_memory:
# Estimate the maximum model length that can fit in the available memory
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
available_memory)
estimated_msg = ""
if estimated_max_len > 0:
estimated_msg = (
"Based on the available memory, "
f"the estimated maximum model length is {estimated_max_len}.")
raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory/GiB_bytes:.2f} GiB). "
f"{estimated_msg} "
f"Try increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine.")
def create_kv_cache_group_specs(
kv_cache_spec: dict[str, KVCacheSpec],
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
"""
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
KVCacheSpec.
Args:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
KVCacheSpec.
Returns:
A list of KVCacheGroupSpec objects, one for each group.
"""
kv_cache_groups = []
for layer_names_one_group in grouped_layer_names:
layer_specs = [
kv_cache_spec[layer_name] for layer_name in layer_names_one_group
]
merged_layer_spec = layer_specs[0].merge(layer_specs)
kv_cache_groups.append(
KVCacheGroupSpec(layer_names_one_group, merged_layer_spec))
return kv_cache_groups
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same type of KV cache.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
Returns:
True if all layers have the same type, False otherwise.
"""
layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
return len(layer_keys) == 1
def get_max_concurrency_for_kv_cache_config(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig) -> float:
"""
Get the maximum concurrency for the given KV cache configuration.
"""
num_layer_per_group = max(
len(group.layer_names) for group in kv_cache_config.kv_cache_groups)
max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes(
vllm_config,
(group.kv_cache_spec for group in kv_cache_config.kv_cache_groups))
memory_per_block = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.page_size_bytes * num_layer_per_group
num_block_per_request = cdiv(max_memory_usage_per_request,
memory_per_block)
max_concurrency = kv_cache_config.num_blocks / num_block_per_request
return max_concurrency
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
available_memory: int, page_size: int) -> int:
"""
Get the number of kv cache blocks.
Args:
vllm_config: The global VllmConfig
num_layers: The number of layers
available_memory: Memory available for KV cache in bytes.
page_size: The page size of the KV cache.
"""
num_blocks = int(available_memory // page_size // num_layers)
num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = \
vllm_config.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
return num_blocks
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
"""
Get the page size of the KV cache.
"""
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
assert len(page_sizes) == 1
return page_sizes.pop()
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
"""
page_size = get_uniform_page_size(kv_cache_spec)
num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec),
available_memory, page_size)
per_layer_size = page_size * num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
grouped_layer_names = [list(kv_cache_spec.keys())]
# Each layer uses a separate Tensor to store its KV cache.
kv_cache_tensors = [
KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
for layer_name in kv_cache_spec
]
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=kv_cache_tensors,
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
grouped_layer_names),
)
num_tokens = num_blocks * vllm_config.cache_config.block_size
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
max_concurrency = get_max_concurrency_for_kv_cache_config(
vllm_config, kv_cache_config)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
max_model_len_str, max_concurrency)
return kv_cache_config
def is_kv_cache_page_size_uniform(
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same page size.
Args:
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
True if all layers have the same page size, False otherwise.
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
return len(page_sizes) == 1
def _get_kv_cache_config_uniform_page_size(
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for hybrid models with multiple
attention types but still with a uniform page size (physical memory per
block per layer) for all layers.
Detailed explanation about kv cache management of hybrid models:
The layers in the models are repeated with some patterns, e.g., a model
with 10 full attention layers and 20 sliding window attention layers can be
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
The KVCacheManager allocates different block tables for each of the 3 layers
in the pattern, and repeats each of them 10 times to generate the
block_table for the 30 layers in the model.
Therefore, we can group the layers in the model into 3 kv_cache_groups, each
of which contains 10 layers in the model.
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers. It is already handled by
`_get_kv_cache_config_uniform_type`.
2. A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 kv_cache_groups, each of which represents 10 layers.
To simplify the implementation, we make the following assumptions:
1. Physical memory per block: Must be the same across all KV cache groups.
Breaking this assumption is non-trivial due to memory fragmentation concerns
when allocating blocks of different sizes.
2. Tokens per block (block_size): Currently, we directly use
`CacheConfig.block_size` for all layers. It can be extended to vary by KV
cache group, but within each KV cache group, all layers must share the same
block size.
3. Physical memory per token per layer: This property is decided by model
config. Currently we only support models that have the same physical memory
per token per layer for all layers. Can be relaxed with a simple extension,
but still need to keep physical memory per block the same for all groups.
4. Number of layers per group: Currently assumed the same for all layers.
Can be relaxed with a simple extension, but still need to keep physical
memory per block the same for all groups.
5. Attention type within groups: All layers in a group must share the same
attention type. One exception is that, when
`--disable-hybrid-kv-cache-manager` is true, the single group for full
attention layers may also include attention layers using sliding window or
LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details.
6. Support for multiple attention types: The design for most components is
general to an arbitrary number of attention types. But
`find_longest_cache_hit` only supports one attention type or two
types of full-attention plus exactly one another type. The general
implementation of this function is feasible but we don't know how to
implement it cleanly yet.
As we assume tokens per block, physical memory per token per layer, and
number of layers per group are the same now, we can ensure that physical
memory per block is the same for all groups.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
"""
# Group all layers by type_id.
# E.g., 2 full attention layers and 3 sliding window attention layers,
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
same_type_layers: dict[str, list[str]] = defaultdict(list)
for layer_name, layer_spec in kv_cache_spec.items():
same_type_layers[layer_spec.type_id].append(layer_name)
# Split each group into smaller groups, to make the number of layers in each
# group identical. Add padding to the last group of each type if necessary.
# E.g., (full.0, full.1), (sw.0, sw.1, sw.2)
# split to 3 groups with 2 layers each:
# (full.0, full.1), (sw.0, sw.1), (sw.2, padding).
# FIXME(Chen): At the moment of writing this code (2025-06-02), all
# open-source hybrid model follows a n:1 pattern between different attention
# types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and
# full), so we can use the "1" in the n:1 pattern as the group size, which
# is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10).
group_size = min([len(layers) for layers in same_type_layers.values()])
grouped_layers = []
for layers in same_type_layers.values():
num_padding_layers = group_size - len(layers) % group_size
if num_padding_layers != group_size:
logger.warning(
"Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa
num_padding_layers,
num_padding_layers / len(layers) * 100,
)
for i in range(0, len(layers), group_size):
grouped_layers.append(layers[i:i + group_size])
kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec,
grouped_layers)
# Determine how model runners should initialize the KV cache tensors.
# We will have group_size memory pools, each is shared by one layer from
# each group. As layers of different groups have different block table,
# they will use different parts of the shared Tensor.
# The memory layout in the example will be:
# full.0, sw.0, sw.2: share a Tensor with size=available_memory//2
# full.1, sw.1: share another Tensor with size=available_memory//2
page_size = get_uniform_page_size(kv_cache_spec)
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
page_size)
per_memory_pool_size = page_size * num_blocks
kv_cache_tensors = []
for i in range(group_size):
shared_by = []
for j in range(len(kv_cache_groups)):
if i < len(grouped_layers[j]):
shared_by.append(grouped_layers[j][i])
kv_cache_tensors.append(
KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by))
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=kv_cache_tensors,
kv_cache_groups=kv_cache_groups,
)
# Print the KV cache size and maximum concurrency.
num_tokens = num_blocks // len(
grouped_layers) * vllm_config.cache_config.block_size
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
max_concurrency = get_max_concurrency_for_kv_cache_config(
vllm_config, kv_cache_config)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
max_model_len_str, max_concurrency)
return kv_cache_config
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
"""
This function tries to convert the KV cache specs to one type if the model
is a hybrid model with multiple type of KV cache. It will convert all
SlidingWindowSpec to FullAttentionSpec if both types are present.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
type_ids = set(layer_spec.type_id
for layer_spec in kv_cache_spec.values())
return len(type_ids) > 1
if not is_hybrid(kv_cache_spec):
return
logger.warning(
"Hybrid KV cache manager is disabled for this hybrid model, "
"This means we do not enable any optimizations for saving KV cache "
"memory (e.g., dropping the KV cache outside the sliding window). "
"The compute of layers like sliding window is still saved.")
has_full_attention = any(
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
has_sliding_window = any(
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
if has_full_attention and has_sliding_window:
for layer_name, spec in kv_cache_spec.items():
if isinstance(spec, SlidingWindowSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
sliding_window=spec.sliding_window,
)
if is_hybrid(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type.")
def get_kv_cache_config(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int,
) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)
elif is_kv_cache_page_size_uniform(kv_cache_spec):
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_config_uniform_page_size(vllm_config,
kv_cache_spec,
available_memory)
raise NotImplementedError
def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
"""
Make the KV cache configurations for each worker consistent, so that all
workers can be controlled by the same KVCacheManager.
This function verifies that the layer group of each worker are the same,
and changes the num_blocks of each worker to the smallest among all workers.
Args:
kv_cache_configs: The KV cache configurations for each worker. Will be
in-place modified to make them consistent.
"""
# Sort the kv cache groups by the type_id of their KV cache spec.
# This can avoid the inconsistency caused by the order of groups.
for kv_cache_config in kv_cache_configs:
kv_cache_config.kv_cache_groups.sort(
key=lambda x: x.kv_cache_spec.type_id)
# Verify that the groups of each rank are the same.
for kv_cache_config in kv_cache_configs[1:]:
for group_rank_0, group_rank_i in zip(
kv_cache_configs[0].kv_cache_groups,
kv_cache_config.kv_cache_groups):
assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec
# Change the num_blocks of each rank to the smallest among all ranks. We
# do not need to shrink the tensor size because it is valid to only use the
# first `num_blocks` blocks of the tensor.
min_num_blocks = min(kv_cache_config.num_blocks
for kv_cache_config in kv_cache_configs)
for kv_cache_config in kv_cache_configs:
kv_cache_config.num_blocks = min_num_blocks
return kv_cache_configs

View File

View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
class SchedulerInterface(ABC):
@abstractmethod
def schedule(self) -> "SchedulerOutput":
"""Schedule the requests to process in this scheduling step.
The scheduling decision is made at the iteration level. Each scheduling
step corresponds to a single forward pass of the model. Therefore, this
method is called repeatedly by a busy loop in the engine.
Essentially, the scheduler produces a dictionary of {req_id: num_tokens}
that specifies how many tokens to process for each request in this
scheduling step. For example, num_tokens can be as large as the number
of prompt tokens for new requests, or it can be 1 for the requests that
are auto-regressively generating new tokens one by one. Otherwise, it
can be somewhere in between in case of chunked prefills, prefix caching,
speculative decoding, etc.
Additionally, the scheduler also returns useful data about each request
or the batch as a whole. The model runner will use this information in
preparing inputs to the model.
Returns:
A SchedulerOutput object containing information about the scheduled
requests.
"""
raise NotImplementedError
@abstractmethod
def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> dict[int, "EngineCoreOutputs"]:
"""Update the scheduler state based on the model runner output.
This method is called after the model runner has processed the scheduled
requests. The model runner output includes generated token ids, draft
token ids for next step, etc. The scheduler uses this information to
update its states, checks the finished requests, and returns the output
for each request.
Returns:
A dict of client index to EngineCoreOutputs object containing the
outputs for each request originating from that client.
"""
raise NotImplementedError
@abstractmethod
def add_request(self, request: "Request") -> None:
"""Add a new request to the scheduler's internal queue.
Args:
request: The new request being added.
"""
raise NotImplementedError
@abstractmethod
def finish_requests(
self,
request_ids: Union[str, Iterable[str]],
finished_status: "RequestStatus",
) -> None:
"""Finish the requests in the scheduler's internal queue. If the request
is not in the queue, this method will do nothing.
This method is called in two cases:
1. When the request is aborted by the client.
2. When the frontend process detects a stop string of the request after
de-tokenizing its generated tokens.
Args:
request_ids: A single or a list of request IDs.
finished_status: The finished status of the given requests.
"""
raise NotImplementedError
@abstractmethod
def get_num_unfinished_requests(self) -> int:
"""Number of unfinished requests in the scheduler's internal queue."""
raise NotImplementedError
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests in the scheduler's
internal queue."""
return self.get_num_unfinished_requests() > 0
@abstractmethod
def has_finished_requests(self) -> bool:
"""Returns True if there are finished requests that need to be cleared.
NOTE: This is different from `not self.has_unfinished_requests()`.
The scheduler maintains an internal list of the requests finished in the
previous step. This list is returned from the next call to schedule(),
to be sent to the model runner in the next step to clear cached states
for these finished requests.
This method checks if this internal list of finished requests is
non-empty. This information is useful for DP attention.
"""
raise NotImplementedError
def has_requests(self) -> bool:
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset the prefix cache for KV cache.
This is particularly required when the model weights are live-updated.
"""
raise NotImplementedError
@abstractmethod
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
raise NotImplementedError
@abstractmethod
def make_stats(self) -> Optional["SchedulerStats"]:
"""Make a SchedulerStats object for logging.
The SchedulerStats object is created for every scheduling step.
"""
raise NotImplementedError
@abstractmethod
def shutdown(self) -> None:
"""Shutdown the scheduler."""
raise NotImplementedError
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
return None

View File

@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata)
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
@dataclass
class NewRequestData:
req_id: str
prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
@classmethod
def from_request(
cls,
request: Request,
block_ids: tuple[list[int], ...],
) -> NewRequestData:
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
)
def __repr__(self):
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids},"
f"mm_inputs={self.mm_inputs},"
f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
")")
# Version of __repr__ with the prompt data obfuscated
def anon_repr(self):
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
f"mm_inputs={self.mm_inputs},"
f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
")")
@dataclass
class CachedRequestData:
req_id: str
# If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_token_ids: list[int]
new_block_ids: tuple[list[int], ...]
num_computed_tokens: int
@classmethod
def from_request(
cls,
request: Request,
resumed_from_preemption: bool,
new_token_ids: list[int],
new_block_ids: tuple[list[int], ...],
) -> CachedRequestData:
return cls(
req_id=request.request_id,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=request.num_computed_tokens,
)
@dataclass
class SchedulerOutput:
# list of the requests that are scheduled for the first time.
# We cache the request's data in each worker process, so that we don't
# need to re-send it every scheduling step.
scheduled_new_reqs: list[NewRequestData]
# list of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost.
scheduled_cached_reqs: list[CachedRequestData]
# req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request.
num_scheduled_tokens: dict[str, int]
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens: int
# req_id -> spec_token_ids
# If a request does not have any spec decode tokens, it will not be
# included in the dictionary.
scheduled_spec_decode_tokens: dict[str, list[int]]
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs: dict[str, list[int]]
# Number of common prefix blocks for all requests in each KV cache group.
# This can be used for cascade attention.
num_common_prefix_blocks: list[int]
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.
finished_req_ids: set[str]
# list of (req_id, encoder_input_index) tuples.
# Used to free the encoder cache.
free_encoder_input_ids: list[tuple[str, int]]
# Dict of request ids to their index within the batch
# for filling the next token bitmask
structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]]
# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.request import Request, RequestStatus
def check_stop(request: Request, max_model_len: int) -> bool:
if (request.num_tokens >= max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
return True
return False

View File

@@ -0,0 +1,403 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Callable
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
from vllm.v1.request import Request
class SingleTypeKVCacheManager(ABC):
"""
An abstract base class for a manager that handle the kv cache management
logic of one specific type of attention layer.
"""
def __init__(
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
kv_cache_group_id: int,
caching_hash_fn: Callable,
) -> None:
"""
Initializes the SpecializedManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager.
caching_hash_fn: The caching hash function.
"""
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self.req_to_blocks: defaultdict[str,
list[KVCacheBlock]] = defaultdict(list)
# {req_id: The number of cached blocks for this given request}
# This is used to track the number of cached blocks for each request.
# This is only used to track the RUNNING requests, we do not track the
# data for reempted ones.
self.num_cached_block: dict[str, int] = {}
self.caching_hash_fn = caching_hash_fn
self.kv_cache_group_id = kv_cache_group_id
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: list[KVCacheBlock]) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks.
"""
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = (num_required_blocks - len(new_computed_blocks) -
len(self.req_to_blocks[request_id]))
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it will be changed from a free block
# to a computed block when the request is allocated, so we also count
# it as needed to be allocated.
num_evictable_computed_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null
for blk in new_computed_blocks)
return num_new_blocks + num_evictable_computed_blocks
def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: list[KVCacheBlock]) -> None:
"""
Add the new computed blocks to the request.
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
"""
if request_id not in self.num_cached_block:
# A new request.
req_blocks = self.req_to_blocks[request_id]
assert len(req_blocks) == 0
req_blocks.extend(new_computed_blocks)
self.num_cached_block[request_id] = len(new_computed_blocks)
else:
# A running request. Should not have new computed blocks.
assert len(new_computed_blocks) == 0
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[KVCacheBlock]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
Returns:
The new allocated blocks.
"""
req_blocks = self.req_to_blocks[request_id]
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = num_required_blocks - len(req_blocks)
if num_new_blocks <= 0:
return []
else:
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
return new_blocks
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
num_cached_blocks = self.num_cached_block[request.request_id]
num_full_blocks = num_tokens // self.block_size
self.block_pool.cache_full_blocks(
request=request,
blocks=self.req_to_blocks[request.request_id],
block_hashes=block_hashes,
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks,
block_size=self.block_size,
kv_cache_group_id=self.kv_cache_group_id,
hash_fn=self.caching_hash_fn,
)
self.num_cached_block[request.request_id] = num_full_blocks
def free(self, request_id: str) -> None:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
# Default to [] in case a request is freed (aborted) before alloc.
req_blocks = self.req_to_blocks.pop(request_id, [])
# Free blocks in reverse order so that the tail blocks are
# freed first.
ordered_blocks = reversed(req_blocks)
self.block_pool.free_blocks(ordered_blocks)
self.num_cached_block.pop(request_id, None)
@abstractmethod
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
"""
Get the number of common prefix blocks for a request.
Args:
request_id: The request ID.
block_hashes: The block hashes of the request.
Returns:
The number of common prefix blocks.
"""
raise NotImplementedError
@classmethod
@abstractmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. The prefix should be a common prefix hit for all the
kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
return an empty list.
If eagle is enabled, drop the last matched block to force recompute the
last block to get the required hidden states for eagle drafting head.
Need to be customized for each attention type.
Args:
block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.
kv_cache_group_ids: The ids of the kv cache groups.
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
Returns:
A list of cached blocks with skipped blocks replaced by null block
for each kv cache group in `kv_cache_group_ids`.
Return a list of length `len(kv_cache_group_ids)`, where the i-th
element is a list of cached blocks for the i-th kv cache group
in `kv_cache_group_ids`.
For example, sliding window manager should return a list like
([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
and sliding window 8 and len(kv_cache_group_ids) = 1.
"""
raise NotImplementedError
@abstractmethod
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks` and free the
blocks. The removed blocks should be replaced by null_block.
Need to be customized for each attention type.
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
"""
raise NotImplementedError
class FullAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, FullAttentionSpec), (
"FullAttentionManager can only be used for full attention groups")
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids)))
max_num_blocks = max_length // kv_cache_spec.block_size
for i, block_hash in zip(range(max_num_blocks), block_hashes):
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids):
for computed, cached in zip(computed_blocks, cached_block):
computed.append(cached)
else:
break
if use_eagle and computed_blocks[0]:
for computed in computed_blocks:
computed.pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# No need to remove blocks for full attention.
pass
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
blocks = self.req_to_blocks[request_id]
num_common_blocks = 0
for block in blocks:
if block.ref_cnt == num_running_requests:
num_common_blocks += 1
else:
break
return num_common_blocks
class SlidingWindowManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
**kwargs) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
self.sliding_window = kv_cache_spec.sliding_window
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups")
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
sliding_window_contiguous_blocks = cdiv(
kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size)
if use_eagle:
# Need to drop the last matched block if eagle is enabled. For
# sliding window layer, we achieve this by increasing the number of
# contiguous blocks needed for prefix cache hit by one and dropping
# the last matched block.
sliding_window_contiguous_blocks += 1
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(max_num_blocks) to
# O(max_num_blocks / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
max_num_blocks = max_length // kv_cache_spec.block_size
computed_blocks = tuple([block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids)))
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids):
for computed, cached in zip(computed_blocks, cached_block):
computed[i] = cached
num_contiguous_blocks += 1
if num_contiguous_blocks >= sliding_window_contiguous_blocks:
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
for computed in computed_blocks:
del computed[i + num_contiguous_blocks:]
match_found = True
break
else:
num_contiguous_blocks = 0
if not match_found:
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for computed in computed_blocks:
del computed[num_contiguous_blocks:]
if use_eagle and computed_blocks[0]:
for computed in computed_blocks:
computed.pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token = num_computed_tokens - self.sliding_window + 1
last_useful_block = last_useful_token // self.block_size
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
for i in range(last_useful_block - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
"""
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
So it's not correct to count ref_cnt like FullAttentionManager. Return
0 here for correctness. Need to support cascade attention + sliding
window in the future.
"""
return 0
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
}
def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec,
**kwargs) -> SingleTypeKVCacheManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
manager = manager_class(kv_cache_spec, **kwargs)
return manager

173
vllm/v1/engine/__init__.py Normal file
View File

@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
from collections.abc import Sequence
from typing import Any, Optional, Union
import msgspec
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort")
class FinishReason(enum.IntEnum):
"""
Reason a request finished - stop, length, or abort.
Int rather than Str for more compact serialization.
stop - a stop string was emitted
length - max_tokens was consumed, or max_model_len was reached
abort - aborted for another reason
"""
STOP = 0
LENGTH = 1
ABORT = 2
def __str__(self):
return FINISH_REASON_STRINGS[self.value]
class EngineCoreRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
prompt_token_ids: list[int]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: SamplingParams
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]
cache_salt: Optional[str]
data_parallel_rank: Optional[int]
# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.
client_index: int = 0
# Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before
# a wave finished notification is received.
current_wave: int = 0
class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
QUEUED = 1
SCHEDULED = 2
PREEMPTED = 3
class EngineCoreEvent(msgspec.Struct):
"""A timestamped engine core event associated with a request.
The timestamp is a monotonic timestamps and is used for by the engine
frontend to calculate intervals between engine core events. These
timestamps should not be compared with timestamps from other processes.
"""
type: EngineCoreEventType
timestamp: float
@classmethod
def new_event(cls,
event_type: EngineCoreEventType,
timestamp: Optional[float] = None) -> "EngineCoreEvent":
timestamp = time.monotonic() if timestamp is None else timestamp
return cls(event_type, timestamp)
class EngineCoreOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
new_token_ids: list[int]
new_logprobs: Optional[LogprobsLists] = None
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None
# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0
@property
def finished(self) -> bool:
return self.finish_reason is not None
class UtilityOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
call_id: int
# Non-None implies the call failed, result should be None.
failure_message: Optional[str] = None
result: Any = None
class EngineCoreOutputs(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout
engine_index: int = 0
# [num_reqs]
outputs: list[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0
utility_output: Optional[UtilityOutput] = None
finished_requests: Optional[set[str]] = None
# In DP case, used to signal that the current wave of requests
# has finished and the engines are paused.
wave_complete: Optional[int] = None
# In DP case, used to signal that a request was received for an
# "old" wave, so the next wave needs to be started in other engines.
start_wave: Optional[int] = None
def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.monotonic()
class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
ADD = b'\x00'
ABORT = b'\x01'
START_DP_WAVE = b'\x02'
UTILITY = b'\x03'
# Sentinel used within EngineCoreProc.
EXECUTOR_FAILED = b'\x04'

558
vllm/v1/engine/async_llm.py Normal file
View File

@@ -0,0 +1,558 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Any, Optional, Union
import numpy as np
import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
setup_default_loggers)
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
logger = init_logger(__name__)
class AsyncLLM(EngineClient):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> None:
"""
Create an AsyncLLM.
Args:
vllm_config: global configuration.
executor_class: an Executor impl, e.g. MultiprocExecutor.
log_stats: Whether to log stats.
usage_context: Usage context of the LLM.
mm_registry: Multi-modal registry.
use_cached_outputs: Whether to use cached outputs.
log_requests: Whether to log requests.
start_engine_loop: Whether to start the engine loop.
stat_loggers: customized stat loggers for the engine.
If not provided, default stat loggers will be used.
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
Returns:
None
"""
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
# Ensure we can serialize custom transformer configs
maybe_register_config_serialize_by_value()
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.log_requests = log_requests
self.log_stats = log_stats
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
vllm_config=vllm_config,
log_stats=self.log_stats,
engine_num=vllm_config.parallel_config.data_parallel_size,
custom_stat_loggers=stat_loggers,
)
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
vllm_config=vllm_config,
tokenizer=self.tokenizer,
mm_registry=mm_registry,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
client_addresses=client_addresses,
client_index=client_index,
)
if self.stat_loggers:
for stat_logger in self.stat_loggers[0]:
stat_logger.log_engine_initialized()
self.output_handler: Optional[asyncio.Task] = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
asyncio.get_running_loop()
self._run_output_handler()
except RuntimeError:
pass
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> "AsyncLLM":
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
# Create the LLMEngine.
return cls(
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop,
stat_loggers=stat_loggers,
log_requests=not disable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
client_addresses=client_addresses,
client_index=client_index,
)
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
# Create the AsyncLLM.
return cls(
vllm_config=vllm_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
def __del__(self):
self.shutdown()
def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC."""
shutdown_prometheus()
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
if handler := getattr(self, "output_handler", None):
handler.cancel()
async def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""
if self.errored:
raise EngineDeadError()
assert isinstance(params, SamplingParams), \
"Pooling is not supported in V1"
# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)
# Convert Input --> Request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority, data_parallel_rank)
if params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue)
return queue
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, params)
for idx in range(params.n):
request_id, params = parent_request.get_child_info(idx)
child_request = request if idx == params.n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
await self._add_request(child_request, prompt_str, parent_request,
idx, queue)
return queue
async def _add_request(self, request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest], index: int,
queue: RequestOutputCollector):
# Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, prompt, parent_req, index,
queue)
# Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
* 2) Processing the Input.
* 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process).
A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the
per-request AsyncStream.
The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller.
"""
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
q = await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() or await q.get()
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
yield out
# If the request is disconnected by the client, generate()
# is cancelled or the generator is garbage collected. So,
# we abort the request if we end up here.
except (asyncio.CancelledError, GeneratorExit):
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
# Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
if self.output_handler is not None:
return
# Ensure that the task doesn't have a circular ref back to the AsyncLLM
# object, or else it won't be garbage collected and cleaned up properly.
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
stat_loggers = self.stat_loggers if log_stats else None
async def output_handler():
try:
while True:
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await engine_core.get_output_async()
num_outputs = len(outputs.outputs)
iteration_stats = IterationStats() if (
log_stats and num_outputs) else None
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
else:
slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
for i, outputs_slice in enumerate(slices):
# 2) Process EngineCoreOutputs.
processed_outputs = output_processor.process_outputs(
outputs_slice, outputs.timestamp, iteration_stats)
# NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs
# Allow other asyncio tasks to run between chunks
if i + 1 < len(slices):
await asyncio.sleep(0)
# 3) Abort any reqs that finished due to stop strings.
await engine_core.abort_requests_async(
processed_outputs.reqs_to_abort)
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if stat_loggers:
AsyncLLM._record_stats(
stat_loggers[outputs.engine_index],
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
except Exception as e:
logger.exception("AsyncLLM output_handler failed.")
output_processor.propagate_error(e)
self.output_handler = asyncio.create_task(output_handler())
async def abort(self, request_id: str) -> None:
"""Abort RequestId in OutputProcessor and EngineCore."""
request_ids = self.output_processor.abort_requests((request_id, ))
await self.engine_core.abort_requests_async(request_ids)
if self.log_requests:
logger.info("Aborted request %s.", request_id)
@staticmethod
def _record_stats(
stat_loggers: list[StatLoggerBase],
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
):
"""static so that it can be used from the output_handler task
without a circular ref to AsyncLLM."""
for stat_logger in stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
):
raise ValueError("Not Supported on V1 yet.")
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def get_decoding_config(self):
raise ValueError("Not Supported on V1 yet.")
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.processor.input_preprocessor
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.tokenizer.get_lora_tokenizer(lora_request)
async def is_tracing_enabled(self) -> bool:
return False
async def do_log_stats(
self,
scheduler_outputs=None,
model_output=None,
) -> None:
for loggers in self.stat_loggers:
for stat_logger in loggers:
stat_logger.log()
async def check_health(self) -> None:
logger.debug("Called check_health.")
async def start_profile(self) -> None:
await self.engine_core.profile_async(True)
async def stop_profile(self) -> None:
await self.engine_core.profile_async(False)
async def reset_mm_cache(self) -> None:
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_input_cache_client.reset()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
if device == Device.CPU:
raise ValueError("Not supported on CPU.")
await self.engine_core.reset_prefix_cache_async()
async def sleep(self, level: int = 1) -> None:
await self.engine_core.sleep_async(level)
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
await self.engine_core.wake_up_async(tags)
async def is_sleeping(self) -> bool:
return await self.engine_core.is_sleeping_async()
async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
return await self.engine_core.add_lora_async(lora_request)
async def remove_lora(self, lora_id: int) -> bool:
"""Remove an already loaded LoRA adapter."""
return await self.engine_core.remove_lora_async(lora_id)
async def list_loras(self) -> set[int]:
"""List all registered adapters."""
return await self.engine_core.list_loras_async()
async def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted."""
return await self.engine_core.pin_lora_async(lora_id)
async def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
"""
Perform a collective RPC call to the given path.
"""
return await self.engine_core.collective_rpc_async(
method, timeout, args, kwargs)
@property
def is_running(self) -> bool:
# Is None before the loop is started.
return self.output_handler is None or not self.output_handler.done()
@property
def is_stopped(self) -> bool:
return self.errored
@property
def errored(self) -> bool:
return self.engine_core.resources.engine_dead or not self.is_running
@property
def dead_error(self) -> BaseException:
return EngineDeadError()

View File

@@ -0,0 +1,253 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import time
import weakref
from typing import Optional
import msgspec.msgpack
import zmq
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
from vllm.v1.serial_utils import MsgpackDecoder
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
logger = init_logger(__name__)
class DPCoordinator:
"""Coordinator process used for data-parallel deployments (DP>1).
Intermediates between multiple DP engine rank processes and one or more
front-end API server processes.
* Collects stats from each DP engine (currently just waiting and running
queue lengths), and publishes these to all front-ends for use in
load-balancing decisions.
* Keeps track of the current DP "request wave" number and running state
of the engines. This is received from the DP rank 0 engine and published
to the front-end processes along with the current load stats.
The engines alternate between a global running/paused state. The global
"request wave" number is a count of the number of times that the workers
collectively move from a running state to a paused state. This transition
is synchronized via the all-reduce operation performed in the
DPEngineCoreProc._has_global_unfinished_reqs method.
* Broadcasts the START_DP_WAVE message to engines to move them from paused
to running state when one engine receives a new request. This can happen
in two cases:
1) A front-end sending a new request while the engines are paused will
concurrently notify the coordinator.
2) An engine receiving a request for a stale request wave while in paused
state will notify the coordinator.
Engines will move into running state when receiving a new request or
START_DP_WAVE message.
"""
def __init__(self, parallel_config: ParallelConfig):
# Assume coordinator is colocated with front-end procs.
front_publish_address = get_open_zmq_ipc_path()
dp_size = parallel_config.data_parallel_size
assert dp_size > 1, "Coordinator only used for data parallel"
local_only = dp_size == parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
back_publish_address = get_engine_client_zmq_addr(local_only, host)
back_output_address = get_engine_client_zmq_addr(local_only, host)
context = get_mp_context()
self.proc: multiprocessing.Process = context.Process(
target=CoordinatorProc.run_coordinator,
name="VLLM_DP_Coordinator",
kwargs={
"engine_count": parallel_config.data_parallel_size,
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
},
daemon=True)
self.proc.start()
self.stats_publish_address = front_publish_address
self.coord_in_address = back_publish_address
self.coord_out_address = back_output_address
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
def get_stats_publish_address(self) -> str:
return self.stats_publish_address
def get_engine_socket_addresses(self) -> tuple[str, str]:
"""Returns tuple of ZMQ input address, output address."""
return self.coord_in_address, self.coord_out_address
def close(self):
self._finalizer()
class EngineState:
def __init__(self):
self.request_counts = [0, 0] # [waiting, running]
class CoordinatorProc:
def __init__(self, engine_count: int):
self.ctx = zmq.Context()
self.engines = [EngineState() for _ in range(engine_count)]
self.current_wave = 0
self.engines_running = False
self.stats_changed = False
@staticmethod
def run_coordinator(
engine_count: int,
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
):
coordinator = CoordinatorProc(engine_count=engine_count)
try:
coordinator.process_input_socket(
front_publish_address,
back_output_address,
back_publish_address,
)
except KeyboardInterrupt:
logger.info("DP Coordinator process exiting")
def process_input_socket(self, front_publish_address: str,
back_output_address: str,
back_publish_address: str):
decoder = MsgpackDecoder(EngineCoreOutputs)
with make_zmq_socket(
path=front_publish_address, # IPC
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_front, make_zmq_socket(
path=back_output_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.PULL,
bind=True,
) as output_back, make_zmq_socket(
path=back_publish_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_back:
poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN)
last_publish_time = 0
while True:
elapsed = int(time.time() * 1000) - last_publish_time
# Send at 100 ms interval if the stats have changed,
# or otherwise every 3 seconds.
wait_for = 100 if self.stats_changed else 3000
events = poller.poll(timeout=max(0, wait_for - elapsed))
if not events:
# Poller timeout - publish current stats to front-ends.
engine_req_counts_list = self._get_engine_counts()
to_publish = (engine_req_counts_list, self.current_wave,
self.engines_running)
publish_front.send(msgspec.msgpack.encode(to_publish))
last_publish_time = int(time.time() * 1000)
self.stats_changed = False
continue
events = dict(events)
if publish_front in events:
buffer = publish_front.recv()
if buffer == b'\x01':
# Ignore subscription messages.
continue
# We received a message on the front-end XPUB socket,
# from an API server sending a new request while the
# engines are paused, so that we can wake the other
# engines.
engine_to_exclude, wave = msgspec.msgpack.decode(buffer)
if wave < self.current_wave:
# If the wave number is stale, ensure the message is
# handled by all the engines.
engine_to_exclude = None
if not self.engines_running:
self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, self.current_wave,
engine_to_exclude)
if output_back in events:
# We received a message from one of the engines.
buffer = output_back.recv()
outputs: EngineCoreOutputs = decoder.decode(buffer)
assert not outputs.outputs
assert outputs.utility_output is None
eng_index = outputs.engine_index
if outputs.scheduler_stats:
# 1. Updated request load stats - update our local
# state with these.
stats = self.engines[eng_index].request_counts
stats[0] = outputs.scheduler_stats.num_waiting_reqs
stats[1] = outputs.scheduler_stats.num_running_reqs
self.stats_changed = True
if (wave := outputs.wave_complete) is not None:
# 2. Notification from rank 0 engine that we've
# moved into the global paused state
# (engines_running==False)
if self.current_wave <= wave:
logger.debug("Moving DP wave from %d to %d.",
self.current_wave, wave)
self.current_wave = wave + 1
self.engines_running = False
self.stats_changed = True
elif (wave := outputs.start_wave) is not None and (
wave > self.current_wave or
(wave == self.current_wave
and not self.engines_running)):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger.debug(
"Starting wave %d after notification of "
"stale wave request from engine.", wave)
self.current_wave = wave
self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, wave, eng_index)
@staticmethod
def _send_start_wave(socket: zmq.Socket, wave: int,
exclude_engine_index: Optional[int]):
"""Broadcast the START_DP_WAVE message to all the engines.
It includes the current wave number and index of engine which
has already received a request with this wave number and so doesn't
require additional notification.
"""
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
socket.send_multipart(
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
def _get_engine_counts(self) -> list[list[int]]:
"""Return list of [waiting, running] count lists for each engine."""
return [e.request_counts for e in self.engines]

962
vllm/v1/engine/core.py Normal file
View File

@@ -0,0 +1,962 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import queue
import signal
import sys
import threading
import time
from collections import deque
from collections.abc import Generator
from concurrent.futures import Future
from contextlib import ExitStack, contextmanager
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, Callable, Optional, TypeVar, Union
import msgspec
import zmq
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5
HANDSHAKE_TIMEOUT_MINS = 5
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCore:
"""Inner loop of vLLM's Engine."""
def __init__(self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
executor_fail_callback: Optional[Callable] = None):
assert vllm_config.model_config.runner_type != "pooling"
# plugins need to be loaded at the engine/scheduler level too
from vllm.plugins import load_general_plugins
load_general_plugins()
self.vllm_config = vllm_config
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
self.log_stats = log_stats
# Setup Model.
self.model_executor = executor_class(vllm_config)
if executor_fail_callback is not None:
self.model_executor.register_failure_callback(
executor_fail_callback)
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
self._initialize_kv_caches(vllm_config)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler.
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = vllm_config.scheduler_config.scheduler_cls
# This warning can be removed once the V1 Scheduler interface is
# finalized and we can maintain support for scheduler classes that
# implement it
if Scheduler is not V1Scheduler:
logger.warning(
"Using configured V1 scheduler class %s. "
"This scheduler interface is not public and "
"compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls)
self.scheduler: SchedulerInterface = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
log_stats=self.log_stats,
)
# Setup MM Input Mapper.
self.mm_input_cache_server = MirroredProcessingCache(
vllm_config.model_config)
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
# schedule and execute batches, and is required by pipeline parallelism
# to eliminate pipeline bubbles.
self.batch_queue_size = self.model_executor.max_concurrent_batches
self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
SchedulerOutput]]] = None
if self.batch_queue_size > 1:
logger.info("Batch queue is enabled with size %d",
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time()
# Get all kv cache needed by the model
kv_cache_specs = self.model_executor.get_kv_cache_specs()
# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
available_gpu_memory = self.model_executor.determine_available_memory()
assert len(kv_cache_specs) == len(available_gpu_memory)
# Get the kv cache tensor size
kv_cache_configs = [
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
available_gpu_memory_one_worker)
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
zip(kv_cache_specs, available_gpu_memory)
]
# Since we use a shared centralized controller, we need the
# `kv_cache_config` to be consistent across all workers to make sure
# all the memory operators can be applied to all workers.
unify_kv_cache_configs(kv_cache_configs)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to initialize the scheduler.
assert all([
cfg.num_blocks == kv_cache_configs[0].num_blocks
for cfg in kv_cache_configs
])
num_gpu_blocks = kv_cache_configs[0].num_blocks
num_cpu_blocks = 0
scheduler_kv_cache_config = kv_cache_configs[0]
# Initialize kv cache and warmup the execution
self.model_executor.initialize_from_config(kv_cache_configs)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""
if request.mm_hashes is not None:
# Here, if hash exists for a multimodal input, then it will be
# fetched from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client cache, so
# anything that has a hash must have a HIT cache entry here
# as well.
assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
if req.use_structured_output:
# Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req)
if req.kv_transfer_params is not None and (
not self.scheduler.get_kv_connector()):
logger.warning("Got kv_transfer_params, but no KVConnector found. "
"Disabling KVTransfer for this request.")
self.scheduler.add_request(req)
def abort_requests(self, request_ids: list[str]):
"""Abort requests from the scheduler."""
# TODO: The scheduler doesn't really need to know the
# specific finish reason, TBD whether we propagate that
# (i.e. client-aborted vs stop criteria met).
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)
def execute_model(self, scheduler_output: SchedulerOutput):
try:
return self.model_executor.execute_model(scheduler_output)
except BaseException as err:
# NOTE: This method is exception-free
dump_engine_exception(self.vllm_config, scheduler_output,
self.scheduler.make_stats())
# Re-raise exception
raise err
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
model_output = self.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore
return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0)
def step_with_batch_queue(
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
"""Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned.
The execution flow is as follows:
1. Try to schedule a new batch if the batch queue is not full.
If a new batch is scheduled, directly return an empty engine core
output. In other words, fulfilling the batch queue has a higher priority
than getting model outputs.
2. If there is no new scheduled batch, meaning that the batch queue
is full or no other requests can be scheduled, we block until the first
batch in the job queue is finished.
3. Update the scheduler from the output.
"""
assert self.batch_queue is not None
engine_core_outputs = None
scheduler_output = None
# Try to schedule a new batch if the batch queue is not full, but
# the scheduler may return an empty batch if all requests are scheduled.
# Note that this is not blocking.
if not self.batch_queue.full():
scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0:
future = self.model_executor.execute_model(scheduler_output)
self.batch_queue.put_nowait(
(future, scheduler_output)) # type: ignore
scheduled_batch = (scheduler_output is not None
and scheduler_output.total_num_scheduled_tokens > 0)
# If no more requests can be scheduled and the job queue is not empty,
# block until the first batch in the job queue is finished.
# TODO(comaniac): Ideally we should peek the first batch in the
# job queue to check if it's finished before scheduling a new batch,
# but peeking the first element in a queue is not thread-safe,
# so we need more work.
if not scheduled_batch and not self.batch_queue.empty():
future, scheduler_output = self.batch_queue.get_nowait()
# Blocking until the first result is available.
model_output = future.result()
self.batch_queue.task_done()
engine_core_outputs = (self.scheduler.update_from_output(
scheduler_output, model_output))
return engine_core_outputs, scheduled_batch
def shutdown(self):
self.structured_output_manager.clear_backend()
if self.model_executor:
self.model_executor.shutdown()
if self.scheduler:
self.scheduler.shutdown()
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
def reset_mm_cache(self):
# NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
if self.scheduler.has_unfinished_requests():
logger.warning("Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches.")
self.mm_input_cache_server.reset()
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()
def sleep(self, level: int = 1):
self.model_executor.sleep(level)
def wake_up(self, tags: Optional[list[str]] = None):
self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping
def execute_dummy_batch(self):
self.model_executor.collective_rpc("execute_dummy_batch")
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_executor.remove_lora(lora_id)
def list_loras(self) -> set[int]:
return self.model_executor.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.model_executor.save_sharded_state(path=path,
pattern=pattern,
max_size=max_size)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
def save_tensorized_model(
self,
tensorizer_config,
) -> None:
self.model_executor.save_tensorized_model(
tensorizer_config=tensorizer_config, )
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'
def __init__(
self,
vllm_config: VllmConfig,
on_head_node: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
engine_index: int = 0,
):
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
bytes]]()
executor_fail_callback = lambda: self.input_queue.put_nowait(
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
self.engine_index = engine_index
identity = self.engine_index.to_bytes(length=2, byteorder="little")
self.engines_running = False
with self._perform_handshake(handshake_address, identity, on_head_node,
vllm_config) as addresses:
self.client_count = len(addresses.outputs)
# Set up data parallel environment.
self.has_coordinator = addresses.coordinator_output is not None
self._init_data_parallel(vllm_config)
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs, addresses.coordinator_input,
identity),
daemon=True).start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
self.engine_index),
daemon=True)
self.output_thread.start()
@contextmanager
def _perform_handshake(
self, handshake_address: str, identity: bytes, on_head_node: bool,
vllm_config: VllmConfig
) -> Generator[EngineZmqAddresses, None, None]:
input_ctx = zmq.Context()
with make_zmq_socket(input_ctx,
handshake_address,
zmq.DEALER,
identity=identity,
linger=5000,
bind=False) as handshake_socket:
# Register engine with front-end.
addresses = self.startup_handshake(handshake_socket, on_head_node,
vllm_config.parallel_config)
# Update config which may have changed from the handshake
vllm_config.__post_init__()
yield addresses
# Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
handshake_socket.send(
msgspec.msgpack.encode({
"status": "READY",
"local": on_head_node,
"num_gpu_blocks": num_gpu_blocks,
}))
@staticmethod
def startup_handshake(
handshake_socket: zmq.Socket, on_head_node: bool,
parallel_config: ParallelConfig) -> EngineZmqAddresses:
# Send registration message.
handshake_socket.send(
msgspec.msgpack.encode({
"status": "HELLO",
"local": on_head_node,
}))
# Receive initialization message.
logger.info("Waiting for init message from front-end.")
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
raise RuntimeError("Did not receive response from front-end "
f"process within {HANDSHAKE_TIMEOUT_MINS} "
f"minutes")
init_bytes = handshake_socket.recv()
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
init_bytes, type=EngineHandshakeMetadata)
logger.debug("Received init message: %s", init_message)
received_parallel_config = init_message.parallel_config
for key, value in received_parallel_config.items():
setattr(parallel_config, key, value)
return init_message.addresses
@staticmethod
def run_engine_core(*args,
dp_rank: int = 0,
local_dp_rank: int = 0,
**kwargs):
"""Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
# Ensure we can serialize transformer config after spawning
maybe_register_config_serialize_by_value()
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the engine_core
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
engine_core: Optional[EngineCoreProc] = None
try:
parallel_config: ParallelConfig = kwargs[
"vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop()
except SystemExit:
logger.debug("EngineCore exiting.")
raise
except Exception as e:
if engine_core is None:
logger.exception("EngineCore failed to start.")
else:
logger.exception("EngineCore encountered a fatal error.")
engine_core._send_engine_dead()
raise e
finally:
if engine_core is not None:
engine_core.shutdown()
def _init_data_parallel(self, vllm_config: VllmConfig):
pass
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
waited = False
while not self.engines_running and not self.scheduler.has_requests():
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.")
waited = True
req = self.input_queue.get()
self._handle_client_request(*req)
if waited:
logger.debug("EngineCore loop active.")
# Handle any more client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
def _process_engine_step(self) -> bool:
"""Called only when there are unfinished local requests."""
# Step the engine core.
outputs, model_executed = self.step_fn()
# Put EngineCoreOutputs into the output queue.
for output in (outputs.items() if outputs else ()):
self.output_queue.put_nowait(output)
return model_executed
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
"""Dispatch request from client."""
if request_type == EngineCoreRequestType.ADD:
self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.UTILITY:
client_idx, call_id, method_name, args = request
output = UtilityOutput(call_id)
try:
method = getattr(self, method_name)
output.result = method(
*self._convert_msgspec_args(method, args))
except BaseException as e:
logger.exception("Invocation of %s method failed", method_name)
output.failure_message = (f"Call to {method_name} method"
f" failed: {str(e)}")
self.output_queue.put_nowait(
(client_idx, EngineCoreOutputs(utility_output=output)))
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
raise RuntimeError("Executor failed.")
else:
logger.error("Unrecognized input request type encountered: %s",
request_type)
@staticmethod
def _convert_msgspec_args(method, args):
"""If a provided arg type doesn't match corresponding target method
arg type, try converting to msgspec object."""
if not args:
return args
arg_types = signature(method).parameters.values()
assert len(args) <= len(arg_types)
return tuple(
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
and issubclass(p.annotation, msgspec.Struct)
and not isinstance(v, p.annotation) else v
for v, p in zip(args, arg_types))
def _send_engine_dead(self):
"""Send EngineDead status to the EngineCoreClient."""
# Put ENGINE_CORE_DEAD in the queue.
self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)
# Wait until msg sent by the daemon before shutdown.
self.output_thread.join(timeout=5.0)
if self.output_thread.is_alive():
logger.fatal("vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue.")
def process_input_sockets(self, input_addresses: list[str],
coord_input_address: Optional[str],
identity: bytes):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
with ExitStack() as stack, zmq.Context() as ctx:
input_sockets = [
stack.enter_context(
make_zmq_socket(ctx,
input_address,
zmq.DEALER,
identity=identity,
bind=False))
for input_address in input_addresses
]
if coord_input_address is None:
coord_socket = None
else:
coord_socket = stack.enter_context(
make_zmq_socket(ctx,
coord_input_address,
zmq.XSUB,
identity=identity,
bind=False))
# Send subscription message to coordinator.
coord_socket.send(b'\x01')
# Register sockets with poller.
poller = zmq.Poller()
for input_socket in input_sockets:
# Send initial message to each input socket - this is required
# before the front-end ROUTER socket can send input messages
# back to us.
input_socket.send(b'')
poller.register(input_socket, zmq.POLLIN)
if coord_socket is not None:
poller.register(coord_socket, zmq.POLLIN)
while True:
for input_socket, _ in poller.poll():
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(
copy=False)
request_type = EngineCoreRequestType(
bytes(type_frame.buffer))
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str],
engine_index: int):
"""Output socket IO thread."""
# Msgpack serialization encoding.
encoder = MsgpackEncoder()
# Send buffers to reuse.
reuse_buffers: list[bytearray] = []
# Keep references to outputs and buffers until zmq is finished
# with them (outputs may contain tensors/np arrays whose
# backing buffers were extracted for zero-copy send).
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
# We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket.
with ExitStack() as stack, zmq.Context() as ctx:
sockets = [
stack.enter_context(
make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
for output_path in output_paths
]
coord_socket = stack.enter_context(
make_zmq_socket(
ctx, coord_output_path, zmq.PUSH, bind=False,
linger=4000)) if coord_output_path is not None else None
max_reuse_bufs = len(sockets) + 1
while True:
output = self.output_queue.get()
if output == EngineCoreProc.ENGINE_CORE_DEAD:
for socket in sockets:
socket.send(output)
break
assert not isinstance(output, bytes)
client_index, outputs = output
outputs.engine_index = engine_index
if client_index == -1:
# Don't reuse buffer for coordinator message
# which will be very small.
assert coord_socket is not None
coord_socket.send_multipart(encoder.encode(outputs))
continue
# Reclaim buffers that zmq is finished with.
while pending and pending[-1][0].done:
reuse_buffers.append(pending.pop()[2])
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer)
tracker = sockets[client_index].send_multipart(buffers,
copy=False,
track=True)
if not tracker.done:
ref = outputs if len(buffers) > 1 else None
pending.appendleft((tracker, ref, buffer))
elif len(reuse_buffers) < max_reuse_bufs:
# Limit the number of buffers to reuse.
reuse_buffers.append(buffer)
class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
def __init__(
self,
vllm_config: VllmConfig,
on_head_node: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
):
self._decorate_logs()
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
self.current_wave = 0
self.last_counts = (0, 0)
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, on_head_node, handshake_address,
executor_class, log_stats, dp_rank)
def _decorate_logs(self):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from multiprocessing import current_process
process_name = current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
def _init_data_parallel(self, vllm_config: VllmConfig):
# Configure GPUs and stateless process group for data parallel.
dp_rank = vllm_config.parallel_config.data_parallel_rank
dp_size = vllm_config.parallel_config.data_parallel_size
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1
assert 0 <= local_dp_rank <= dp_rank < dp_size
if vllm_config.kv_transfer_config is not None:
# modify the engine_id and append the local_dp_rank to it to ensure
# that the kv_transfer_config is unique for each DP rank.
vllm_config.kv_transfer_config.engine_id = (
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
)
logger.debug("Setting kv_transfer_config.engine_id to %s",
vllm_config.kv_transfer_config.engine_id)
from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var
world_size = vllm_config.parallel_config.world_size
os.environ[device_control_env_var] = ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
world_size))
os.environ["MACA_VISIBLE_DEVICES"] = os.environ[device_control_env_var]
self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
def shutdown(self):
super().shutdown()
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
def add_request(self, request: EngineCoreRequest):
if self.has_coordinator and request.current_wave != self.current_wave:
if request.current_wave > self.current_wave:
self.current_wave = request.current_wave
elif not self.engines_running:
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
super().add_request(request)
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
if request_type == EngineCoreRequestType.START_DP_WAVE:
new_wave, exclude_eng_index = request
if exclude_eng_index != self.engine_index and (
new_wave >= self.current_wave):
self.current_wave = new_wave
if not self.engines_running:
logger.debug("EngineCore starting idle loop for wave %d.",
new_wave)
self.engines_running = True
else:
super()._handle_client_request(request_type, request)
def _maybe_publish_request_counts(self):
if not self.has_coordinator:
return
# Publish our request counts (if they've changed).
counts = self.scheduler.get_request_counts()
if counts != self.last_counts:
self.last_counts = counts
stats = SchedulerStats(*counts)
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(scheduler_stats=stats)))
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core.
executed = self._process_engine_step()
self._maybe_publish_request_counts()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if not executed:
if not local_unfinished_reqs and not self.engines_running:
# All engines are idle.
continue
# We are in a running state and so must execute a dummy pass
# if the model didn't execute any ready requests.
self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs.
self.engines_running = self._has_global_unfinished_reqs(
local_unfinished_reqs)
if not self.engines_running:
if self.dp_rank == 0:
# Notify client that we are pausing the loop.
logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave)
self.output_queue.put_nowait(
(-1,
EngineCoreOutputs(wave_complete=self.current_wave)))
self.current_wave += 1
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 24 steps.
self.counter += 1
if self.counter != 24:
return True
self.counter = 0
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)
class DPEngineCoreActor(DPEngineCoreProc):
"""
Ray actor for running EngineCore in a data parallel context
"""
def __init__(
self,
vllm_config: VllmConfig,
on_head_node: bool,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
dp_rank: int = 0,
local_dp_rank: int = 0,
):
self.addresses = addresses
vllm_config.parallel_config.data_parallel_rank = dp_rank
vllm_config.parallel_config.data_parallel_rank_local = \
local_dp_rank
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
# we clean this up to be able to properly initialize
# data parallel groups.
# del os.environ['CUDA_VISIBLE_DEVICES']
super().__init__(vllm_config, on_head_node, "", executor_class,
log_stats)
def _decorate_logs(self):
pass
@contextmanager
def _perform_handshake(self, handshake_address: str, identity: bytes,
on_head_node: bool, vllm_config: VllmConfig):
"""
For Ray, we don't need to actually perform handshake.
All addresses information is known before the actor creation.
Therefore, we simply yield these addresses.
"""
yield self.addresses
def wait_for_init(self):
"""
Wait until the engine core is initialized.
This is just an empty method. When ray.get() on this method
(or any other method of the actor) returns, it is guaranteed
that actor creation (i.e., __init__) is complete.
"""
pass
def run(self):
"""
Run the engine core busy loop.
"""
try:
self.run_busy_loop()
except SystemExit:
logger.debug("EngineCore exiting.")
raise
except Exception:
logger.exception("EngineCore encountered a fatal error.")
raise
finally:
self.shutdown()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,286 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Optional
import tokenizers
from packaging import version
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
# Only tokenizers >= 0.21.1 supports DecodeStream used for
# FastIncrementalDetokenizer.
USE_FAST_DETOKENIZER = version.parse(
tokenizers.__version__) >= version.parse("0.21.1")
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
class IncrementalDetokenizer:
def __init__(self):
self.token_ids: list[int] = []
@property
def output_token_ids(self) -> list[int]:
return self.token_ids
def update(self, new_token_ids: list[int],
stop_terminated: bool) -> Optional[str]:
self.token_ids.extend(new_token_ids)
return None
def get_next_output_text(self, finished: bool, delta: bool) -> str:
return ""
@classmethod
def from_new_request(
cls,
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "IncrementalDetokenizer":
if tokenizer is None:
# No tokenizer => skipping detokenization.
return IncrementalDetokenizer()
if USE_FAST_DETOKENIZER and isinstance(tokenizer,
PreTrainedTokenizerFast):
# Fast tokenizer => use tokenizers library DecodeStream.
return FastIncrementalDetokenizer(tokenizer, request)
# Fall back to slow python-based incremental detokenization.
return SlowIncrementalDetokenizer(tokenizer, request)
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
def __init__(self, request: EngineCoreRequest):
super().__init__()
# Stop strings
params = request.sampling_params
self.stop = stop = params.stop
self.include_stop_str_in_output = params.include_stop_str_in_output
# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
if stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in stop) - 1
else:
self.stop_buffer_length = 0
self._last_output_text_offset: int = 0
# Generation data
self.output_text = ""
def update(self, new_token_ids: list[int],
stop_terminated: bool) -> Optional[str]:
"""
Update RequestState for the request_id by:
1) Detokenize the new token ids incrementally.
2) Evaluate stop criteria.
Return matched stop string or None.
"""
if not new_token_ids:
# Skip detokenization if no new token ids.
return None
if stop_terminated and not self.include_stop_str_in_output:
# If stop-terminated, exclude last token from detokenization
# based on include_stop_str_in_output parameter.
skipped_stop_token_id = new_token_ids[-1]
new_token_ids = new_token_ids[:-1]
else:
skipped_stop_token_id = None
# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
offset_before = len(self.output_text)
for new_token_id in new_token_ids:
self.token_ids.append(new_token_id)
self.output_text += self.decode_next(new_token_id)
if stop_terminated:
if skipped_stop_token_id is not None:
# Cleanup after skipping detokenization.
self.token_ids.append(skipped_stop_token_id)
# Stop token triggered; skip stop string check.
return None
# 2) Evaluate stop strings.
stop_string = None
if self.stop:
stop = StopChecker.check_stop_strings(
output_text=self.output_text,
new_char_count=len(self.output_text) - offset_before,
stop=self.stop,
include_in_output=self.include_stop_str_in_output,
)
if stop is not None:
stop_string, truncate_to = stop
if truncate_to != -1:
self.output_text = self.output_text[:truncate_to]
return stop_string
@abstractmethod
def decode_next(self, next_token_id: int) -> str:
raise NotImplementedError
def get_next_output_text(self, finished: bool, delta: bool) -> str:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished.
buffer_length = 0 if finished else self.stop_buffer_length
if not delta:
return self.output_text[:-buffer_length] if buffer_length else (
self.output_text)
length = len(self.output_text) - buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
return self.output_text[last_offset:length]
return ""
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: PreTrainedTokenizerFast,
request: EngineCoreRequest):
super().__init__(request)
sampling_params = request.sampling_params
self.request_id = request.request_id
self.skip_special_tokens = sampling_params.skip_special_tokens
self.stream = DecodeStream(
skip_special_tokens=self.skip_special_tokens)
self.tokenizer: Tokenizer = tokenizer._tokenizer
# Find a safe place to start.
prompt_suffix = request.prompt_token_ids
prompt_len = len(prompt_suffix)
if prompt_len > 4:
for i in range(4, min(prompt_len + 1, 24)):
suffix = request.prompt_token_ids[-i:]
if '<EFBFBD>' not in self.tokenizer.decode(suffix):
prompt_suffix = suffix
break
# Prime the stream.
for tid in prompt_suffix:
self._protected_step(tid)
self.spaces_between_special_tokens = (
sampling_params.skip_special_tokens
or sampling_params.spaces_between_special_tokens)
if not self.spaces_between_special_tokens:
# Store dict of added token ids so that we can suppress
# the spaces between them.
if (added_token_ids := getattr(self.tokenizer, "added_token_ids",
None)) is None:
self.tokenizer.added_token_ids = added_token_ids = {
tid: tok.content
for tid, tok in
self.tokenizer.get_added_tokens_decoder().items()
}
if added_token_ids:
self.last_special = False
self.added_token_ids = added_token_ids
else:
# No added tokens.
self.spaces_between_special_tokens = True
def decode_next(self, next_token_id: int) -> str:
token = self._protected_step(next_token_id)
if not self.spaces_between_special_tokens:
special_token = self.added_token_ids.get(next_token_id)
is_special = special_token is not None
if is_special and self.last_special:
# Return raw token string without any prefixed spaces.
token = special_token
self.last_special = is_special
return token or ""
def _protected_step(self, next_token_id: int) -> Optional[str]:
try:
token = self.stream.step(self.tokenizer, next_token_id)
except Exception as e:
if str(e) != INVALID_PREFIX_ERR_MSG:
raise e
# Recover from edge case where tokenizer can produce non-monotonic,
# invalid UTF-8 output, which breaks the internal state of
# tokenizers' DecodeStream.
# See https://github.com/vllm-project/vllm/issues/17448.
logger.warning(
"Encountered invalid prefix detokenization error"
" for request %s, resetting decode stream.", self.request_id)
self.stream = DecodeStream(self.skip_special_tokens)
token = self.stream.step(self.tokenizer, next_token_id)
return token
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
super().__init__(request)
self.tokenizer = tokenizer
# Metadata for incremental detokenization.
self.tokens, self.prefix_offset, self.read_offset = (
convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=request.sampling_params.
skip_special_tokens,
))
self.token_ids.extend(request.prompt_token_ids)
self.prompt_len = len(request.prompt_token_ids)
params = request.sampling_params
self.skip_special_tokens = params.skip_special_tokens
self.spaces_between_special_tokens = (
params.spaces_between_special_tokens)
@property
def output_token_ids(self) -> list[int]:
return self.token_ids if not self.prompt_len else (
self.token_ids[self.prompt_len:])
def decode_next(self, next_token_id: int) -> str:
new_tokens, decoded_text, prefix_offset, read_offset = (
detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=self.token_ids,
prev_tokens=self.tokens,
prefix_offset=self.prefix_offset,
read_offset=self.read_offset,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.
spaces_between_special_tokens,
))
self.tokens.extend(new_tokens)
self.prefix_offset = prefix_offset
self.read_offset = read_offset
return decoded_text

View File

@@ -0,0 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
class EngineGenerateError(Exception):
"""Raised when a AsyncLLM.generate() fails. Recoverable."""
pass
class EngineDeadError(Exception):
"""Raised when the EngineCore dies. Unrecoverable."""
def __init__(self, *args, suppress_context: bool = False, **kwargs):
ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501
super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs)
# Make stack trace clearer when using with LLMEngine by
# silencing irrelevant ZMQError.
self.__suppress_context__ = suppress_context

View File

@@ -0,0 +1,317 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from copy import copy
from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import (
TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class LLMEngine:
"""Legacy LLMEngine for backwards compatibility."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
) -> None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
if stat_loggers is not None:
raise NotImplementedError(
"Passing StatLoggers to LLMEngine in V1 is not yet supported. "
"Set VLLM_USE_V1=0 and file and issue on Github.")
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.log_stats = log_stats
self.stat_logger: Optional[StatLoggerBase] = None
if self.log_stats:
self.stat_logger = PrometheusStatLogger(vllm_config)
# important: init dp group before init the engine_core
# In the decoupled engine case this is handled in EngineCoreProc.
parallel_config = vllm_config.parallel_config
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
self.dp_group = parallel_config.stateless_init_dp_group()
else:
self.dp_group = None
self.should_execute_dummy_batch = False
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config,
tokenizer=self.tokenizer,
mm_registry=mm_registry)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(
multiprocess_mode=multiprocess_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)
if not multiprocess_mode:
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
# Don't keep the dummy data in memory
self.reset_mm_cache()
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
return cls(vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=(not disable_log_stats),
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
enable_multiprocessing: bool = False,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.")
enable_multiprocessing = True
# Create the LLMEngine.
return cls(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing)
def get_num_unfinished_requests(self) -> int:
return self.output_processor.get_num_unfinished_requests()
def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
if self.dp_group is None:
return has_unfinished
return self.has_unfinished_requests_dp(has_unfinished)
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
self.dp_group, has_unfinished)
if not has_unfinished and aggregated_has_unfinished:
self.should_execute_dummy_batch = True
return aggregated_has_unfinished
@classmethod
def validate_outputs(cls, outputs, output_type):
return outputs
def abort_request(self, request_ids: list[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
request_ids = self.output_processor.abort_requests(request_ids)
self.engine_core.abort_requests(request_ids)
def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
# Process raw inputs into the request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
# Make a new RequestState and queue.
self.output_processor.add_request(request, prompt_str, None, 0)
# Add the request to EngineCore.
self.engine_core.add_request(request)
return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
# Make a new RequestState and queue.
self.output_processor.add_request(child_request, prompt_str,
parent_req, idx)
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput]:
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
self.engine_core.execute_dummy_batch()
return []
# 1) Get EngineCoreOutput from the EngineCore.
outputs = self.engine_core.get_output()
# 2) Process EngineCoreOutputs.
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats)
# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
# 4) Record stats
if self.stat_logger is not None:
assert outputs.scheduler_stats is not None
self.stat_logger.record(scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats)
return processed_outputs.request_outputs
def get_vllm_config(self):
return self.vllm_config
def get_model_config(self):
return self.model_config
def start_profile(self):
self.engine_core.profile(True)
def stop_profile(self):
self.engine_core.profile(False)
def reset_mm_cache(self):
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_input_cache_client.reset()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, device: Optional[Device] = None):
self.engine_core.reset_prefix_cache()
def sleep(self, level: int = 1):
self.engine_core.sleep(level)
def wake_up(self, tags: Optional[list[str]] = None):
self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()
def get_metrics(self) -> list[Metric]:
assert self.log_stats, "Stat logging disabled"
return get_metrics_snapshot()
def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
return self.tokenizer
def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
return self.engine_core.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
"""Remove an already loaded LoRA adapter."""
return self.engine_core.remove_lora(lora_id)
def list_loras(self) -> set[int]:
"""List all registered adapters."""
return self.engine_core.list_loras()
def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def __del__(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)

199
vllm/v1/engine/logprobs.py Normal file
View File

@@ -0,0 +1,199 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional
from vllm.logger import init_logger
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_ids_list_to_tokens)
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
logger = init_logger(__name__)
NONES = itertools.repeat(None)
@dataclass
class LogprobsProcessor:
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: Optional[AnyTokenizer]
# Logprobs for this request
logprobs: Optional[SampleLogprobs]
prompt_logprobs: Optional[PromptLogprobs]
cumulative_logprob: Optional[float]
num_logprobs: Optional[int]
num_prompt_logprobs: Optional[int]
@classmethod
def from_new_request(
cls,
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "LogprobsProcessor":
num_logprobs = request.sampling_params.logprobs
num_prompt_logprobs = request.sampling_params.prompt_logprobs
return cls(
tokenizer=tokenizer,
cumulative_logprob=(None if num_logprobs is None else 0.),
logprobs=(None if num_logprobs is None else []),
# NOTE: logprob of first prompt token is None.
prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
num_prompt_logprobs=num_prompt_logprobs,
num_logprobs=num_logprobs,
)
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
"""Update with sample logprobs from EngineCore.
Outer lists are only of len > 1 if EngineCore made
>1 tokens in prior step (e.g. in spec decoding).
Args:
logprobs_lists: the lists of logprob tokens, logprobs, and ranks.
"""
assert self.num_logprobs is not None
assert self.logprobs is not None
assert self.cumulative_logprob is not None
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst,
token_ids_lst):
# Detokenize (non-incrementally).
decoded_tokens = NONES if self.tokenizer is None else (
convert_ids_list_to_tokens(self.tokenizer, token_ids))
# Sampler puts the sampled logprob in first.
sampled_token_logprob = logprobs[0]
self.cumulative_logprob += sampled_token_logprob
# Update with the Logprob dictionary for this pos.
self.logprobs.append(
self._make_logprob_dict(
logprobs,
token_ids,
decoded_tokens,
rank,
self.num_logprobs,
))
def _update_prompt_logprobs(
self,
prompt_logprobs_tensors: LogprobsTensors,
) -> None:
"""Update with prompt logprobs from EngineCore.
Args:
prompt_logprobs_tensors: tuple containing the prompt logprobs
tensors.
"""
# Prompt logprobs are enabled.
assert self.num_prompt_logprobs is not None
assert self.prompt_logprobs is not None
token_ids, logprobs, ranks = prompt_logprobs_tensors
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = None if self.tokenizer is None else (
convert_ids_list_to_tokens(self.tokenizer,
token_ids.flatten().tolist()))
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the torch tensors.
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
# Make Logprob for each position.
for pos in range(num_prompt_tokens):
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = NONES \
if decoded_tokens is None else decoded_tokens[offset:offset_end]
# Update with the Logprob dictionary for this pos.
self.prompt_logprobs.append(
self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
self.num_prompt_logprobs))
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
"""Pop and return all request prompt logprobs
The logprobs processor aggregates prompt chunk logprobs
over one or more prefill chunks. This method returns
all prompt logprobs at once and then forgets them.
Ensures correct RequestOutputKind.DELTA semantics
wherein all prompt logprobs are returned at once at
the end of prefill.
Returns:
None if prompt logprobs are disabled for this request.
List of all prompt logprobs, otherwise.
"""
plp = self.prompt_logprobs
if plp:
self.prompt_logprobs = []
return plp
@staticmethod
def _make_logprob_dict(
logprobs: list[float],
logprob_token_ids: list[int],
decoded_tokens: Iterable[Optional[str]],
rank: int,
num_logprobs: int,
) -> dict[int, Logprob]:
"""Make a Logprob dictionary for a position.
Args:
logprobs: list of log probabilities
logprob_token_ids: list of top token ids
decoded_tokens: list of decoded top tokens
rank: rank of the sampled token
num_logprobs: number of logprobs requested
by the user (in addition to sampled logprob)
Returns:
dict[token id, Logprob]
"""
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank, ), topk_ranks)
return {
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(
logprob_token_ids, logprobs, ranks, decoded_tokens)
}
def update_from_output(self, output: EngineCoreOutput) -> None:
if output.new_logprobs is not None:
self._update_sample_logprobs(output.new_logprobs)
if output.new_prompt_logprobs_tensors is not None:
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)

View File

@@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Optional
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.processing import ProcessingCache
from vllm.utils import is_list_of
# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
#
# -- Client:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier.
# - MirroredProcessingCache to keep track of the cached entries and
# determine whether to send the MultiModalKwargs to P1.
#
# -- Server:
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
#
# The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
# client (=P0) and server (=P1) processes if the mm_hash is found in the client
# cache.
# Both Client and Server must use the same cache size
# (to perform mirrored caching). This cache size is set by the environment
# variable VLLM_MM_INPUT_CACHE_GIB.
class MirroredProcessingCache:
def __init__(self, model_config):
mm_config = model_config.multimodal_config
disable_mm_preprocessor_cache = (
mm_config is not None and mm_config.disable_mm_preprocessor_cache)
self.use_cache = not disable_mm_preprocessor_cache
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)
def get_and_update_p0(
self,
mm_inputs: Sequence[MultiModalKwargs],
mm_hashes: list[str],
) -> Sequence[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = list[Optional[MultiModalKwargs]]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if self.mm_cache.get(mm_hash) is not None:
mm_input = None
else:
self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input)
return full_mm_inputs
def get_and_update_p1(
self,
mm_inputs: Sequence[Optional[MultiModalKwargs]],
mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = list[MultiModalKwargs]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_input is None:
mm_input = self.mm_cache[mm_hash]
else:
self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input)
return full_mm_inputs
def reset(self) -> bool:
self.mm_cache.clear()
return True

View File

@@ -0,0 +1,428 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Optional, Union
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
RequestStateStats)
class RequestOutputCollector:
"""
Collects streamed RequestOutputs per individual request,
for hand-off to the consuming asyncio generate task.
When streaming deltas, RequestOutputs are merged if the
producer gets ahead of the consumer.
"""
def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: Optional[Union[RequestOutput, Exception]] = None
self.ready = asyncio.Event()
def put(self, output: Union[RequestOutput, Exception]) -> None:
"""Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
self.output = output
self.ready.set()
elif isinstance(self.output, RequestOutput):
# This ensures that request outputs with different request indexes
# (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate)
async def get(self) -> RequestOutput:
"""Get operation blocks on put event."""
while (output := self.output) is None:
await self.ready.wait()
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
def get_nowait(self) -> Optional[RequestOutput]:
"""Non-blocking get operation."""
output = self.output
if output is not None:
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
@dataclass
class OutputProcessorOutput:
request_outputs: list[RequestOutput]
reqs_to_abort: list[str]
class RequestState:
def __init__(
self,
request_id: str,
parent_req: Optional[ParentRequest],
request_index: int,
lora_name: Optional[str],
output_kind: RequestOutputKind,
prompt: Optional[str],
prompt_token_ids: list[int],
logprobs_processor: LogprobsProcessor,
detokenizer: IncrementalDetokenizer,
max_tokens_param: Optional[int],
arrival_time: float,
queue: Optional[RequestOutputCollector],
log_stats: bool,
):
self.request_id = request_id
self.parent_req = parent_req
self.request_index = request_index
self.lora_name = lora_name
self.output_kind = output_kind
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_len = len(prompt_token_ids)
self.logprobs_processor = logprobs_processor
self.detokenizer = detokenizer
self.max_tokens_param = max_tokens_param
self.is_prefilling = True
self.queue = queue
self.stats = RequestStateStats(
arrival_time=arrival_time) if log_stats else None
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest],
request_index: int,
queue: Optional[RequestOutputCollector],
log_stats: bool,
) -> "RequestState":
if not request.sampling_params.detokenize:
tokenizer = None
return cls(
request_id=request.request_id,
parent_req=parent_req,
request_index=request_index,
lora_name=(request.lora_request.name
if request.lora_request is not None else None),
output_kind=request.sampling_params.output_kind,
prompt=prompt,
prompt_token_ids=request.prompt_token_ids,
logprobs_processor=LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
),
detokenizer=IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
),
max_tokens_param=(request.sampling_params.max_tokens if
request.sampling_params is not None else None),
arrival_time=request.arrival_time,
queue=queue,
log_stats=log_stats,
)
def make_request_output(
self,
new_token_ids: list[int],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> Optional[RequestOutput]:
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
if not finished and final_only:
# Only the final output is required in FINAL_ONLY mode.
return None
completion_output = self._new_completion_output(
new_token_ids, finish_reason, stop_reason)
request_id = self.request_id
if self.parent_req is None:
outputs = [completion_output]
else:
request_id, outputs, finished = self.parent_req.get_outputs(
request_id, completion_output)
if not outputs:
return None
return self._new_request_output(request_id, outputs, finished,
kv_transfer_params, num_cached_tokens)
def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput],
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> RequestOutput:
if self.output_kind == RequestOutputKind.DELTA:
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
else:
prompt_logprobs = self.logprobs_processor.prompt_logprobs
return RequestOutput(
request_id=request_id,
prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids,
prompt_logprobs=prompt_logprobs,
outputs=outputs,
finished=finished,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=num_cached_tokens,
)
def _new_completion_output(
self,
token_ids: list[int],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
) -> CompletionOutput:
finished = finish_reason is not None
delta = self.output_kind == RequestOutputKind.DELTA
# Prepare text and token_ids, based on delta mode
text = self.detokenizer.get_next_output_text(finished, delta)
if not delta:
token_ids = self.detokenizer.output_token_ids
# Prepare logprobs, based on delta mode
logprobs = self.logprobs_processor.logprobs
if delta and logprobs:
logprobs = logprobs[-len(token_ids):]
return CompletionOutput(
index=self.request_index,
text=text,
token_ids=token_ids,
logprobs=logprobs,
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
finish_reason=str(finish_reason) if finished else None,
stop_reason=stop_reason if finished else None)
class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""
def __init__(
self,
tokenizer: TokenizerGroup,
log_stats: bool,
):
self.log_stats = log_stats
self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates()
def get_num_unfinished_requests(self):
return len(self.request_states)
def has_unfinished_requests(self) -> bool:
return len(self.request_states) > 0
def propagate_error(self, e: Exception):
"""Propagate error to all generate() tasks."""
for _, state in self.request_states.items():
assert state.queue is not None
state.queue.put(e)
def abort_requests(
self,
request_ids: Iterable[str],
) -> list[str]:
request_ids_to_abort = []
for request_id in request_ids:
req_state = self.request_states.pop(request_id, None)
if req_state is not None:
self.lora_states.abort_request(req_state)
request_ids_to_abort.append(request_id)
else:
parent = self.parent_requests.pop(request_id, None)
if parent and parent.child_requests:
self.abort_requests(parent.child_requests)
request_ids_to_abort.extend(parent.child_requests)
return request_ids_to_abort
def add_request(
self,
request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest] = None,
request_index: int = 0,
queue: Optional[RequestOutputCollector] = None,
) -> None:
request_id = request.request_id
if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.")
req_state = RequestState.from_new_request(
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
request=request,
prompt=prompt,
parent_req=parent_req,
request_index=request_index,
queue=queue,
log_stats=self.log_stats)
self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req
def process_outputs(
self,
engine_core_outputs: list[EngineCoreOutput],
engine_core_timestamp: Optional[float] = None,
iteration_stats: Optional[IterationStats] = None,
) -> OutputProcessorOutput:
"""
Process the EngineCoreOutputs:
1) Compute stats for logging
2) Detokenize
3) Create and handle RequestOutput objects:
* If there is a queue (for usage with AsyncLLM),
put the RequestOutput objects into the queue for
handling by the per-request generate() tasks.
* If there is no queue (for usage with LLMEngine),
return a list of RequestOutput objects.
NOTE FOR DEVELOPERS
vLLM V1 minimizes the number of python loops over the full
batch to ensure system overheads are minimized. This is the
only function that should loop over EngineCoreOutputs.
If you need to touch every element of the batch, do it from
within the loop below.
"""
request_outputs: list[RequestOutput] = []
reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id
req_state = self.request_states.get(req_id)
if req_state is None:
# Ignore output for already-aborted request.
continue
# 1) Compute stats for this iteration.
self._update_stats_from_output(req_state, engine_core_output,
engine_core_timestamp,
iteration_stats)
new_token_ids = engine_core_output.new_token_ids
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP)
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request, if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason,
kv_transfer_params, num_cached_tokens):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
else:
# LLMEngine: return list of RequestOutputs.
request_outputs.append(request_output)
# Free completed requests.
if finish_reason is not None:
self.request_states.pop(req_id)
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
# Track per-request stats
self._update_stats_from_finished(req_state, finish_reason,
iteration_stats)
self.lora_states.update_iteration_stats(iteration_stats)
return OutputProcessorOutput(
request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort,
)
def _update_stats_from_output(self, req_state: RequestState,
engine_core_output: EngineCoreOutput,
engine_core_timestamp: Optional[float],
iteration_stats: Optional[IterationStats]):
if iteration_stats is None:
return
lora_stats = self.lora_states.get_stats(req_state)
assert engine_core_timestamp is not None
assert req_state.stats is not None
iteration_stats.update_from_output(engine_core_output,
engine_core_timestamp,
req_state.is_prefilling,
req_state.prompt_len,
req_state.stats, lora_stats)
def _update_stats_from_finished(self, req_state: RequestState,
finish_reason: Optional[FinishReason],
iteration_stats: Optional[IterationStats]):
if iteration_stats is None:
return
assert finish_reason is not None
assert req_state.stats is not None
iteration_stats.update_from_finished_request(
finish_reason=finish_reason,
num_prompt_tokens=len(req_state.prompt_token_ids),
max_tokens_param=req_state.max_tokens_param,
req_stats=req_state.stats)
self.lora_states.finish_request(req_state)
ParentRequest.observe_finished_request(
req_state.parent_req, iteration_stats,
req_state.stats.num_generation_tokens)

View File

@@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import copy
from typing import Optional
from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.metrics.stats import IterationStats
class ParentRequest:
"""Info, state & processing for parallel sampling request.
Store parent request ID and sampling params.
Facilitate generating child request sampling params.
"""
request_id: str
sampling_params: SamplingParams
# To track the completion of child requests
child_requests: set[str]
# To aggregate child completions when not streaming
output_aggregator: list[CompletionOutput]
# To find the max number of generated tokens across all children
max_num_generation_tokens: int
# To efficiently obtain child sampling params
cached_child_sampling_params: Optional[SamplingParams]
def __init__(self, request_id: str,
sampling_params: SamplingParams) -> None:
self.request_id = request_id
self.sampling_params = sampling_params
self.child_requests = set()
self.output_aggregator = [None] * sampling_params.n if (
sampling_params.output_kind
== RequestOutputKind.FINAL_ONLY) else []
self.max_num_generation_tokens = 0
self.cached_child_sampling_params = None
def _get_child_sampling_params(
self,
index: int,
) -> SamplingParams:
"""Efficiently obtain child `sampling_params`
If `sampling_params.seed` is not `None` then
each child request requires a unique clone of
parent `sampling_params` with a unique seed.
Args:
index: index within `n` child requests
Returns:
Child `sampling_params` instance.
"""
seed = self.sampling_params.seed
if self.cached_child_sampling_params:
# Reuse child sampling_params data structure
return self.cached_child_sampling_params
# Build child sampling_params
child_sampling_params = copy(self.sampling_params)
child_sampling_params.n = 1
if seed is None:
# Cache child sampling_params for later reuse
self.cached_child_sampling_params = child_sampling_params
else:
# Each child gets a clone with a unique seed
child_sampling_params.seed = seed + index
return child_sampling_params
def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
"""Get child request ID and sampling params.
Args:
index: index within `n` child requests.
Returns:
(request ID, sampling_params) tuple
"""
child_req_id = f"{index}_{self.request_id}"
self.child_requests.add(child_req_id)
return child_req_id, self._get_child_sampling_params(index)
@property
def n(self) -> int:
return self.sampling_params.n
def get_outputs(
self,
child_request_id: str,
completion_output: CompletionOutput,
) -> tuple[str, list[CompletionOutput], bool]:
if completion_output.finished():
self.child_requests.remove(child_request_id)
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
# If streaming, just return the current output.
outputs = [completion_output]
else:
# If not streaming, aggregate the n final outputs.
self.output_aggregator[completion_output.index] = completion_output
outputs = [] if self.child_requests else self.output_aggregator
finished = not self.child_requests
return self.request_id, outputs, finished
def observe_num_generation_tokens(self, num_generation_tokens: int):
self.max_num_generation_tokens = max(num_generation_tokens,
self.max_num_generation_tokens)
return self.max_num_generation_tokens
@staticmethod
def observe_finished_request(parent_req: Optional['ParentRequest'],
iteration_stats: IterationStats,
num_generation_tokens: int):
n_param = parent_req.n if parent_req is not None else 1
if parent_req is not None:
num_generation_tokens = parent_req.observe_num_generation_tokens(
num_generation_tokens)
# Child requests finished, we can now record to iteration stats
if parent_req is None or not parent_req.child_requests:
iteration_stats.max_num_generation_tokens_iter.append(
num_generation_tokens)
iteration_stats.n_params_iter.append(n_param)

407
vllm/v1/engine/processor.py Normal file
View File

@@ -0,0 +1,407 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Mapping, Sequence
from typing import Any, Literal, Optional, Union
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_xgrammar import (
validate_xgrammar_grammar)
class Processor:
def __init__(
self,
vllm_config: VllmConfig,
tokenizer: TokenizerGroup,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.decoding_config = vllm_config.decoding_config
self.tokenizer = tokenizer
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
# Multi-modal hasher (for images)
self.use_hash = self.mm_input_cache_client.use_cache or \
self.cache_config.enable_prefix_caching
@property
def mm_registry(self):
return self.input_preprocessor.mm_registry
def _validate_logprobs(
self,
params: SamplingParams,
) -> None:
max_logprobs = self.model_config.max_logprobs
# Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs:
raise ValueError(
f"Requested sample logprobs of {params.logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
# Validate prompt logprobs.
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
raise ValueError(
f"Requested prompt logprobs of {params.prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
def _validate_sampling_params(
self,
params: SamplingParams,
lora_request: Optional[LoRARequest],
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
if params.allowed_token_ids is None:
return
if not params.allowed_token_ids:
raise ValueError("allowed_token_ids is not None and empty!")
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
vocab_size = len(tokenizer)
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")
def _validate_logit_bias(
self,
params: SamplingParams,
) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not params.logit_bias:
return
vocab_size = self.model_config.get_vocab_size()
invalid_token_ids = []
for token_id in params.logit_bias:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if invalid_token_ids:
raise ValueError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
def _validate_supported_sampling_params(
self,
params: SamplingParams,
) -> None:
# Best of not yet supported.
if params.best_of is not None and params.best_of > 1:
raise ValueError("vLLM V1 does not yet support best_of.")
# Logits processors not supported.
if params.logits_processors:
raise ValueError("vLLM V1 does not support per request "
"user provided logits processors.")
def _validate_params(
self,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest],
):
"""
Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server.
"""
if not isinstance(params, SamplingParams):
raise ValueError("V1 does not yet support Pooling models.")
self._validate_logprobs(params)
self._validate_sampling_params(params, lora_request)
self._validate_supported_sampling_params(params)
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return
engine_level_backend = self.decoding_config.backend
if params.guided_decoding.backend:
# Request-level backend selection is not supported in V1.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_auto` option set on the backend in the params.
if (params.guided_decoding.backend != engine_level_backend
and not (engine_level_backend == "auto"
and params.guided_decoding.backend_was_auto)):
raise ValueError(
"Request-level structured output backend selection is no "
"longer supported. The request specified "
f"'{params.guided_decoding.backend}', but vLLM was "
f"initialised with '{engine_level_backend}'. This error "
"can be resolved by removing backend selection from the "
"request.")
else:
params.guided_decoding.backend = engine_level_backend
# Request content validation
if engine_level_backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_xgrammar_grammar(params)
elif engine_level_backend.startswith("guidance"):
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
else:
# NOTE: engine_level_backend must be "auto" here, because we have
# checked supported_backends above.
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
# default as it is less predictable and subject to change
# between releases as feature support changes.
try:
validate_xgrammar_grammar(params)
params.guided_decoding.backend = "xgrammar"
except ValueError:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
validate_guidance_grammar(params, tokenizer=None)
params.guided_decoding.backend = "guidance"
# Remember that this backend was set automatically
params.guided_decoding.backend_was_auto = True
def process_inputs(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> tuple[Optional[str], EngineCoreRequest]:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request)
self._validate_params(params, lora_request)
if priority != 0:
raise ValueError("V1 does not support priority yet.")
if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.")
if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.")
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
data_parallel_size):
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
f"is out of range [0, {data_parallel_size}).")
if arrival_time is None:
arrival_time = time.time()
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
# 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=self.use_hash,
)
from vllm.platforms import current_platform
current_platform.validate_request(
prompt=prompt,
params=params,
processed_inputs=processed_inputs,
)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
self._validate_model_inputs(processed_inputs, lora_request)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
# TODO: Impl encoder-decoder
if encoder_inputs is not None:
raise NotImplementedError
assert isinstance(params, SamplingParams)
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (
self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
# Multimodal related.
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
(
sorted_item_modalities,
sorted_mm_positions,
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
decoder_inputs["mm_placeholders"],
decoder_inputs["mm_hashes"] if self.use_hash else None,
)
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# is a single MultiModalKwargs for all items from all modalities.
# This code flattens kwargs for individual items in a list and
# sorts them by each item's position in the input sequence if there
# are multiple modalities.
unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1:
orig_sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities}
for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
orig_sorted_mm_inputs.append(
MultiModalKwargs.from_items([item]))
used_indices[modality] += 1
else:
orig_sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
orig_sorted_mm_inputs, sorted_mm_hashes)
else:
sorted_mm_inputs = orig_sorted_mm_inputs
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_inputs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
sampling_params=sampling_params,
eos_token_id=eos_token_id,
arrival_time=arrival_time,
lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"),
data_parallel_rank=data_parallel_rank,
)
def _validate_model_inputs(self,
inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
model_config = self.model_config
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
prompt_ids = prompt_inputs["prompt_token_ids"]
if not prompt_ids:
if prompt_type == "encoder" and model_config.is_multimodal_model:
pass # Mllama may have empty encoder inputs for text-only data
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config,
tokenizer=tokenizer,
)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper
if model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens.")
raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens

View File

View File

@@ -0,0 +1,113 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import Future
from typing import Callable, Union
import torch
import torch.distributed as dist
from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
from vllm.executor.uniproc_executor import ( # noqa
UniProcExecutor as UniProcExecutorV0)
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
FailureCallback = Callable[[], None]
class Executor(ExecutorBase):
"""
Abstract class for v1 executors, mainly define some methods for v1.
For methods shared by v0 and v1, define them in ExecutorBase"""
@staticmethod
def get_class(vllm_config: VllmConfig) -> type["Executor"]:
executor_class: type[Executor]
parallel_config = vllm_config.parallel_config
distributed_executor_backend = (
parallel_config.distributed_executor_backend)
# distributed_executor_backend must be set in VllmConfig.__post_init__
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif distributed_executor_backend == "ray":
from vllm.v1.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
elif distributed_executor_backend == "uni":
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# TODO: make v1 scheduling deterministic
# to support external launcher
executor_class = ExecutorWithExternalLauncher
else:
raise ValueError("Unknown distributed executor backend: "
f"{distributed_executor_backend}")
return executor_class
def initialize_from_config(self,
kv_cache_configs: list[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_from_config",
args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model")
def register_failure_callback(self, callback: FailureCallback):
"""
Register a function to be called if the executor enters a permanent
failed state.
"""
pass
def determine_available_memory(self) -> list[int]: # in bytes
output = self.collective_rpc("determine_available_memory")
return output
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
output = self.collective_rpc("get_kv_cache_spec")
return output
def execute_model(
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
return output[0]
@property
def max_concurrent_batches(self) -> int:
return 1
def profile(self, is_start: bool = True):
self.collective_rpc("profile", args=(is_start, ))
class UniProcExecutor(UniProcExecutorV0, Executor):
pass
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
def determine_available_memory(self) -> list[int]: # in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory = super().determine_available_memory()
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return [memory_tensor.item()]

View File

@@ -0,0 +1,537 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import os
import pickle
import signal
import sys
import threading
import time
import traceback
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from multiprocessing.connection import Connection
from multiprocessing.process import BaseProcess
from threading import Thread
from typing import Any, Callable, Optional, Union, cast
import cloudpickle
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.executor.multiproc_worker_utils import (
_add_prefix, set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_mp_context,
get_open_port)
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
EXECUTE_MODEL_TIMEOUT_S = 300
class MultiprocExecutor(Executor):
def _init_executor(self) -> None:
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown)
self.is_failed = False
self.shutdown_event = threading.Event()
self.failure_callback: Optional[FailureCallback] = None
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
pp_parallel_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
f"_parallel_size ({pp_parallel_size}). ")
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
distributed_init_method = get_distributed_init_method(
"127.0.0.1", get_open_port())
# Initialize worker and set up message queues for SchedulerOutputs
# and ModelRunnerOutputs
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
self.rpc_broadcast_mq = MessageQueue(self.world_size,
self.world_size,
max_chunk_bytes=max_chunk_bytes)
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
unready_workers: list[UnreadyWorkerProcHandle] = []
success = False
try:
for rank in range(self.world_size):
unready_workers.append(
WorkerProc.make_worker_process(
vllm_config=self.vllm_config,
local_rank=rank,
rank=rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
))
# Workers must be created before wait_for_ready to avoid
# deadlock, since worker.init_device() does a device sync.
self.workers = WorkerProc.wait_for_ready(unready_workers)
# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc.
self.rpc_broadcast_mq.wait_until_ready()
for w in self.workers:
w.worker_response_mq.wait_until_ready()
self.start_worker_monitor()
success = True
finally:
if not success:
# Clean up the worker procs if there was a failure.
self._ensure_worker_termination(
[w.proc for w in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")
self.output_rank = self._get_output_rank()
def start_worker_monitor(self):
workers = self.workers
self_ref = weakref.ref(self)
# Monitors worker process liveness. If any die unexpectedly,
# logs an error, shuts down the executor and invokes the failure
# callback to inform the engine.
def monitor_workers():
sentinels = [h.proc.sentinel for h in workers]
died = multiprocessing.connection.wait(sentinels)
_self = self_ref()
if not _self or getattr(_self, 'shutting_down', False):
return
_self.is_failed = True
proc_name = next(h.proc.name for h in workers
if h.proc.sentinel == died[0])
logger.error(
"Worker proc %s died unexpectedly, "
"shutting down executor.", proc_name)
_self.shutdown()
callback = _self.failure_callback
if callback is not None:
_self.failure_callback = None
callback()
Thread(target=monitor_workers,
daemon=True,
name="MultiprocWorkerMonitor").start()
def register_failure_callback(self, callback: FailureCallback):
if self.is_failed:
callback()
else:
self.failure_callback = callback
def execute_model(
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
(output, ) = self.collective_rpc("execute_model",
args=(scheduler_output, ),
unique_reply_rank=self.output_rank,
non_block=self.max_concurrent_batches
> 1,
timeout=EXECUTE_MODEL_TIMEOUT_S)
return output
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
if self.is_failed:
raise RuntimeError("Executor failed.")
deadline = None if timeout is None else time.monotonic() + timeout
kwargs = kwargs or {}
# NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker
# knows their own rank.
try:
if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, unique_reply_rank))
workers = (self.workers[unique_reply_rank],
) if unique_reply_rank is not None else self.workers
responses = []
def get_response(w: WorkerProcHandle,
dequeue_timeout: Optional[float] = None,
cancel_event: Optional[threading.Event] = None):
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=cancel_event)
if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError(
f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause")
return result
for w in workers:
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic())
if non_block:
result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event)
else:
result = get_response(w, dequeue_timeout)
responses.append(result)
return responses
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
"""Ensure that all worker processes are terminated. Assumes workers have
received termination requests. Waits for processing, then sends
termination and kill signals if needed."""
def wait_for_termination(procs, timeout):
if not time:
# If we are in late stage shutdown, the interpreter may replace
# `time` with `None`.
return all(not proc.is_alive() for proc in procs)
start_time = time.time()
while time.time() - start_time < timeout:
if all(not proc.is_alive() for proc in procs):
return True
time.sleep(0.1)
return False
# Send SIGTERM if still running
active_procs = [proc for proc in worker_procs if proc.is_alive()]
for p in active_procs:
p.terminate()
if not wait_for_termination(active_procs, 4):
# Send SIGKILL if still running
active_procs = [p for p in active_procs if p.is_alive()]
for p in active_procs:
p.kill()
def shutdown(self):
"""Properly shut down the executor and its workers"""
if not getattr(self, 'shutting_down', False):
self.shutting_down = True
self.shutdown_event.set()
if self.io_thread_pool is not None:
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
self.io_thread_pool = None
if workers := getattr(self, 'workers', None):
for w in workers:
w.worker_response_mq = None
self._ensure_worker_termination([w.proc for w in workers])
self.rpc_broadcast_mq = None
def check_health(self) -> None:
self.collective_rpc("check_health", timeout=10)
return
@property
def max_concurrent_batches(self) -> int:
return self.parallel_config.pipeline_parallel_size
def _get_output_rank(self) -> int:
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
# (the first TP worker of the last PP stage).
# Example:
# Assuming TP=8, PP=4, then the world_size=32
# 0-7, PP rank 0
# 8-15, PP rank 1
# 16-23, PP rank 2
# 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return self.world_size - self.parallel_config.tensor_parallel_size
@dataclass
class UnreadyWorkerProcHandle:
"""WorkerProcess handle before READY."""
proc: BaseProcess
rank: int
ready_pipe: Connection
@dataclass
class WorkerProcHandle:
proc: BaseProcess
rank: int
worker_response_mq: MessageQueue # The worker process writes to this MQ
@classmethod
def from_unready_handle(
cls, unready_handle: UnreadyWorkerProcHandle,
worker_response_mq: MessageQueue) -> "WorkerProcHandle":
return cls(
proc=unready_handle.proc,
rank=unready_handle.rank,
worker_response_mq=worker_response_mq,
)
class WorkerProc:
"""Wrapper that runs one Worker in a separate process."""
READY_STR = "READY"
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle: Handle,
):
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
is_driver_worker = (
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
all_kwargs[rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"is_driver_worker": is_driver_worker,
}
wrapper.init_worker(all_kwargs)
self.worker = wrapper
pid = os.getpid()
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
_add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid)
# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank)
# Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1)
# Initialize device and loads weights
self.worker.init_device()
self.worker.load_model()
@staticmethod
def make_worker_process(
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle, # Receive SchedulerOutput
) -> UnreadyWorkerProcHandle:
context = get_mp_context()
# (reader, writer)
reader, writer = context.Pipe(duplex=False)
process_kwargs = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"input_shm_handle": input_shm_handle,
"ready_pipe": (reader, writer),
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs,
name=f"VllmWorker-{rank}",
daemon=True)
proc.start()
writer.close()
return UnreadyWorkerProcHandle(proc, rank, reader)
@staticmethod
def wait_for_ready(
unready_proc_handles: list[UnreadyWorkerProcHandle]
) -> list[WorkerProcHandle]:
e = Exception("WorkerProc initialization failed due to "
"an exception in a background process. "
"See stack trace for root cause.")
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
ready_proc_handles: list[Optional[WorkerProcHandle]] = (
[None] * len(unready_proc_handles))
while pipes:
ready = multiprocessing.connection.wait(pipes.keys())
for pipe in ready:
assert isinstance(pipe, Connection)
try:
# Wait until the WorkerProc is ready.
unready_proc_handle = pipes.pop(pipe)
response: dict[str, Any] = pipe.recv()
if response["status"] != "READY":
raise e
# Extract the message queue handle.
worker_response_mq = MessageQueue.create_from_handle(
response["handle"], 0)
ready_proc_handles[unready_proc_handle.rank] = (
WorkerProcHandle.from_unready_handle(
unready_proc_handle, worker_response_mq))
except EOFError:
e.__suppress_context__ = True
raise e from None
finally:
# Close connection.
pipe.close()
return cast(list[WorkerProcHandle], ready_proc_handles)
def shutdown(self):
self.rpc_broadcast_mq = None
self.worker_response_mq = None
destroy_model_parallel()
destroy_distributed_environment()
@staticmethod
def worker_main(*args, **kwargs):
""" Worker initialization and execution loops.
This runs a background process """
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the worker
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
worker = None
# tuple[Connection, Connection]
reader, ready_writer = kwargs.pop("ready_pipe")
try:
reader.close()
worker = WorkerProc(*args, **kwargs)
# Send READY once we know everything is loaded
ready_writer.send({
"status":
WorkerProc.READY_STR,
"handle":
worker.worker_response_mq.export_handle(),
})
# Ensure message queues are ready. Will deadlock if re-ordered.
# Must be kept consistent with the Executor
worker.rpc_broadcast_mq.wait_until_ready()
worker.worker_response_mq.wait_until_ready()
ready_writer.close()
ready_writer = None
worker.worker_busy_loop()
except Exception:
# NOTE: if an Exception arises in busy_loop, we send
# a FAILURE message over the MQ RPC to notify the Executor,
# which triggers system shutdown.
# TODO(rob): handle case where the MQ itself breaks.
if ready_writer is not None:
logger.exception("WorkerProc failed to start.")
else:
logger.exception("WorkerProc failed.")
# The parent sends a SIGTERM to all worker processes if
# any worker dies. Set this value so we don't re-throw
# SystemExit() to avoid zmq exceptions in __del__.
shutdown_requested = True
finally:
if ready_writer is not None:
ready_writer.close()
# Clean up once worker exits busy loop
if worker is not None:
worker.shutdown()
class ResponseStatus(Enum):
SUCCESS = auto()
FAILURE = auto()
def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
try:
if isinstance(method, str):
func = getattr(self.worker, method)
elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker)
output = func(*args, **kwargs)
except Exception as e:
# Notes have been introduced in python 3.11
if hasattr(e, "add_note"):
e.add_note(traceback.format_exc())
logger.exception("WorkerProc hit an exception.")
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
continue
if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))

View File

@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import Future
from typing import Union
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor as RayDistributedExecutorV0)
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
class FutureWrapper(Future):
"""A wrapper around a Ray output reference to meet the interface
of .execute_model().
"""
def __init__(self, ref):
super().__init__()
self.ref = ref
def result(self, timeout=None):
if timeout is not None:
raise NotImplementedError("timeout is not supported")
return self.ref.get()
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
"""Ray distributed executor using Ray Compiled Graphs."""
@property
def max_concurrent_batches(self) -> int:
"""Ray distributed executor supports pipeline parallelism,
meaning that it allows PP size batches to be executed concurrently.
"""
return self.parallel_config.pipeline_parallel_size
def execute_model(
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
"""Execute the model on the Ray workers.
Args:
scheduler_output: The scheduler output to execute.
Returns:
The model runner output.
"""
# Build the compiled DAG for the first time.
if self.forward_dag is None: # type: ignore
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
refs = self.forward_dag.execute(scheduler_output) # type: ignore
# When PP is not used, we block here until the result is available.
if self.max_concurrent_batches == 1:
return refs[0].get()
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return FutureWrapper(refs[0])

View File

@@ -0,0 +1,194 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from dataclasses import dataclass
from typing import Optional
import torch
from typing_extensions import Self
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, get_dtype_size
logger = init_logger(__name__)
@dataclass
class KVCacheSpec:
"""
A base class for specifying the KV cache format of one layer.
"""
# number of tokens in a block
block_size: int
@property
def type_id(self) -> str:
"""
The type identifier of this KV cache.
Return different strings for layers with different KV cache type (e.g.,
different number of tokens like full attention vs sliding window
attention, different KV cache size per token like layers with different
number of heads)
Returns:
The type identifier of this KV cache.
"""
raise NotImplementedError
@property
def page_size_bytes(self) -> int:
"""
The size of a page with `block_size` tokens in bytes.
Returns:
The page size
"""
raise NotImplementedError
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
"""
The maximum possible memory usage of this KV cache in bytes.
Returns:
The KV cache size in bytes
"""
raise NotImplementedError
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
"""
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
"All layers in the same KV cache group must share the same "
"type_id.")
return copy.deepcopy(specs[0])
@dataclass
class AttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
dtype: torch.dtype
use_mla: bool
@property
def page_size_bytes(self) -> int:
# For MLA we only store a single latent vector
coef = 1 if self.use_mla else 2
return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
@dataclass
class FullAttentionSpec(AttentionSpec):
sliding_window: Optional[int] = None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
window attention are regarded as full attention in KV cache manager
(blocks are allocated for all tokens), while computed as sliding window
attention in model runner.
In this case, we use FullAttentionSpec and record the sliding window size.
Default to None for not using sliding window attention.
"""
@property
def type_id(self) -> str:
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
merged_spec = super().merge(specs)
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
if len(sliding_window) == 0:
merged_spec.sliding_window = None
elif len(sliding_window) == 1:
merged_spec.sliding_window = sliding_window.pop()
else:
raise ValueError(
"All sliding window layers in the same KV cache group "
"must have the same window size.")
return merged_spec
@dataclass
class SlidingWindowSpec(AttentionSpec):
sliding_window: int
def __post_init__(self):
assert not self.use_mla, "MLA is not supported for sliding window"
@property
def type_id(self) -> str:
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
# During chunked prefill, we allocate KV cache for the last
# `self.sliding_window-1` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
max_model_len)
# +1 here because the sliding window may not start from the beginning
# of the block. For example, if the block size is 4 and num_token
# is 4, we need two blocks [XXCD] [EF] to store the sliding
# window [CDEF] of 6 tokens.
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass
class KVCacheTensor:
"""
A class for specifying how the workers should initialize the KV cache.
"""
size: int # size of the KV cache tensor in bytes
shared_by: list[str] # layer names that share the same KV cache tensor
@dataclass
class KVCacheGroupSpec:
"""
Represents a group of model layers that share the same KV cache block table.
These layers are regarded as one layer in the KV cache manager.
"""
# The names of model layers in this group
layer_names: list[str]
# The KV cache spec of this manager layer
kv_cache_spec: KVCacheSpec
@dataclass
class KVCacheConfig:
"""
The KV cache configuration of a model.
"""
"""The number of KV cache blocks"""
num_blocks: int
"""How should model runner initialize the KV cache tensors for each layer"""
kv_cache_tensors: list[KVCacheTensor]
"""
The kv cache groups of the model.
For models with only one type of attention, there is only one group that
contains all layers.
For models with multiple types of attention, there will be multiple groups,
see `_get_kv_cache_config_uniform_page_size` for more details.
"""
kv_cache_groups: list[KVCacheGroupSpec]

View File

523
vllm/v1/metrics/loggers.py Normal file
View File

@@ -0,0 +1,523 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import time
from abc import ABC, abstractmethod
from typing import Callable, Optional
import numpy as np
import prometheus_client
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__)
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
class StatLoggerBase(ABC):
"""Interface for logging metrics.
API users may define custom loggers that implement this interface.
However, note that the `SchedulerStats` and `IterationStats` classes
are not considered stable interfaces and may change in future versions.
"""
@abstractmethod
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
...
@abstractmethod
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
...
@abstractmethod
def log_engine_initialized(self):
...
def log(self): # noqa
pass
class LoggingStatLogger(StatLoggerBase):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self.engine_index = engine_index
self.vllm_config = vllm_config
self._reset(time.monotonic())
self.last_scheduler_stats = SchedulerStats()
# Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_logging = SpecDecodingLogging()
self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0
def _reset(self, now):
self.last_log_time = now
# Tracked stats over current local logging interval.
self.num_prompt_tokens: list[int] = []
self.num_generation_tokens: list[int] = []
def _track_iteration_stats(self, iteration_stats: IterationStats):
# Save tracked stats for token counters.
self.num_prompt_tokens.append(iteration_stats.num_prompt_tokens)
self.num_generation_tokens.append(
iteration_stats.num_generation_tokens)
def _get_throughput(self, tracked_stats: list[int], now: float) -> float:
# Compute summary metrics for tracked stats
return float(np.sum(tracked_stats) / (now - self.last_log_time))
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
"""Log Stats to standard output."""
if iteration_stats:
self._track_iteration_stats(iteration_stats)
if scheduler_stats is not None:
self.prefix_caching_metrics.observe(
scheduler_stats.prefix_cache_stats)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_logging.observe(
scheduler_stats.spec_decoding_stats)
self.last_scheduler_stats = scheduler_stats
def log(self):
now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now)
self._reset(now)
scheduler_stats = self.last_scheduler_stats
log_fn = logger.info
if not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput)):
# Avoid log noise on an idle production system
log_fn = logger.debug
self.last_generation_throughput = generation_throughput
self.last_prompt_throughput = prompt_throughput
# Format and print output.
log_fn(
"Engine %03d: "
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
self.engine_index,
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
scheduler_stats.gpu_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
)
self.spec_decoding_logging.log(log_fn=log_fn)
def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks:
logger.info(
"Engine %03d: vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d", self.engine_index,
self.vllm_config.cache_config.num_gpu_blocks)
class PrometheusStatLogger(StatLoggerBase):
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
_spec_decoding_cls = SpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
unregister_vllm_metrics()
self.vllm_config = vllm_config
self.engine_index = engine_index
# Use this flag to hide metrics that were deprecated in
# a previous release and which will be removed future
self.show_hidden_metrics = \
vllm_config.observability_config.show_hidden_metrics
labelnames = ["model_name", "engine"]
labelvalues = [
vllm_config.model_config.served_model_name,
str(engine_index)
]
max_model_len = vllm_config.model_config.max_model_len
self.spec_decoding_prom = self._spec_decoding_cls(
vllm_config.speculative_config, labelnames, labelvalues)
#
# Scheduler state
#
self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running",
documentation="Number of requests in model execution batches.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues)
self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues)
#
# GPU cache
#
self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues)
self.counter_gpu_prefix_cache_queries = self._counter_cls(
name="vllm:gpu_prefix_cache_queries",
documentation=
"GPU prefix cache queries, in terms of number of queried tokens.",
labelnames=labelnames).labels(*labelvalues)
self.counter_gpu_prefix_cache_hits = self._counter_cls(
name="vllm:gpu_prefix_cache_hits",
documentation=
"GPU prefix cache hits, in terms of number of cached tokens.",
labelnames=labelnames).labels(*labelvalues)
#
# Counters
#
self.counter_num_preempted_reqs = self._counter_cls(
name="vllm:num_preemptions",
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames).labels(*labelvalues)
self.counter_prompt_tokens = self._counter_cls(
name="vllm:prompt_tokens",
documentation="Number of prefill tokens processed.",
labelnames=labelnames).labels(*labelvalues)
self.counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames).labels(*labelvalues)
self.counter_request_success: dict[FinishReason,
prometheus_client.Counter] = {}
counter_request_success_base = self._counter_cls(
name="vllm:request_success",
documentation="Count of successfully processed requests.",
labelnames=labelnames + ["finished_reason"])
for reason in FinishReason:
self.counter_request_success[
reason] = counter_request_success_base.labels(*(labelvalues +
[str(reason)]))
#
# Histograms of counts
#
self.histogram_num_prompt_tokens_request = \
self._histogram_cls(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
self.histogram_num_generation_tokens_request = \
self._histogram_cls(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
# See: https://github.com/vllm-project/vllm/pull/18053
self.histogram_iteration_tokens = \
self._histogram_cls(
name="vllm:iteration_tokens_total",
documentation="Histogram of number of tokens per engine_step.",
buckets=[
1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192,
16384
],
labelnames=labelnames).labels(*labelvalues)
self.histogram_max_num_generation_tokens_request = \
self._histogram_cls(
name="vllm:request_max_num_generation_tokens",
documentation=
"Histogram of maximum number of requested generation tokens.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
self.histogram_n_request = \
self._histogram_cls(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
buckets=[1, 2, 5, 10, 20],
labelnames=labelnames).labels(*labelvalues)
self.histogram_max_tokens_request = \
self._histogram_cls(
name="vllm:request_params_max_tokens",
documentation="Histogram of the max_tokens request parameter.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
#
# Histogram of timing intervals
#
self.histogram_time_to_first_token = \
self._histogram_cls(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0,
640.0, 2560.0
],
labelnames=labelnames).labels(*labelvalues)
self.histogram_time_per_output_token = \
self._histogram_cls(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
],
labelnames=labelnames).labels(*labelvalues)
request_latency_buckets = [
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
]
self.histogram_e2e_time_request = \
self._histogram_cls(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of e2e request latency in seconds.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_queue_time_request = \
self._histogram_cls(
name="vllm:request_queue_time_seconds",
documentation=
"Histogram of time spent in WAITING phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_inference_time_request = \
self._histogram_cls(
name="vllm:request_inference_time_seconds",
documentation=
"Histogram of time spent in RUNNING phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_prefill_time_request = \
self._histogram_cls(
name="vllm:request_prefill_time_seconds",
documentation=
"Histogram of time spent in PREFILL phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_decode_time_request = \
self._histogram_cls(
name="vllm:request_decode_time_seconds",
documentation=
"Histogram of time spent in DECODE phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
#
# LoRA metrics
#
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
if vllm_config.lora_config is not None:
self.labelname_max_lora = "max_lora"
self.labelname_waiting_lora_adapters = "waiting_lora_adapters"
self.labelname_running_lora_adapters = "running_lora_adapters"
self.max_lora = vllm_config.lora_config.max_loras
self.gauge_lora_info = \
self._gauge_cls(
name="vllm:lora_requests_info",
documentation="Running stats on lora requests.",
multiprocess_mode="sum",
labelnames=[
self.labelname_max_lora,
self.labelname_waiting_lora_adapters,
self.labelname_running_lora_adapters,
],
)
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
metrics_info = config_obj.metrics_info()
metrics_info["engine"] = self.engine_index
name, documentation = None, None
if type == "cache_config":
name = "vllm:cache_config_info"
documentation = "Information of the LLMEngine CacheConfig"
assert name is not None, f"Unknown metrics info type {type}"
# Info type metrics are syntactic sugar for a gauge permanently set to 1
# Since prometheus multiprocessing mode does not support Info, emulate
# info here with a gauge.
info_gauge = self._gauge_cls(
name=name,
documentation=documentation,
multiprocess_mode="mostrecent",
labelnames=metrics_info.keys(),
).labels(**metrics_info)
info_gauge.set(1)
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
"""Log to prometheus."""
if scheduler_stats is not None:
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage)
self.counter_gpu_prefix_cache_queries.inc(
scheduler_stats.prefix_cache_stats.queries)
self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_prom.observe(
scheduler_stats.spec_decoding_stats)
if iteration_stats is None:
return
self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs)
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
self.counter_generation_tokens.inc(
iteration_stats.num_generation_tokens)
self.histogram_iteration_tokens.observe(
iteration_stats.num_prompt_tokens + \
iteration_stats.num_generation_tokens)
for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter:
self.histogram_max_num_generation_tokens_request.observe(
max_gen_tokens)
for n_param in iteration_stats.n_params_iter:
self.histogram_n_request.observe(n_param)
for ttft in iteration_stats.time_to_first_tokens_iter:
self.histogram_time_to_first_token.observe(ttft)
for tpot in iteration_stats.time_per_output_tokens_iter:
self.histogram_time_per_output_token.observe(tpot)
for finished_request in iteration_stats.finished_requests:
self.counter_request_success[finished_request.finish_reason].inc()
self.histogram_e2e_time_request.observe(
finished_request.e2e_latency)
self.histogram_queue_time_request.observe(
finished_request.queued_time)
self.histogram_prefill_time_request.observe(
finished_request.prefill_time)
self.histogram_inference_time_request.observe(
finished_request.inference_time)
self.histogram_decode_time_request.observe(
finished_request.decode_time)
self.histogram_num_prompt_tokens_request.observe(
finished_request.num_prompt_tokens)
self.histogram_num_generation_tokens_request.observe(
finished_request.num_generation_tokens)
self.histogram_max_tokens_request.observe(
finished_request.max_tokens_param)
if self.gauge_lora_info is not None:
running_lora_adapters = \
",".join(iteration_stats.running_lora_adapters.keys())
waiting_lora_adapters = \
",".join(iteration_stats.waiting_lora_adapters.keys())
lora_info_labels = {
self.labelname_running_lora_adapters: running_lora_adapters,
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
self.labelname_max_lora: self.max_lora,
}
self.gauge_lora_info.labels(**lora_info_labels)\
.set_to_current_time()
def log_engine_initialized(self):
self.log_metrics_info("cache_config", self.vllm_config.cache_config)
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
"""
Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values until the value exceeds the specified maximum.
"""
exponent = 0
buckets: list[int] = []
while True:
for m in mantissa_lst:
value = m * 10**exponent
if value <= max_value:
buckets.append(value)
else:
return buckets
exponent += 1
def build_1_2_5_buckets(max_value: int) -> list[int]:
"""
Example:
>>> build_1_2_5_buckets(100)
[1, 2, 5, 10, 20, 50, 100]
"""
return build_buckets([1, 2, 5], max_value)
def setup_default_loggers(
vllm_config: VllmConfig,
log_stats: bool,
engine_num: int,
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> list[list[StatLoggerBase]]:
"""Setup logging and prometheus metrics."""
if not log_stats:
return []
factories: list[StatLoggerFactory]
if custom_stat_loggers is not None:
factories = custom_stat_loggers
else:
factories = [PrometheusStatLogger]
if logger.isEnabledFor(logging.INFO):
factories.append(LoggingStatLogger)
stat_loggers: list[list[StatLoggerBase]] = []
for i in range(engine_num):
per_engine_stat_loggers: list[StatLoggerBase] = []
for logger_factory in factories:
per_engine_stat_loggers.append(logger_factory(vllm_config, i))
stat_loggers.append(per_engine_stat_loggers)
return stat_loggers

View File

@@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import tempfile
from typing import Optional
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
from vllm.logger import init_logger
logger = init_logger(__name__)
# Global temporary directory for prometheus multiprocessing
_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None
def setup_multiprocess_prometheus():
"""Set up prometheus multiprocessing directory if not already configured.
"""
global _prometheus_multiproc_dir
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
_prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name
logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s",
_prometheus_multiproc_dir.name)
else:
logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
def get_prometheus_registry():
"""Get the appropriate prometheus registry based on multiprocessing
configuration.
Returns:
Registry: A prometheus registry
"""
if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None:
logger.debug("Using multiprocess registry for prometheus metrics")
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
return registry
return REGISTRY
def unregister_vllm_metrics():
"""Unregister any existing vLLM collectors from the prometheus registry.
This is useful for testing and CI/CD where metrics may be registered
multiple times across test runs.
Also, in case of multiprocess, we need to unregister the metrics from the
global registry.
"""
registry = REGISTRY
# Unregister any existing vLLM collectors
for collector in list(registry._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
registry.unregister(collector)
def shutdown_prometheus():
"""Shutdown prometheus metrics."""
path = _prometheus_multiproc_dir
if path is None:
return
try:
pid = os.getpid()
multiprocess.mark_process_dead(pid, path)
logger.debug("Marked Prometheus metrics for process %d as dead", pid)
except Exception as e:
logger.error("Error during metrics cleanup: %s", str(e))

View File

@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import Optional, Union
from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.spec_decode.metrics import SpecDecodingProm
try:
from ray.util import metrics as ray_metrics
from ray.util.metrics import Metric
except ImportError:
ray_metrics = None
class RayPrometheusMetric:
def __init__(self):
if ray_metrics is None:
raise ImportError(
"RayPrometheusMetric requires Ray to be installed.")
self.metric: Metric = None
def labels(self, *labels, **labelskwargs):
if labelskwargs:
for k, v in labelskwargs.items():
if not isinstance(v, str):
labelskwargs[k] = str(v)
self.metric.set_default_tags(labelskwargs)
if labels:
if len(labels) != len(self.metric._tag_keys):
raise ValueError(
"Number of labels must match the number of tag keys. "
f"Expected {len(self.metric._tag_keys)}, got {len(labels)}"
)
self.metric.set_default_tags(
dict(zip(self.metric._tag_keys, labels)))
return self
class RayGaugeWrapper(RayPrometheusMetric):
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
def __init__(self,
name: str,
documentation: Optional[str] = "",
labelnames: Optional[list[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self.metric = ray_metrics.Gauge(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def set(self, value: Union[int, float]):
return self.metric.set(value)
def set_to_current_time(self):
# ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html
return self.metric.set(time.time())
class RayCounterWrapper(RayPrometheusMetric):
"""Wraps around ray.util.metrics.Counter to provide same API as
prometheus_client.Counter"""
def __init__(self,
name: str,
documentation: Optional[str] = "",
labelnames: Optional[list[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self.metric = ray_metrics.Counter(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def inc(self, value: Union[int, float] = 1.0):
if value == 0:
return
return self.metric.inc(value)
class RayHistogramWrapper(RayPrometheusMetric):
"""Wraps around ray.util.metrics.Histogram to provide same API as
prometheus_client.Histogram"""
def __init__(self,
name: str,
documentation: Optional[str] = "",
labelnames: Optional[list[str]] = None,
buckets: Optional[list[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
boundaries = buckets if buckets else []
self.metric = ray_metrics.Histogram(name=name,
description=documentation,
tag_keys=labelnames_tuple,
boundaries=boundaries)
def observe(self, value: Union[int, float]):
return self.metric.observe(value)
class RaySpecDecodingProm(SpecDecodingProm):
"""
RaySpecDecodingProm is used by RayMetrics to log to Ray metrics.
Provides the same metrics as SpecDecodingProm but uses Ray's
util.metrics library.
"""
_counter_cls = RayCounterWrapper
class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_gauge_cls = RayGaugeWrapper
_counter_cls = RayCounterWrapper
_histogram_cls = RayHistogramWrapper
_spec_decoding_cls = RaySpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
super().__init__(vllm_config, engine_index)
@staticmethod
def _unregister_vllm_metrics():
# No-op on purpose
pass

246
vllm/v1/metrics/reader.py Normal file
View File

@@ -0,0 +1,246 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
from prometheus_client import REGISTRY
from prometheus_client import Metric as PromMetric
from prometheus_client.samples import Sample
@dataclass
class Metric:
"""A base class for prometheus metrics.
Each metric may be associated with key=value labels, and
in some cases a single vLLM instance may have multiple
metrics with the same name but different sets of labels.
"""
name: str
labels: dict[str, str]
@dataclass
class Counter(Metric):
"""A monotonically increasing integer counter."""
value: int
@dataclass
class Vector(Metric):
"""An ordered array of integer counters.
This type - which doesn't exist in Prometheus - models one very
specific metric, vllm:spec_decode_num_accepted_tokens_per_pos.
"""
values: list[int]
@dataclass
class Gauge(Metric):
"""A numerical value that can go up or down."""
value: float
@dataclass
class Histogram(Metric):
"""Observations recorded in configurable buckets.
Buckets are represented by a dictionary. The key is
the upper limit of the bucket, and the value is the
observed count in that bucket. A '+Inf' key always
exists.
The count property is the total count across all
buckets, identical to the count of the '+Inf' bucket.
The sum property is the total sum of all observed
values.
"""
count: int
sum: float
buckets: dict[str, int]
def get_metrics_snapshot() -> list[Metric]:
"""An API for accessing in-memory Prometheus metrics.
Example:
>>> for metric in llm.get_metrics():
... if isinstance(metric, Counter):
... print(f"{metric} = {metric.value}")
... elif isinstance(metric, Gauge):
... print(f"{metric} = {metric.value}")
... elif isinstance(metric, Histogram):
... print(f"{metric}")
... print(f" sum = {metric.sum}")
... print(f" count = {metric.count}")
... for bucket_le, value in metrics.buckets.items():
... print(f" {bucket_le} = {value}")
"""
collected: list[Metric] = []
for metric in REGISTRY.collect():
if not metric.name.startswith("vllm:"):
continue
if metric.type == "gauge":
samples = _get_samples(metric)
for s in samples:
collected.append(
Gauge(name=metric.name, labels=s.labels, value=s.value))
elif metric.type == "counter":
samples = _get_samples(metric, "_total")
if metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
#
# Ugly vllm:num_accepted_tokens_per_pos special case.
#
# This metric is a vector of counters - for each spec
# decoding token position, we observe the number of
# accepted tokens using a Counter labeled with 'position'.
# We convert these into a vector of integer values.
#
for labels, values in _digest_num_accepted_by_pos_samples(
samples):
collected.append(
Vector(name=metric.name, labels=labels, values=values))
else:
for s in samples:
collected.append(
Counter(name=metric.name,
labels=s.labels,
value=int(s.value)))
elif metric.type == "histogram":
#
# A histogram has a number of '_bucket' samples where
# the 'le' label represents the upper limit of the bucket.
# We convert these bucketized values into a dict of values
# indexed by the value of the 'le' label. The 'le=+Inf'
# label is a special case, catching all values observed.
#
bucket_samples = _get_samples(metric, "_bucket")
count_samples = _get_samples(metric, "_count")
sum_samples = _get_samples(metric, "_sum")
for labels, buckets, count_value, sum_value in _digest_histogram(
bucket_samples, count_samples, sum_samples):
collected.append(
Histogram(name=metric.name,
labels=labels,
buckets=buckets,
count=count_value,
sum=sum_value))
else:
raise AssertionError(f"Unknown metric type {metric.type}")
return collected
def _get_samples(metric: PromMetric,
suffix: Optional[str] = None) -> list[Sample]:
name = (metric.name + suffix) if suffix is not None else metric.name
return [s for s in metric.samples if s.name == name]
def _strip_label(labels: dict[str, str], key_to_remove: str) -> dict[str, str]:
labels_copy = labels.copy()
labels_copy.pop(key_to_remove)
return labels_copy
def _digest_histogram(
bucket_samples: list[Sample], count_samples: list[Sample],
sum_samples: list[Sample]
) -> list[tuple[dict[str, str], dict[str, int], int, float]]:
#
# In the case of DP, we have an indigestable
# per-bucket-per-engine count as a list of labelled
# samples, along with total and sum samples
#
# bucket_samples (in):
# labels = {bucket: 100, idx: 0}, value = 2
# labels = {bucket: 200, idx: 0}, value = 4
# labels = {bucket: Inf, idx: 0}, value = 10
# labels = {bucket: 100, idx: 1}, value = 1
# labels = {bucket: 200, idx: 2}, value = 5
# labels = {bucket: Inf, idx: 3}, value = 7
# count_samples (in):
# labels = {idx: 0}, value = 10
# labels = {idx: 1}, value = 7
# sum_samples (in):
# labels = {idx: 0}, value = 2000
# labels = {idx: 1}, value = 1200
#
# output: [
# {idx: 0}, {"100": 2, "200": 4, "Inf": 10}, 10, 2000
# {idx: 1}, {"100": 1, "200": 5, "Inf": 7}, 7, 1200
# ]
buckets_by_labels: dict[frozenset[tuple[str, str]], dict[str, int]] = {}
for s in bucket_samples:
bucket = s.labels["le"]
labels_key = frozenset(_strip_label(s.labels, "le").items())
if labels_key not in buckets_by_labels:
buckets_by_labels[labels_key] = {}
buckets_by_labels[labels_key][bucket] = int(s.value)
counts_by_labels: dict[frozenset[tuple[str, str]], int] = {}
for s in count_samples:
labels_key = frozenset(s.labels.items())
counts_by_labels[labels_key] = int(s.value)
sums_by_labels: dict[frozenset[tuple[str, str]], float] = {}
for s in sum_samples:
labels_key = frozenset(s.labels.items())
sums_by_labels[labels_key] = s.value
assert set(buckets_by_labels.keys()) == set(
counts_by_labels.keys()) == set(sums_by_labels.keys())
output = []
label_keys = list(buckets_by_labels.keys())
for k in label_keys:
labels = dict(k)
output.append((labels, buckets_by_labels[k], counts_by_labels[k],
sums_by_labels[k]))
return output
def _digest_num_accepted_by_pos_samples(
samples: list[Sample]) -> list[tuple[dict[str, str], list[int]]]:
#
# In the case of DP, we have an indigestable
# per-position-per-engine count as a list of
# labelled samples
#
# samples (in):
# labels = {pos: 0, idx: 0}, value = 10
# labels = {pos: 1, idx: 0}, value = 7
# labels = {pos: 2, idx: 0}, value = 2
# labels = {pos: 0, idx: 1}, value = 5
# labels = {pos: 1, idx: 1}, value = 3
# labels = {pos: 2, idx: 1}, value = 1
#
# output: [
# {idx: 0}, [10, 7, 2]
# {idx: 1}, [5, 3, 1]
# ]
#
max_pos = 0
values_by_labels: dict[frozenset[tuple[str, str]], dict[int, int]] = {}
for s in samples:
position = int(s.labels["position"])
max_pos = max(max_pos, position)
labels_key = frozenset(_strip_label(s.labels, "position").items())
if labels_key not in values_by_labels:
values_by_labels[labels_key] = {}
values_by_labels[labels_key][position] = int(s.value)
output = []
for labels_key, values_by_position in values_by_labels.items():
labels = dict(labels_key)
values = [0] * (max_pos + 1)
for pos, val in values_by_position.items():
values[pos] = val
output.append((labels, values))
return output

239
vllm/v1/metrics/stats.py Normal file
View File

@@ -0,0 +1,239 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional
from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.engine.output_processor import RequestState
@dataclass
class PrefixCacheStats:
"""Stores prefix cache hit statistics."""
# Whether reset_prefix_cache was invoked.
reset: bool = False
# The number of requests in this update.
requests: int = 0
# The number of queries in these requests. Note that "queries" here
# means the number of tokens that were queried from the cache.
queries: int = 0
# The number of hits in these requests.
hits: int = 0
@dataclass
class SchedulerStats:
"""Stats associated with the scheduler."""
num_running_reqs: int = 0
num_waiting_reqs: int = 0
gpu_cache_usage: float = 0.0
prefix_cache_stats: PrefixCacheStats = field(
default_factory=PrefixCacheStats)
spec_decoding_stats: Optional[SpecDecodingStats] = None
@dataclass
class LoRAStats:
waiting_requests: set[str] = field(default_factory=set)
running_requests: set[str] = field(default_factory=set)
@dataclass
class RequestStateStats:
"""Stats that need to be tracked across delta updates."""
num_generation_tokens: int = 0
# This is a engine frontend timestamp (wall-clock)
arrival_time: float = 0.0
# These are engine core timestamps (monotonic)
queued_ts: float = 0.0
scheduled_ts: float = 0.0
first_token_ts: float = 0.0
last_token_ts: float = 0.0
@dataclass
class FinishedRequestStats:
"""Stats associated with a finished request."""
finish_reason: "FinishReason"
e2e_latency: float = 0.0
num_prompt_tokens: int = 0
num_generation_tokens: int = 0
max_tokens_param: Optional[int] = None
queued_time: float = 0.0
prefill_time: float = 0.0
inference_time: float = 0.0
decode_time: float = 0.0
class IterationStats:
"""Stats associated with a single set of EngineCoreOutputs."""
def __init__(self):
self.iteration_timestamp = time.time()
self.num_generation_tokens = 0
self.num_prompt_tokens = 0
self.num_preempted_reqs = 0
self.finished_requests: list[FinishedRequestStats] = []
self.max_num_generation_tokens_iter: list[int] = []
self.n_params_iter: list[int] = []
self.time_to_first_tokens_iter: list[float] = []
self.time_per_output_tokens_iter: list[float] = []
self.waiting_lora_adapters: dict[str, int] = {}
self.running_lora_adapters: dict[str, int] = {}
def _time_since(self, start: float) -> float:
"""Calculate an interval relative to this iteration's timestamp."""
return self.iteration_timestamp - start
def update_from_output(self, output: "EngineCoreOutput",
engine_core_timestamp: float, is_prefilling: bool,
prompt_len: int, req_stats: RequestStateStats,
lora_stats: Optional[LoRAStats]):
num_new_generation_tokens = len(output.new_token_ids)
self.num_generation_tokens += num_new_generation_tokens
if is_prefilling:
assert num_new_generation_tokens > 0
self.num_prompt_tokens += prompt_len
first_token_latency = self._time_since(req_stats.arrival_time)
self.time_to_first_tokens_iter.append(first_token_latency)
req_stats.num_generation_tokens += num_new_generation_tokens
# Process request-level engine core events
if output.events is not None:
self.update_from_events(output.request_id, output.events,
is_prefilling, req_stats, lora_stats)
# Process the batch-level "new tokens" engine core event
if is_prefilling:
req_stats.first_token_ts = engine_core_timestamp
else:
tpot = engine_core_timestamp - req_stats.last_token_ts
self.time_per_output_tokens_iter.append(tpot)
req_stats.last_token_ts = engine_core_timestamp
def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
is_prefilling: bool, req_stats: RequestStateStats,
lora_stats: Optional[LoRAStats]):
# Avoid circular dependency
from vllm.v1.engine import EngineCoreEventType
for event in events:
if event.type == EngineCoreEventType.QUEUED:
req_stats.queued_ts = event.timestamp
if lora_stats is not None:
lora_stats.waiting_requests.add(req_id)
elif event.type == EngineCoreEventType.SCHEDULED:
if req_stats.scheduled_ts == 0.0: # ignore preemptions
req_stats.scheduled_ts = event.timestamp
LoRARequestStates.scheduled_request(lora_stats, req_id)
elif event.type == EngineCoreEventType.PREEMPTED:
self.num_preempted_reqs += 1
LoRARequestStates.preempted_request(lora_stats, req_id)
def update_from_finished_request(self, finish_reason: "FinishReason",
num_prompt_tokens: int,
max_tokens_param: Optional[int],
req_stats: RequestStateStats):
e2e_latency = self._time_since(req_stats.arrival_time)
# Queued interval is from first QUEUED event to first SCHEDULED
queued_time = req_stats.scheduled_ts - req_stats.queued_ts
# Prefill interval is from first SCHEDULED to first NEW_TOKEN
# Any preemptions during prefill is included in the interval
prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts
# Decode interval is from first NEW_TOKEN to last NEW_TOKEN
# Any preemptions during decode are included
decode_time = req_stats.last_token_ts - req_stats.first_token_ts
# Inference interval is from first SCHEDULED to last NEW_TOKEN
# Any preemptions during prefill or decode are included
inference_time = req_stats.last_token_ts - req_stats.scheduled_ts
finished_req = \
FinishedRequestStats(finish_reason=finish_reason,
e2e_latency=e2e_latency,
num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=req_stats.num_generation_tokens,
max_tokens_param=max_tokens_param,
queued_time=queued_time,
prefill_time=prefill_time,
inference_time=inference_time,
decode_time=decode_time)
self.finished_requests.append(finished_req)
class LoRARequestStates:
"""Per-LoRA request state stats."""
def __init__(self):
self.lora_name_to_stats: dict[str, LoRAStats] = {}
def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]:
if req_state.lora_name is None:
return None
if req_state.lora_name not in self.lora_name_to_stats:
self.lora_name_to_stats[req_state.lora_name] = LoRAStats()
return self.lora_name_to_stats[req_state.lora_name]
def add_request(self, req_state: 'RequestState'):
if (lora_stats := self.get_stats(req_state)) is not None:
lora_stats.waiting_requests.add(req_state.request_id)
def finish_request(self, req_state: 'RequestState'):
if req_state.lora_name is None:
return
lora_stats = self.lora_name_to_stats[req_state.lora_name]
lora_stats.running_requests.remove(req_state.request_id)
def abort_request(self, req_state: 'RequestState'):
if req_state.lora_name is None:
return
lora_stats = self.lora_name_to_stats[req_state.lora_name]
lora_stats.waiting_requests.discard(req_state.request_id)
lora_stats.running_requests.discard(req_state.request_id)
# Break the pattern for this lifecycle methods so we can
# call this from IterationStats.update_from_events()
@staticmethod
def scheduled_request(lora_stats: Optional[LoRAStats], request_id: str):
if lora_stats is None:
return
lora_stats.waiting_requests.remove(request_id)
lora_stats.running_requests.add(request_id)
@staticmethod
def preempted_request(lora_stats: Optional[LoRAStats], request_id: str):
if lora_stats is None:
return
lora_stats.running_requests.remove(request_id)
lora_stats.waiting_requests.add(request_id)
def update_iteration_stats(self,
iteration_stats: Optional[IterationStats]):
if iteration_stats is None:
return
for lora_name, stats in self.lora_name_to_stats.items():
if stats.waiting_requests:
iteration_stats.waiting_lora_adapters[lora_name] = \
len(stats.waiting_requests)
if stats.running_requests:
iteration_stats.running_lora_adapters[lora_name] = \
len(stats.running_requests)

116
vllm/v1/outputs.py Normal file
View File

@@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import NamedTuple, Optional
import torch
class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: list[list[int]]
# [num_reqs, max_num_logprobs + 1]
logprobs: list[list[float]]
# [num_reqs]
sampled_token_ranks: list[int]
def slice(self, start: int, end: int):
return LogprobsLists(
self.logprob_token_ids[start:end],
self.logprobs[start:end],
self.sampled_token_ranks[start:end],
)
class LogprobsTensors(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: torch.Tensor
# [num_reqs, max_num_logprobs + 1]
logprobs: torch.Tensor
# [num_reqs]
selected_token_ranks: torch.Tensor
def tolists(self):
return LogprobsLists(
self.logprob_token_ids.tolist(),
self.logprobs.tolist(),
self.selected_token_ranks.tolist(),
)
@staticmethod
def empty_cpu(num_positions: int,
num_tokens_per_position: int) -> "LogprobsTensors":
"""Create empty LogprobsTensors on CPU."""
logprob_token_ids = torch.empty(
(num_positions, num_tokens_per_position),
dtype=torch.int32,
device="cpu")
logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
selected_token_ranks = torch.empty(num_positions,
dtype=torch.int32,
device="cpu")
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=selected_token_ranks,
)
@dataclass
class SamplerOutput:
# [num_reqs, max_num_generated_tokens]
# Different requests can have different number of generated tokens.
# All requests are padded to max_num_generated_tokens.
# PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
sampled_token_ids: torch.Tensor
logprobs_tensors: Optional[LogprobsTensors]
# ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use list instead.
@dataclass
class ModelRunnerOutput:
# [num_reqs]
req_ids: list[str]
# req_id -> index
req_id_to_index: dict[str, int]
# num_reqs x num_generated_tokens
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
sampled_token_ids: list[list[int]]
# num_reqs x num_spec_tokens
spec_token_ids: Optional[list[list[int]]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs]
logprobs: Optional[LogprobsLists]
# req_id -> (token_ids, logprobs, ranks)
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
finished_sending=None,
finished_recving=None)

193
vllm/v1/request.py Normal file
View File

@@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreRequest, FinishReason)
from vllm.v1.structured_output.request import StructuredOutputRequest
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
class Request:
def __init__(
self,
request_id: str,
prompt_token_ids: list[int],
multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: SamplingParams,
eos_token_id: Optional[int],
client_index: int = 0,
lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
) -> None:
self.request_id = request_id
self.client_index = client_index
self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.structured_output_request = structured_output_request
self.status = (RequestStatus.WAITING_FOR_FSM
if sampling_params.guided_decoding is not None else
RequestStatus.WAITING)
self.events: list[EngineCoreEvent] = []
self.stop_reason: Union[int, str, None] = None
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
self.prompt_token_ids = prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: list[int] = []
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
self.cache_salt: Optional[str] = cache_salt
# Multi-modal related
self.mm_positions = multi_modal_placeholders or []
self.mm_inputs = multi_modal_inputs or []
self.mm_hashes: list[str] = multi_modal_hashes or []
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# P/D: Connector-specific KV transfer parameters.
kv_params = (None if sampling_params.extra_args is None else
sampling_params.extra_args.get("kv_transfer_params"))
self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
if self.mm_hashes:
assert len(self.mm_inputs) == len(self.mm_hashes)
# Read-only views
# Prevent directly appending to these lists since
# they should also be updated simultaneously.
self.output_token_ids = ConstantList(self._output_token_ids)
self.all_token_ids = ConstantList(self._all_token_ids)
# State
# The number of tokens with prefix cache hits.
self.num_cached_tokens = -1
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None:
assert isinstance(request.mm_inputs, list)
assert is_list_of(request.mm_inputs, MultiModalKwargs), (
"mm_inputs was not updated in EngineCore.add_request")
return cls(
request_id=request.request_id,
client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids,
multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params,
eos_token_id=request.eos_token_id,
lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params),
cache_salt=request.cache_salt,
)
def append_output_token_ids(
self,
token_ids: Union[int, list[int]],
) -> None:
if isinstance(token_ids, int):
self._output_token_ids.append(token_ids)
self._all_token_ids.append(token_ids)
else:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
@property
def num_tokens(self) -> int:
return len(self._all_token_ids)
@property
def num_tokens_with_spec(self) -> int:
return len(self._all_token_ids) + len(self.spec_token_ids)
@property
def num_output_tokens(self) -> int:
return len(self._output_token_ids)
def is_finished(self) -> bool:
return RequestStatus.is_finished(self.status)
def get_finished_reason(self) -> Union[FinishReason, None]:
return RequestStatus.get_finished_reason(self.status)
def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions)
num_tokens = self.mm_positions[input_id].length
return num_tokens
@property
def use_structured_output(self) -> bool:
return self.sampling_params.guided_decoding is not None
def record_event(
self,
event_type: EngineCoreEventType,
timestamp: Optional[float] = None,
) -> None:
self.events.append(EngineCoreEvent.new_event(event_type, timestamp))
def take_events(self) -> Optional[list[EngineCoreEvent]]:
if not self.events:
return None
events, self.events = self.events, []
return events
class RequestStatus(enum.IntEnum):
"""Status of a request."""
WAITING = enum.auto()
WAITING_FOR_FSM = enum.auto()
WAITING_FOR_REMOTE_KVS = enum.auto()
RUNNING = enum.auto()
PREEMPTED = enum.auto()
# Note: anything after PREEMPTED will be considered
# as a finished status.
FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
@staticmethod
def is_finished(status: "RequestStatus") -> bool:
return status > RequestStatus.PREEMPTED
@staticmethod
def get_finished_reason(
status: "RequestStatus") -> Union[FinishReason, None]:
return _FINISHED_REASON_MAP.get(status)
# Mapping of finished statuses to their finish reasons.
# NOTE: The ignored requests are the requests whose prompt lengths
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
}

View File

View File

@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
@dataclass
class SamplingMetadata:
temperature: Optional[torch.Tensor]
all_greedy: bool
all_random: bool
top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
min_p: Optional[torch.Tensor]
generators: dict[int, torch.Generator]
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: Optional[int]
no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]
frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
output_token_ids: list[list[int]]
# req_index -> (min_tokens, stop_token_ids)
min_tokens: dict[int, tuple[int, set[int]]]
logit_bias: list[Optional[dict[int, float]]]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask: Optional[torch.Tensor]
# req_index -> bad_words_token_ids
bad_words_token_ids: dict[int, list[list[int]]]

View File

View File

@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
_SMALLEST_LOGIT = float("-inf")
def _apply_bad_words_single_batch(
logits: torch.Tensor,
bad_words_token_ids: list[list[int]],
past_tokens_ids: list[int],
) -> None:
for bad_word_ids in bad_words_token_ids:
if len(bad_word_ids) > len(past_tokens_ids) + 1:
continue
prefix_length = len(bad_word_ids) - 1
last_token_id = bad_word_ids[-1]
if prefix_length > 0:
actual_prefix = past_tokens_ids[-prefix_length:]
else:
actual_prefix = []
expected_prefix = bad_word_ids[:prefix_length]
assert len(actual_prefix) == len(expected_prefix)
if actual_prefix == expected_prefix:
logits[last_token_id] = _SMALLEST_LOGIT
def apply_bad_words(
logits: torch.Tensor,
bad_words_token_ids: dict[int, list[list[int]]],
past_tokens_ids: list[list[int]],
) -> None:
for i, bad_words_ids in bad_words_token_ids.items():
_apply_bad_words_single_batch(logits[i], bad_words_ids,
past_tokens_ids[i])

View File

@@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.utils import apply_penalties
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
def apply_min_token_penalties(
logits: torch.Tensor, output_token_ids: list[list[int]],
min_tokens: dict[int, tuple[int, set[int]]]) -> None:
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize: list[tuple[int, int]] = []
for index, (min_token, stop_token_ids) in min_tokens.items():
if len(output_token_ids[index]) < min_token:
for stop_token_id in stop_token_ids:
min_tokens_logits_to_penalize.append((index, stop_token_id))
if min_tokens_logits_to_penalize:
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
def apply_all_penalties(
logits: torch.Tensor,
prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
logits.device)
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
presence_penalties, frequency_penalties,
repetition_penalties)
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)

View File

@@ -0,0 +1,293 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.nn as nn
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
try:
import flashinfer.sampling
is_flashinfer_available = True
except ImportError:
is_flashinfer_available = False
class TopKTopPSampler(nn.Module):
"""
Module that performs optional top-k and top-p filtering followed by
weighted random sampling of logits.
Implementations may update the logits tensor in-place.
"""
def __init__(self):
super().__init__()
if current_platform.is_cuda():
if is_flashinfer_available:
flashinfer_version = flashinfer.__version__
if flashinfer_version < "0.2.3":
logger.warning(
"FlashInfer version >= 0.2.3 required. "
"Falling back to default sampling implementation.")
self.forward = self.forward_native
elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
# interpret it differently in V0 and V1 samplers: In V0,
# None means False, while in V1, None means True. This is
# why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
logger.info("Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_cuda
else:
logger.warning(
"FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation of "
"top-p & top-k sampling. For the best performance, "
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
self.forward = self.forward_native
else:
logger.warning(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_tpu():
self.forward = self.forward_tpu
else:
self.forward = self.forward_native
def forward_native(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def forward_cuda(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
if k is None and p is None:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
if generators:
logger.warning("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
return flashinfer_sample(logits, k, p, generators)
def forward_tpu(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
logits = apply_top_k_top_p_tpu(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def apply_top_k_top_p_tpu(
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k and top-p optimized for TPU.
This algorithm avoids using torch.scatter which is extremely slow on TPU.
This is achieved by finding a "cut-off" element in the original logit, and
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.
Note: in the case of tie (i.e. multipple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
if k is not None:
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
top_k_count = top_k_count.unsqueeze(dim=1)
top_k_cutoff = probs_sort.gather(-1, top_k_count)
# Make sure the no top-k rows are no-op.
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
elements_to_discard = probs < top_k_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
if p is not None:
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
return logits
def apply_top_k_top_p(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
if p is None:
if k is None:
return logits
# Avoid sorting vocab for top-k only case.
return apply_top_k_only(logits, k)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def apply_top_k_only(
logits: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
"""
no_top_k_mask = k == logits.shape[1]
# Set non-top-k rows to 1 so that we can gather.
k = k.masked_fill(no_top_k_mask, 1)
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))
return logits
def random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def flashinfer_sample(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the logits using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
assert not (k is None and p is None)
if k is None:
# Top-p only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
probs, p, deterministic=True)
elif p is None:
# Top-k only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
probs, k, deterministic=True)
else:
# Both top-k and top-p.
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
logits, k, p, deterministic=True)
return next_token_ids.view(-1)

View File

@@ -0,0 +1,631 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
logger = init_logger(__name__)
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = -1
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32
class RejectionSampler(nn.Module):
"""
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
end of the sequence. The bonus token is only sampled from the target
probabilities. We pass in the bonus tokens instead of sampling them
in the rejection sampler to allow for more flexibility in the
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def forward(
self,
metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
target_probs = compute_probs(
target_logits,
metadata.cu_num_draft_tokens,
sampling_metadata,
)
output_token_ids = rejection_sample(
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_probs,
bonus_token_ids,
sampling_metadata,
)
return output_token_ids
@staticmethod
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
) -> list[list[int]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
outputs = [
row[valid_mask[i]].tolist()
for i, row in enumerate(output_token_ids_np)
]
return outputs
def rejection_sample(
# [num_tokens]
draft_token_ids: torch.Tensor,
# [batch_size]
num_draft_tokens: list[int],
max_spec_len: int,
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2
batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1]
device = target_probs.device
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
# Create output buffer.
output_token_ids = torch.empty(
(batch_size, max_spec_len + 1),
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
device=device,
)
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
if sampling_metadata.all_greedy:
is_greedy = None
else:
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
rejection_greedy_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
num_warps=1,
)
if sampling_metadata.all_greedy:
return output_token_ids
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs = generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device,
)
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
num_warps=1,
)
return output_token_ids
def compute_probs(
logits: torch.Tensor, # [num_tokens, vocab_size]
cu_num_draft_tokens: torch.Tensor, # [batch_size]
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Compute probability distribution from logits based on sampling metadata.
This function applies temperature scaling to the logits and converts
them to probabilities using softmax. For greedy decoding, it returns
the original logits.
Args:
logits: Input logits tensor to be converted to probabilities.
cu_num_draft_tokens: Cumulative number of draft tokens.
sampling_metadata: Metadata containing sampling parameters such as
temperature and whether greedy sampling is used.
Returns:
torch.Tensor: Probability distribution (softmax of scaled logits)
if non-greedy sampling is used, otherwise returns the
original logits.
"""
assert logits.ndim == 2
assert cu_num_draft_tokens.ndim == 1
if sampling_metadata.all_greedy:
return logits
num_tokens = logits.shape[0]
temperature = expand_batch_to_tokens(
sampling_metadata.temperature,
cu_num_draft_tokens,
num_tokens,
replace_from=GREEDY_TEMPERATURE,
replace_to=1,
)
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
logits.div_(temperature.unsqueeze(-1))
# Get expanded top_k and top_p tensors.
top_k = None
if sampling_metadata.top_k is not None:
top_k = expand_batch_to_tokens(
sampling_metadata.top_k,
cu_num_draft_tokens,
num_tokens,
)
top_p = None
if sampling_metadata.top_p is not None:
top_p = expand_batch_to_tokens(
sampling_metadata.top_p,
cu_num_draft_tokens,
num_tokens,
)
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
logits = apply_top_k_top_p(logits, top_k, top_p)
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
return output_prob
def expand_batch_to_tokens(
x: torch.Tensor, # [batch_size]
cu_num_tokens: torch.Tensor, # [batch_size]
num_tokens: int,
replace_from: int = 0,
replace_to: int = 0,
) -> torch.Tensor:
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
tokens per batch in cu_num_tokens.
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
Args:
x: [batch_size] tensor to expand.
cu_num_tokens: [batch_size] tensor containing the cumulative number of
tokens per batch. Each element represents the total number of
tokens up to and including that batch.
num_tokens: Total number of tokens.
replace_from: int = 0
Value to be replaced if it is found in x.
replace_to: int = 0
Value to replace with when replace_from is found.
Returns:
expanded_x: [num_tokens] tensor.
"""
batch_size = x.shape[0]
assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens)
expand_kernel[(batch_size, )](
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
num_warps=1,
)
return expanded_x
def generate_uniform_probs(
num_tokens: int,
num_draft_tokens: list[int],
generators: dict[int, torch.Generator],
device: torch.device,
) -> torch.Tensor:
"""
Generates a batch of uniform random samples, with optional seeding
if available.
This method creates a tensor of shape `(num_tokens, )` filled
with uniform random values in the range [0, 1). If `generators` is provided,
the requests with their own seeds will use the provided `torch.Generator`
for reproducibility. The samples for the other requests will be generated
without a seed.
Args:
num_tokens : int
Total number of tokens.
num_draft_tokens : List[List[int]]
Number of draft tokens per request.
generators : Optional[Dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(num_tokens, )` containing uniform
random values in the range [0, 1).
"""
uniform_probs = torch.rand(
(num_tokens, ),
dtype=torch.float32,
device=device,
)
start_idx = 0
for req_idx, n in enumerate(num_draft_tokens):
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if n == 0:
continue
end_idx = start_idx + n
generator = generators.get(req_idx)
if generator is not None:
uniform_probs[start_idx:end_idx].uniform_(generator=generator)
start_idx = end_idx
return uniform_probs
def sample_recovered_tokens(
max_spec_len: int,
num_draft_tokens: list[int],
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens]
draft_token_ids: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
device: torch.device,
) -> torch.Tensor:
# NOTE(woosuk): Create only one distribution for each request.
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
device=device,
)
q.exponential_()
for i, generator in sampling_metadata.generators.items():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)
recovered_token_ids = torch.empty_like(draft_token_ids)
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
recovered_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
q,
vocab_size,
triton.next_power_of_2(vocab_size),
NO_DRAFT_PROBS=draft_probs is None,
)
return recovered_token_ids
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
max_spec_len,
):
req_idx = tl.program_id(0)
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
# re-compilation may happen during runtime when is_greedy_ptr is None.
if is_greedy_ptr is None:
is_greedy = True
else:
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy is None:
# Early exit for non-greedy sampling requests.
return
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy is not None:
# Early exit for greedy sampling requests.
return
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept.
token_id = draft_token_id
else:
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
def expand_kernel(
output_ptr, # [num_tokens]
input_ptr, # [batch_size]
cu_num_tokens_ptr, # [batch_size]
replace_from,
replace_to,
MAX_NUM_TOKENS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0: # noqa: SIM108
start_idx = 0
else:
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
num_tokens = end_idx - start_idx
src_val = tl.load(input_ptr + req_idx)
src_val = tl.where(src_val == replace_from, replace_to, src_val)
offset = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx + offset,
src_val,
mask=offset < num_tokens)
@triton.jit
def sample_recovered_tokens_kernel(
output_token_ids_ptr, # [num_tokens]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
# Early exit for out-of-range positions.
pos = tl.program_id(1)
if pos >= num_draft_tokens:
return
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
draft_token_id)
# Temporarily zero out the probability of the draft token.
# This is essentially the same as target_prob - draft_prob, except that
# n-gram does not have draft_prob. We regard it as 1.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
0)
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
else:
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
if NO_DRAFT_PROBS:
# Restore the original probability.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
orig_prob)

286
vllm/v1/sample/sampler.py Normal file
View File

@@ -0,0 +1,286 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A layer that samples the next tokens from the model's outputs."""
import torch
import torch.nn as nn
from vllm.utils import async_tensor_h2d, is_pin_memory_available
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties)
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
self.pin_memory = is_pin_memory_available()
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# TODO(rob): provide option for logprobs post sampling.
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits)
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Apply allowed token ids.
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
# Apply bad words exclusion.
logits = self.apply_bad_words(logits, sampling_metadata)
# Apply logits bias.
logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Convert sampled token ids to int64 (long) type to ensure compatibility
# with subsequent operations that may use these values as indices.
# This conversion is necessary because FlashInfer sampling operations
# return int32 (while PyTorch argmax and topk return int64).
sampled = sampled.long()
# Gather the logprobs of the topk and sampled token (if requested).
# Get logprobs and rank tensors (if requested)
logprobs_tensors = None if num_logprobs is None else \
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=logprobs_tensors,
)
return sampler_output
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use in-place division to avoid creating a new tensor.
return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
may update the logits tensor in-place.
"""
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_random:
greedy_sampled = None
else:
greedy_sampled = self.greedy_sample(logits)
if sampling_metadata.all_greedy:
return greedy_sampled
assert sampling_metadata.temperature is not None
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply min_p.
if sampling_metadata.min_p is not None:
logits = self.apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
if greedy_sampled is None:
return random_sampled
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
return sampled
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
def gather_logprobs(
self,
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == torch.int64
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs,
dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size.
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
def apply_penalties(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.min_tokens:
apply_min_token_penalties(logits,
sampling_metadata.output_token_ids,
sampling_metadata.min_tokens)
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids,
)
return logits
def apply_min_p(
self,
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values,
dim=-1,
keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
valid_token_mask = probability_values >= adjusted_min_p
# Apply mask using boolean indexing
logits[~valid_token_mask] = -float('inf')
return logits
def apply_logits_bias(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
rows: list[int] = []
cols: list[int] = []
vals: list[float] = []
# Get vocabulary size from logits
vocab_size = logits.shape[-1]
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
if logit_bias:
for token_id, bias in logit_bias.items():
# Check token_id bounds to ensure within vocabulary
if token_id < 0 or token_id >= vocab_size:
raise ValueError(
f"token_id {token_id} in logit_bias contains "
f"out-of-vocab token id. Vocabulary size: "
f"{vocab_size}")
rows.append(i)
cols.append(token_id)
vals.append(bias)
if rows:
indices = async_tensor_h2d([rows, cols], torch.int64,
logits.device, self.pin_memory)
values = async_tensor_h2d(vals, torch.float, logits.device,
self.pin_memory)
logits.index_put_(tuple(indices), values=values, accumulate=True)
return logits
def apply_allowed_token_ids(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.allowed_token_ids_mask is not None:
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
float("-inf"))
return logits
def apply_bad_words(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.bad_words_token_ids:
apply_bad_words(
logits,
sampling_metadata.bad_words_token_ids,
sampling_metadata.output_token_ids,
)
return logits

View File

View File

@@ -0,0 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Optional
import torch
from vllm.v1.worker.gpu_input_batch import InputBatch
DEFAULT_SAMPLING_PARAMS = dict(
temperature=-1.0,
min_p=0.0,
# strictly disabled for now
top_k=0,
top_p=1.0,
# frequency_penalties=0.0,
# presence_penalties=0.0,
# repetition_penalties=0.0,
)
@dataclass
class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
temperature: torch.Tensor = None
min_p: torch.Tensor = None
top_k: torch.Tensor = None
top_p: torch.Tensor = None
all_greedy: bool = True
# Whether logprobs are to be gathered in this batch of request. To balance
# out compile time and runtime, a fixed `max_number_logprobs` value is used
# when gathering logprobs, regardless of the values specified in the batch.
logprobs: bool = False
# TODO No penalties for now
no_penalties: bool = True
prompt_token_ids = None
frequency_penalties = None
presence_penalties = None
repetition_penalties = None
# should use tensor
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
min_tokens = None # impl is not vectorized
logit_bias: list[Optional[dict[int, float]]] = field(
default_factory=lambda: list())
allowed_token_ids_mask = None
bad_words_token_ids = None
# Generator not supported by xla
_generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())
@property
def generators(self) -> dict[int, torch.Generator]:
# Generator not supported by torch/xla. This field must be immutable.
return self._generators
@classmethod
def from_input_batch(
cls,
input_batch: InputBatch,
padded_num_reqs: int,
xla_device: torch.device,
generate_params_if_all_greedy: bool = False
) -> "TPUSupportedSamplingMetadata":
"""
Copy sampling tensors slices from `input_batch` to on device tensors.
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
Args:
input_batch: The input batch containing sampling parameters.
padded_num_reqs: The padded number of requests.
xla_device: The XLA device.
generate_params_if_all_greedy: If True, generate sampling parameters
even if all requests are greedy. this is useful for cases where
we want to pre-compile a graph with sampling parameters, even if
they are not strictly needed for greedy decoding.
"""
needs_logprobs = input_batch.max_num_logprobs>0 if \
input_batch.max_num_logprobs else False
# Early return to avoid unnecessary cpu to tpu copy
if (input_batch.all_greedy is True
and generate_params_if_all_greedy is False):
return cls(all_greedy=True, logprobs=needs_logprobs)
num_reqs = input_batch.num_reqs
def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
# Pad value is the default one.
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
fill_slice(input_batch.temperature_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["temperature"])
fill_slice(input_batch.min_p_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["min_p"])
fill_slice(input_batch.top_k_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["top_k"])
fill_slice(input_batch.top_p_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["top_p"])
# Slice persistent device tensors to a fixed pre-compiled padded shape.
return cls(
temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
to(xla_device),
all_greedy=input_batch.all_greedy,
# TODO enable more and avoid returning None values
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(
xla_device),
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
xla_device),
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
xla_device),
logprobs=needs_logprobs)

View File

@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sampler layer implementing TPU supported operations."""
import torch
import torch.nn as nn
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
def forward(
self,
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> SamplerOutput:
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# These are TPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=None)
return sampler_output
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
greedy_sampled = self.greedy_sample(logits)
assert sampling_metadata.temperature is not None
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply min_p.
if sampling_metadata.min_p is not None:
logits = self.apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled, random_sampled)
return sampled
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
def gather_logprobs(
self,
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logits: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs,
dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size.
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
def apply_min_p(
self,
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values,
dim=-1,
keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
valid_token_mask = probability_values >= adjusted_min_p
# Apply mask using boolean indexing (xla friendly)
logits.masked_fill_(~valid_token_mask, -float("inf"))
return logits

315
vllm/v1/serial_utils.py Normal file
View File

@@ -0,0 +1,315 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import pickle
from collections.abc import Sequence
from inspect import isclass
from types import FunctionType
from typing import Any, Optional, Union
import cloudpickle
import numpy as np
import torch
import zmq
from msgspec import msgpack
from vllm import envs
from vllm.logger import init_logger
from vllm.multimodal.inputs import (BaseMultiModalField,
MultiModalBatchedField,
MultiModalFieldConfig, MultiModalFieldElem,
MultiModalFlatField, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField, NestedTensors)
logger = init_logger(__name__)
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
CUSTOM_TYPE_RAW_VIEW = 3
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
MultiModalFlatField: "flat",
MultiModalSharedField: "shared",
MultiModalBatchedField: "batched",
}
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
def _log_insecure_serialization_warning():
logger.warning_once("Allowing insecure serialization using pickle due to "
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization.
Note that unlike vanilla `msgspec` Encoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
By default, arrays below 256B are serialized inline Larger will get sent
via dedicated messages. Note that this is a per-tensor limit.
"""
def __init__(self, size_threshold: Optional[int] = None):
if size_threshold is None:
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
# This is used as a local stash of buffers that we can then access from
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
# pass custom data to the hook otherwise.
self.aux_buffers: Optional[list[bytestr]] = None
self.size_threshold = size_threshold
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
_log_insecure_serialization_warning()
def encode(self, obj: Any) -> Sequence[bytestr]:
try:
self.aux_buffers = bufs = [b'']
bufs[0] = self.encoder.encode(obj)
# This `bufs` list allows us to collect direct pointers to backing
# buffers of tensors and np arrays, and return them along with the
# top-level encoded buffer instead of copying their data into the
# new buffer.
return bufs
finally:
self.aux_buffers = None
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
try:
self.aux_buffers = [buf]
bufs = self.aux_buffers
self.encoder.encode_into(obj, buf)
return bufs
finally:
self.aux_buffers = None
def enc_hook(self, obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
return self._encode_tensor(obj)
# Fall back to pickle for object or void kind ndarrays.
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
return self._encode_ndarray(obj)
if isinstance(obj, slice):
# We are assuming only int-based values will be used here.
return tuple(
int(v) if v is not None else None
for v in (obj.start, obj.stop, obj.step))
if isinstance(obj, MultiModalKwargs):
mm: MultiModalKwargs = obj
if not mm.modalities:
# just return the main dict if there are no modalities.
return dict(mm)
# ignore the main dict, it will be re-indexed.
# Encode a list of MultiModalKwargsItems as plain dicts
# + special handling for .field.
# Any tensors *not* indexed by modality will be ignored.
return [[{
"modality": elem.modality,
"key": elem.key,
"data": self._encode_nested_tensors(elem.data),
"field": self._encode_mm_field(elem.field),
} for elem in item.values()]
for itemlist in mm._items_by_modality.values()
for item in itemlist]
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
raise TypeError(f"Object of type {type(obj)} is not serializable"
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
"fallback to pickle-based serialization.")
if isinstance(obj, FunctionType):
# `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods.
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
return msgpack.Ext(CUSTOM_TYPE_PICKLE,
pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
def _encode_ndarray(
self, obj: np.ndarray
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
# If the array is non-contiguous, we need to copy it first
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
if not obj.shape or obj.nbytes < self.size_threshold:
# Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
else:
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
self.aux_buffers.append(arr_data)
# We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`.
return obj.dtype.str, obj.shape, data
def _encode_tensor(
self, obj: torch.Tensor
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
# view the tensor as a contiguous 1D array of bytes
arr = obj.flatten().contiguous().view(torch.uint8).numpy()
if obj.nbytes < self.size_threshold:
# Smaller tensors are encoded inline, just like ndarrays.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
else:
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
self.aux_buffers.append(arr.data)
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
if isinstance(nt, torch.Tensor):
return self._encode_tensor(nt)
if isinstance(nt, (int, float)):
# Although it violates NestedTensors type, MultiModalKwargs
# values are sometimes floats.
return nt
return [self._encode_nested_tensors(x) for x in nt]
def _encode_mm_field(self, field: BaseMultiModalField):
# Figure out the factory name for the field type.
name = MMF_CLASS_TO_FACTORY.get(field.__class__)
if not name:
raise TypeError(f"Unsupported field type: {field.__class__}")
# We just need to copy all of the field values in order
# which will be then used to reconstruct the field.
field_values = (getattr(field, f.name)
for f in dataclasses.fields(field))
return name, *field_values
class MsgpackDecoder:
"""Decoder with custom torch tensor and numpy array serialization.
Note that unlike vanilla `msgspec` Decoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def __init__(self, t: Optional[Any] = None):
args = () if t is None else (t, )
self.decoder = msgpack.Decoder(*args,
ext_hook=self.ext_hook,
dec_hook=self.dec_hook)
self.aux_buffers: Sequence[bytestr] = ()
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
_log_insecure_serialization_warning()
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
# TODO - This check can become `isinstance(bufs, bytestr)`
# as of Python 3.10.
return self.decoder.decode(bufs)
self.aux_buffers = bufs
try:
return self.decoder.decode(bufs[0])
finally:
self.aux_buffers = ()
def dec_hook(self, t: type, obj: Any) -> Any:
# Given native types in `obj`, convert to type `t`.
if isclass(t):
if issubclass(t, np.ndarray):
return self._decode_ndarray(obj)
if issubclass(t, torch.Tensor):
return self._decode_tensor(obj)
if t is slice:
return slice(*obj)
if issubclass(t, MultiModalKwargs):
if isinstance(obj, list):
return MultiModalKwargs.from_items(
self._decode_mm_items(obj))
return MultiModalKwargs({
k: self._decode_nested_tensors(v)
for k, v in obj.items()
})
return obj
def _decode_ndarray(self, arr: Any) -> np.ndarray:
dtype, shape, data = arr
# zero-copy decode. We assume the ndarray will not be kept around,
# as it now locks the whole received message buffer in memory.
buffer = self.aux_buffers[data] if isinstance(data, int) else data
return np.frombuffer(buffer, dtype=dtype).reshape(shape)
def _decode_tensor(self, arr: Any) -> torch.Tensor:
dtype, shape, data = arr
# Copy from inline representation, to decouple the memory storage
# of the message from the original buffer. And also make Torch
# not complain about a readonly memoryview.
buffer = self.aux_buffers[data] if isinstance(data, int) \
else bytearray(data)
torch_dtype = getattr(torch, dtype)
assert isinstance(torch_dtype, torch.dtype)
if not buffer: # torch.frombuffer doesn't like empty buffers
assert 0 in shape
return torch.empty(shape, dtype=torch_dtype)
# Create uint8 array
arr = torch.frombuffer(buffer, dtype=torch.uint8)
# Convert back to proper shape & type
return arr.view(torch_dtype).view(shape)
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
decoded_items = []
for item in obj:
elems = []
for v in item:
v["data"] = self._decode_nested_tensors(v["data"])
# Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = v["field"]
factory_meth = getattr(MultiModalFieldConfig,
factory_meth_name)
# Special case: decode the union "slices" field of
# MultiModalFlatField
if factory_meth_name == "flat":
field_args[0] = self._decode_nested_slices(field_args[0])
v["field"] = factory_meth(None, *field_args).field
elems.append(MultiModalFieldElem(**v))
decoded_items.append(MultiModalKwargsItem.from_elems(elems))
return decoded_items
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
if isinstance(obj, (int, float)):
# Although it violates NestedTensors type, MultiModalKwargs
# values are sometimes floats.
return obj
if not isinstance(obj, list):
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
if obj and isinstance(obj[0], str):
return self._decode_tensor(obj)
return [self._decode_nested_tensors(x) for x in obj]
def _decode_nested_slices(self, obj: Any) -> Any:
assert isinstance(obj, (list, tuple))
if obj and not isinstance(obj[0], (list, tuple)):
return slice(*obj)
return [self._decode_nested_slices(x) for x in obj]
def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_RAW_VIEW:
return data
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
raise NotImplementedError(
f"Extension type code {code} is not supported")

View File

View File

@@ -0,0 +1,434 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
FlashAttentionMetadata)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
logger = init_logger(__name__)
PADDING_SLOT_ID = -1
class EagleProposer:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
self.runner = runner
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
self.num_speculative_tokens = (
self.speculative_config.num_speculative_tokens)
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
self.cudagraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
# persistent buffers for cuda graph
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=device)
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
1,
device=device,
dtype=torch.int32)
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
# FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int()
if self.method in ["eagle", "eagle3"]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] -
cu_num_tokens[:-1]).max().item()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_num_tokens,
query_start_loc=cu_num_tokens,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=target_slot_mapping,
# TODO(woosuk): Support cascade attention.
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
elif self.method == "deepseek_mtp":
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens,
seq_lens=seq_lens,
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise ValueError(f"Unsupported method: {self.method}")
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
ret_hidden_states = self.model(
self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens],
self.hidden_states[:num_input_tokens],
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
attn_metadata.seq_lens += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
self.input_ids[:input_batch_size],
self.positions[:input_batch_size],
self.hidden_states[:input_batch_size],
)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
# TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
@staticmethod
def prepare_inputs(
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_target_query_lens.device,
)
batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024
prepare_eagle_input_kernel[(batch_size, )](
token_indices,
cu_target_query_lens,
cu_num_tokens,
BLOCK_SIZE=BLOCK_SIZE,
)
return cu_num_tokens, token_indices
def load_model(self, target_model: nn.Module) -> None:
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
self.model = get_model(vllm_config=self.vllm_config,
model_config=draft_model_config)
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names)
self.attn_layer_names = list(draft_attn_layer_names)
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1 \
and self.model.model.embed_tokens.weight.shape \
== target_model.model.embed_tokens.weight.shape:
logger.info(
"Assuming the EAGLE head shares the same vocab embedding" \
" with the target model."
)
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_model.model.embed_tokens
else:
logger.info(
"The EAGLE head's vocab embedding will be loaded separately" \
" from the target model."
)
# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.vllm_config.speculative_config.method != "eagle3" and \
hasattr(target_model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.")
if supports_multimodal(target_model):
self.model.lm_head = target_model.get_language_model().lm_head
else:
self.model.lm_head = target_model.lm_head
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
)
def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None:
"""
Validate that all eagle layers belong to the same KVCacheGroup.
Need this assumption to ensure all eagle layers can use the
same AttentionMetadata.
May extend to multiple AttentionMetadata in the future.
"""
kv_cache_groups: dict[str, int] = {}
for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
kv_cache_groups[layer_name] = id
assert len(
set([
kv_cache_groups[layer_name]
for layer_name in self.attn_layer_names
])
) == 1, "All eagle layers should belong to the same kv cache group"
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling.
# Therefore, we can just return the logits.
probs = logits
next_token_ids = logits.argmax(dim=-1)
return next_token_ids, probs
is_greedy = sampling_metadata.temperature == -1
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits.div_(temperature.view(-1, 1))
probs = logits.softmax(dim=-1, dtype=torch.float32)
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
# generating the draft tokens. We only use the temperature. While this
# could degrade the acceptance rate, it does not affect the distribution
# of the generated tokens after rejection sampling.
# TODO(woosuk): Consider seeds.
q = torch.empty_like(probs)
q.exponential_()
# NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
# will be used later for rejection sampling.
next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where(
is_greedy,
greedy_token_ids,
next_token_ids,
)
return next_token_ids, probs

View File

@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.sample.metadata import SamplingMetadata
# Initialize logger
logger = init_logger(__name__)
class MedusaProposer:
"""
Medusa proposer class for generating token sequences
"""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
# Save config parameters
self.vllm_config = vllm_config
self.device = device
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
self.hidden_size = vllm_config.speculative_config.\
draft_model_config.get_hidden_size(
)
self.dtype = vllm_config.model_config.dtype
def propose(
self,
target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# Generate blocks and compute logits
blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks, None)
# Get draft tokens and transpose the result
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
return [list(row) for row in zip(*draft_tokens)]
def load_model(self, target_model: nn.Module) -> None:
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)
@torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None:
hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
self.model(hidden_states)

View File

@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import numpy as np
import torch
@dataclass
class SpecDecodeMetadata:
# [num_tokens]
draft_token_ids: torch.Tensor
# [batch_size]
num_draft_tokens: list[int]
# [batch_size]
cu_num_draft_tokens: torch.Tensor
# [num_tokens]
target_logits_indices: torch.Tensor
# [batch_size]
bonus_logits_indices: torch.Tensor
# [num_tokens + batch_size]
logits_indices: torch.Tensor
def __post_init__(self):
self.max_spec_len = max(self.num_draft_tokens)
@classmethod
def make_dummy(
cls,
draft_token_ids: list[list[int]],
device: torch.device,
) -> "SpecDecodeMetadata":
batch_size = len(draft_token_ids)
num_draft_tokens = [len(ids) for ids in draft_token_ids]
flattened_draft_token_ids = sum(draft_token_ids, [])
num_tokens = len(flattened_draft_token_ids)
draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids,
dtype=torch.int32,
device=device)
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(
device)
target_logits_indices = torch.zeros(num_tokens,
dtype=torch.int32,
device=device)
bonus_logits_indices = torch.zeros(batch_size,
dtype=torch.int32,
device=device)
logits_indices = torch.zeros(num_tokens + batch_size,
dtype=torch.int32,
device=device)
return cls(
draft_token_ids=draft_token_ids_tensor,
num_draft_tokens=num_draft_tokens,
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import prometheus_client
from vllm.config import SpeculativeConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class SpecDecodingStats:
"""Per-step iteration decoding stats from scheduler.
Each scheduler step, statistics on spec decoding performance are
aggregated across requests by the scheduler and returned to the
frontend in EngineCoreOutputs->SchedulerStats.
"""
num_spec_tokens: int
num_drafts: int = 0
num_draft_tokens: int = 0
num_accepted_tokens: int = 0
num_accepted_tokens_per_pos: list[int] = field(default_factory=list)
@classmethod
def new(cls, num_spec_tokens: int) -> "SpecDecodingStats":
return cls(num_spec_tokens=num_spec_tokens,
num_accepted_tokens_per_pos=[0] * num_spec_tokens)
def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int):
self.num_drafts += 1
self.num_draft_tokens += num_draft_tokens
self.num_accepted_tokens += num_accepted_tokens
assert num_accepted_tokens <= self.num_spec_tokens
for i in range(num_accepted_tokens):
self.num_accepted_tokens_per_pos[i] += 1
class SpecDecodingLogging:
"""Aggregate and log spec decoding metrics.
LoggingStatLogger aggregates per-iteration metrics over a set
time interval using observe() and then logs them using log()
before resetting to zero.
"""
def __init__(self):
self.reset()
def reset(self):
self.num_drafts: list[int] = []
self.num_draft_tokens: list[int] = []
self.num_accepted_tokens: list[int] = []
self.accepted_tokens_per_pos_lists: list[list[int]] = []
def observe(self, spec_decoding_stats: SpecDecodingStats):
self.num_drafts.append(spec_decoding_stats.num_drafts)
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
self.num_accepted_tokens.append(
spec_decoding_stats.num_accepted_tokens)
self.accepted_tokens_per_pos_lists.append(
spec_decoding_stats.num_accepted_tokens_per_pos)
def log(self, log_fn=logger.info):
if not self.num_drafts:
return
num_drafts = np.sum(self.num_drafts)
num_draft_tokens = np.sum(self.num_draft_tokens)
num_accepted_tokens = np.sum(self.num_accepted_tokens)
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens *
100 if num_draft_tokens > 0 else float("nan"))
# Conventionally, mean acceptance length includes the bonus token
mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts)
pos_matrix = np.array(self.accepted_tokens_per_pos_lists)
acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates)
log_fn(
"SpecDecoding metrics: "
"Draft acceptance rate: %.1f%%, "
"Mean acceptance length: %.2f, "
"Accepted: %d tokens, "
"Drafted: %d tokens, "
"Per-position acceptance rate: %s",
draft_acceptance_rate,
mean_acceptance_length,
num_accepted_tokens,
num_draft_tokens,
rates_str,
)
self.reset()
class SpecDecodingProm:
"""Record spec decoding metrics in Prometheus.
The acceptance rate can be calculated using a PromQL query:
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
rate(vllm:spec_decode_num_draft_tokens_total[$interval])
The mean acceptance length (conventionally including bonus tokens)
can be calculated using:
1 + (
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
rate(vllm:spec_decode_num_drafts[$interval]))
A per-position acceptance rate vector can be computed using
vllm:spec_decode_num_accepted_tokens_per_pos[$interval] /
vllm:spec_decode_num_drafts[$interval]
"""
_counter_cls = prometheus_client.Counter
def __init__(
self,
speculative_config: Optional[SpeculativeConfig],
labelnames: list[str],
labelvalues: list[str],
):
self.spec_decoding_enabled = speculative_config is not None
if not self.spec_decoding_enabled:
return
self.counter_spec_decode_num_drafts = \
self._counter_cls(
name="vllm:spec_decode_num_drafts",
documentation="Number of spec decoding drafts.",
labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_draft_tokens = \
self._counter_cls(
name="vllm:spec_decode_num_draft_tokens",
documentation="Number of draft tokens.",
labelnames=labelnames,).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \
self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens",
documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues)
assert speculative_config is not None
num_spec_tokens = (speculative_config.num_speculative_tokens
if self.spec_decoding_enabled else 0)
pos_labelnames = labelnames + ["position"]
base_counter = self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_per_pos",
documentation="Accepted tokens per draft position.",
labelnames=pos_labelnames,
)
self.counter_spec_decode_num_accepted_tokens_per_pos: list[
prometheus_client.Counter] = []
for pos in range(num_spec_tokens):
pos_labelvalues = labelvalues + [str(pos)]
self.counter_spec_decode_num_accepted_tokens_per_pos.append(
base_counter.labels(*pos_labelvalues))
def observe(self, spec_decoding_stats: SpecDecodingStats):
if not self.spec_decoding_enabled:
return
self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts)
self.counter_spec_decode_num_draft_tokens.inc(
spec_decoding_stats.num_draft_tokens)
self.counter_spec_decode_num_accepted_tokens.inc(
spec_decoding_stats.num_accepted_tokens)
for pos, counter in enumerate(
self.counter_spec_decode_num_accepted_tokens_per_pos):
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])

View File

@@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
from numba import jit
from vllm.config import VllmConfig
class NgramProposer:
def __init__(self, vllm_config: VllmConfig):
# Minimum length of the n-gram to match.
self.min_n = vllm_config.speculative_config.prompt_lookup_min
# Maximum length of the n-gram to match.
self.max_n = vllm_config.speculative_config.prompt_lookup_max
# Number of tokens follow the match. If there are less than k
# tokens follow the match, we will return the maximum amount of
# tokens until the end.
self.k = vllm_config.speculative_config.num_speculative_tokens
# Maximum length of the model.
self.max_model_len = vllm_config.model_config.max_model_len
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.propose(np.zeros(1024, dtype=np.int32))
def propose(
self,
context_token_ids: np.ndarray,
) -> Optional[np.ndarray]:
"""Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n
tokens in the previous context, and returns k tokens that followed
that match.
Args:
context_token_ids: Numpy array of token IDs representing the
context sequence.
Returns:
np.ndarray: The sequence of tokens that followed
the matched n-gram in the context.
None: If no matching n-gram pattern is found.
Example:
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
k = 4:
- The last 3 (= max_n) tokens [4,2,3] cannot find a match.
- The last 2 tokens [2,3] will be matched against the previous
4 tokens [1,2,3,4].
- Finding a match of [2,3] would return the tokens that
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# Do not generate draft tokens beyond the max model length.
k = min(self.k, self.max_model_len - context_token_ids.shape[0])
if k <= 0:
return None
# TODO(woosuk): Optimize this.
for n in range(self.max_n, self.min_n - 1, -1):
result = _find_subarray_kmp(context_token_ids, n, k)
if result is not None:
return result
return None
def load_model(self, *args, **kwargs):
# No model to load.
pass
@jit(nopython=True)
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
"""
lps = np.zeros(len(pattern), dtype=np.int32)
prev_lps = 0 # length of the previous longest prefix suffix
i = 1
while i < len(pattern):
if pattern[i] == pattern[prev_lps]:
prev_lps += 1
lps[i] = prev_lps
i += 1
else:
if prev_lps != 0:
prev_lps = lps[prev_lps - 1]
else:
lps[i] = 0
i += 1
return lps
@jit(nopython=True)
def _find_subarray_kmp(
context_token_ids: np.ndarray,
n: int,
k: int,
) -> Optional[np.ndarray]:
context_len = context_token_ids.shape[0]
assert n > 0
pattern = context_token_ids[-n:]
# Precompute lps array for Y
lps = _kmp_lps_array(pattern)
i = 0
j = 0
# -n because the last n tokens are used as pattern
while i < context_len - n:
if context_token_ids[i] == pattern[j]:
i += 1
j += 1
# If we have matched the entire Y
if j == n:
# Found pattern in context, gather the next K elements
return context_token_ids[i:i + k]
else:
# Mismatch
if j != 0:
# Use the lps array to avoid re-checking elements
j = lps[j - 1]
else:
i += 1
# Y not found
return None

View File

@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu_input_batch import InputBatch
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
if req_id in input_batch.min_p_reqs:
# Spec decode doesn't support min_p sampling.
return False
elif (req_id in input_batch.frequency_penalties_reqs
or req_id in input_batch.presence_penalties_reqs
or req_id in input_batch.repetition_penalties_reqs):
# Spec decode doesn't support penalties.
return False
elif req_id in input_batch.num_logprobs:
# Spec decode doesn't support logprobs.
return False
return True
@triton.jit
def prepare_eagle_input_kernel(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
# [start_pos, end_pos)
start_pos = tl.load(cu_num_tokens_ptr + pid)
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
num_tokens = end_pos - start_pos
index_start = tl.load(cu_query_lens_ptr + pid)
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
index_start + offset,
mask=offset < num_tokens,
)

View File

@@ -0,0 +1,222 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar)
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import torch
from vllm.reasoning import ReasoningParser
from vllm.v1.request import Request
else:
torch = LazyLoader("torch", globals(), "torch")
logger = init_logger(__name__)
class StructuredOutputManager:
"""Engine-level manager for structured output requests."""
def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None
self.reasoner: Optional[ReasoningParser] = None
self.vllm_config = vllm_config
self._grammar_bitmask: Optional[torch.Tensor] = None
self._full_mask = torch.tensor(-1, dtype=torch.int32)
# The default max_workers if not specified is the number of CPUs * 5,
# which is way too high since these tasks are CPU-bound, not I/O bound.
# We also know we would never dominate CPU usage with just grammar
# compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.tokenizer = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
lora_config=self.vllm_config.lora_config,
).get_lora_tokenizer(None)
reasoning_backend = vllm_config.decoding_config.reasoning_backend
if reasoning_backend:
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None:
return
if TYPE_CHECKING:
assert request.sampling_params.guided_decoding is not None
# Initialize the backend the first time it is needed.
#
# NOTE: We only support a single backend. We do NOT support different
# backends on a per-request basis in V1 (for now, anyway...).
if self.backend is None:
backend = request.sampling_params.guided_decoding.backend
vocab_size = self.vllm_config.model_config.get_vocab_size()
if backend == "xgrammar":
self.backend = XgrammarBackend(
self.vllm_config,
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
elif backend == "guidance":
self.backend = GuidanceBackend(
self.vllm_config,
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
else:
raise ValueError(
f"Unsupported structured output backend: {backend}")
grammar = self.executor.submit(self._async_create_grammar, request)
request.structured_output_request.grammar = grammar # type: ignore[assignment]
def _async_create_grammar(
self,
request: Request,
) -> StructuredOutputGrammar:
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
# Note that the request was validated in the engine core client,
# so at this point we know it is a supported type of request.
#
# TODO: we still need to handle xgrammar compilation failures,
# though it should be unlikely as we test that up front as well.
request_type, grammar_spec = key
assert self.backend is not None
return self.backend.compile_grammar(request_type, grammar_spec)
def grammar_bitmask(
self,
requests: dict[str, Request],
structured_output_request_ids: dict[str, int],
scheduled_spec_decode_tokens: dict[str, list[int]],
) -> Optional[npt.NDArray[np.int32]]:
# Prepare the structured output bitmask for this batch.
if not structured_output_request_ids:
return None
max_num_spec_tokens = 0
if self.vllm_config.speculative_config is not None:
max_num_spec_tokens = \
self.vllm_config.speculative_config.num_speculative_tokens
if self._grammar_bitmask is None:
assert self.backend is not None
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
# Allocate a bitmask for each token needing to be checked:
# one for each speculative position, and one more for the
# bonus token / non-speculative token.
self._grammar_bitmask = \
self.backend.allocate_token_bitmask(
max_batch_size * (1 + max_num_spec_tokens))
bitmask_tensor = self._grammar_bitmask
# Generate a batched bitmask for all structured output requests.
# When speculative decoding is enabled, we need to include multiple
# masks for each request, one for each possible bonus token position.
# These are stored inline in the tensor and unpacked by the gpu runner.
cumulative_index = 0
ordered_seq = sorted(structured_output_request_ids.items(),
key=lambda x: x[1])
# Note that for thinking support, we will need to
# reset the relevant part of the bitmask for consequent
# request here.
bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_(
self._full_mask)
# NOTE: This outer loop can likely be parallelized to improve
# performance of bitmask generation for large batches.
for req_id, _ in ordered_seq:
request = requests[req_id]
structured_output_request = request.structured_output_request
if TYPE_CHECKING:
assert structured_output_request is not None
assert structured_output_request.grammar is not None
apply_bitmask: bool = True
if self.reasoner is not None:
if structured_output_request.reasoning_ended is None:
structured_output_request.reasoning_ended = \
self.reasoner.is_reasoning_end(request.prompt_token_ids)
apply_bitmask = structured_output_request.reasoning_ended
state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
for i, token in enumerate(req_tokens):
if apply_bitmask and not \
structured_output_request.grammar.is_terminated():
structured_output_request.grammar.fill_bitmask(
bitmask_tensor, cumulative_index)
if token is not None:
# In order to generate the correct bitmask for each
# position in the speculative sequence, we advance
# the FSM state for each speculative token and rollback
# to restore the previous state when we are finished.
assert structured_output_request.grammar.accept_tokens(
req_id, [token])
state_advancements += 1
cumulative_index += 1
if state_advancements > 0:
structured_output_request.grammar.rollback(state_advancements)
if cumulative_index < bitmask_tensor.shape[0]:
bitmask_tensor = bitmask_tensor[:cumulative_index]
# After finishing with the xgrammar operations, we convert to
# np.ndarray, because that is much more efficient for serialization
# and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy()
def should_advance(self, request: Request) -> bool:
if not request.use_structured_output:
return False
# To determine whether we can advance the FSM.
# Supports thinking usage where we skip the reasoning components.
if TYPE_CHECKING:
assert request.structured_output_request is not None
assert request.structured_output_request.grammar is not None
# by default, we should always advance
# for cases that doesn't uses thinking mode.
if self.reasoner is not None:
structured_req = request.structured_output_request
if structured_req.reasoning_ended:
return True
# Check if reasoning ends in *this* step
if self.reasoner.is_reasoning_end(request.all_token_ids):
# Reasoning just ended, so we shouldn't advanced til
# next pass
structured_req.reasoning_ended = True
return False
else:
return True
def clear_backend(self) -> None:
if self.backend is not None:
self.backend.destroy()

View File

@@ -0,0 +1,245 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import copy
import json
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar,
StructuredOutputOptions)
from vllm.v1.structured_output.request import get_structured_output_key
if TYPE_CHECKING:
import llguidance
import llguidance.hf as llguidance_hf
import llguidance.torch as llguidance_torch
else:
llguidance = LazyLoader("llguidance", globals(), "llguidance")
llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf")
llguidance_torch = LazyLoader("llguidance.torch", globals(),
"llguidance.torch")
logger = init_logger(__name__)
def _walk_json_for_additional_properties(data: object):
if isinstance(data, dict):
for value in data.values():
_walk_json_for_additional_properties(value)
if 'additionalProperties' not in data and \
('properties' in data or 'patternProperties' in data):
data['additionalProperties'] = False
elif isinstance(data, list):
for item in data:
_walk_json_for_additional_properties(item)
def process_for_additional_properties(
guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]:
if isinstance(guide_json, str):
guide_json_obj = json.loads(guide_json)
else:
# copy for modifications
guide_json_obj = copy.deepcopy(guide_json)
_walk_json_for_additional_properties(guide_json_obj)
return guide_json_obj
@dataclass
class GuidanceBackend(StructuredOutputBackend):
def __post_init__(self):
self.disable_any_whitespace = \
self.vllm_config.decoding_config.disable_any_whitespace
self.disable_additional_properties = \
self.vllm_config.decoding_config.disable_additional_properties
self.ll_tokenizer = llguidance_hf.from_tokenizer(
self.tokenizer, self.vocab_size)
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
self.serialized_grammar = serialize_guidance_grammar(
request_type, grammar_spec, self.disable_any_whitespace,
self.disable_additional_properties)
ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
self.serialized_grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
r = GuidanceGrammar(
ll_matcher=ll_matcher,
ll_tokenizer=self.ll_tokenizer,
vocab_size=self.vocab_size,
)
r.check_error()
return r
def allocate_token_bitmask(self, max_num_seqs: int):
return llguidance_torch.allocate_token_bitmask(
max_num_seqs, self.ll_tokenizer.vocab_size)
def destroy(self):
pass
@dataclass
class GuidanceGrammar(StructuredOutputGrammar):
ll_matcher: llguidance.LLMatcher
ll_tokenizer: llguidance.LLTokenizer
vocab_size: int
printed_error: bool = False
terminated: bool = False
def check_error(self):
if not self.printed_error:
err = self.ll_matcher.get_error()
if err:
self.printed_error = True
logger.warning("LLMatcher error: %s", err)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the parser.
Returns True if the parser was advanced successfully.
Returns False if the parser failed to advance.
"""
if self.ll_tokenizer.eos_token in tokens:
self.terminated = True
if self.ll_matcher.is_stopped():
return True
# TODO - Add jump decoding support in the future:
# self.ll_matcher.compute_ff_bytes() - this should always work
# self.ll_matcher.compute_ff_tokens() - this only works for
# "canonical" tokenizers
# For conversion between the two, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md
r = self.ll_matcher.consume_tokens(tokens)
self.check_error()
return r
def validate_tokens(self, tokens: list[int]) -> list[int]:
"""Checks if the list of tokens are accepted by the parser in sequence.
Will not advance the parser.
Returns the prefix list of tokens that are accepted by the parser.
"""
if len(tokens) == 0:
return []
if self.ll_matcher.is_stopped():
return []
num_tokens = self.ll_matcher.validate_tokens(tokens)
self.check_error()
return tokens[:num_tokens]
def rollback(self, num_tokens: int) -> None:
self.ll_matcher.rollback(num_tokens)
self.check_error()
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
# this will automatically return [EOS] mask if the matcher is stopped
# or otherwise in an error state
llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx)
self.check_error()
def is_terminated(self) -> bool:
return self.terminated
def reset(self):
# This method may be not needed anymore? TODO
self.ll_matcher.reset()
def serialize_guidance_grammar(
request_type: StructuredOutputOptions,
grammar_spec: Union[str, dict[str, Any]],
disable_any_whitespace: bool = False,
disable_additional_properties: bool = False,
) -> str:
def _process_schema(grammar_spec: Union[str, dict[str, Any]], ) -> str:
if disable_additional_properties:
grammar_spec = process_for_additional_properties(grammar_spec)
return llguidance.LLMatcher.grammar_from_json_schema(
grammar_spec,
defaults={
"whitespace_flexible": not disable_any_whitespace,
})
if request_type == StructuredOutputOptions.JSON:
return _process_schema(grammar_spec)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
return llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}',
defaults={
"whitespace_flexible": not disable_any_whitespace,
})
else:
if request_type == StructuredOutputOptions.REGEX:
tp = "regex"
elif request_type == StructuredOutputOptions.GRAMMAR:
tp = "grammar"
elif request_type == StructuredOutputOptions.CHOICE:
tp = "choice"
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
if isinstance(grammar_spec, str):
s_tag = json.loads(grammar_spec)
else:
s_tag = grammar_spec
triggers: list[str] = s_tag["triggers"]
tags: list[llguidance.StructTag] = []
for s in s_tag["structures"]:
begin: str = s["begin"]
trig = next((t for t in triggers if begin.startswith(t)), None)
if trig is None:
raise ValueError(
f"Trigger {begin} not found in triggers {triggers}")
tags.append(
llguidance.StructTag(
trigger=trig,
begin=s["begin"],
grammar=_process_schema(s["schema"]),
end=s["end"],
))
if not tags:
raise ValueError(
"No structural tags found in the grammar spec.")
return llguidance.StructTag.to_grammar(tags)
else:
logger.error("Validation should have already occurred. "
"Please file an issue.")
raise ValueError("grammar is not of valid supported types. "
f"({request_type!s})")
return llguidance.grammar_from(tp, grammar_spec)
def validate_guidance_grammar(
sampling_params: SamplingParams,
tokenizer: Optional[llguidance.LLTokenizer] = None) -> None:
tp, grm = get_structured_output_key(sampling_params)
guidance_grm = serialize_guidance_grammar(tp, grm)
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
if err:
raise ValueError(f"Grammar error: {err}")

View File

@@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import torch
from vllm.config import VllmConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StructuredOutputOptions(enum.Enum):
JSON = enum.auto()
JSON_OBJECT = enum.auto()
REGEX = enum.auto()
GRAMMAR = enum.auto()
CHOICE = enum.auto()
STRUCTURAL_TAG = enum.auto()
StructuredOutputKey = tuple[StructuredOutputOptions, str]
class StructuredOutputGrammar(ABC):
"""Request-level backend for structured output requests."""
@abstractmethod
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""
Determines whether the provided tokens are accepted for the
given request.
Args:
request_id (str): The unique identifier for the request.
tokens (list[int]): A list of token IDs to evaluate.
Returns:
bool: True if the tokens are accepted, False otherwise.
"""
@abstractmethod
def validate_tokens(self, tokens: list[int]) -> list[int]:
"""
Validates the provided tokens against the grammar.
Will not advance the FSM.
Args:
tokens (list[int]): A list of token IDs to validate.
Returns:
list[int]: A list of accepted token IDs. Will be a prefix
of the input tokens, and empty if none are accepted.
"""
@abstractmethod
def rollback(self, num_tokens: int) -> None:
"""
Rolls back the state of the grammar by a specified number of tokens.
Will also revert counters for the number of processed tokens.
Args:
num_tokens (int): The number of tokens to roll back.
"""
@abstractmethod
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
"""
Fills the bitmask for a specific batch index.
Args:
bitmask (torch.Tensor): The bitmask to fill
batch_index (int): The index in the bitmask to fill
"""
@abstractmethod
def is_terminated(self) -> bool:
"""
Checks whether the structured output process has terminated.
Returns:
bool: True if the process is terminated, False otherwise.
"""
@abstractmethod
def reset(self):
"""
Resets the state of the structured output grammar.
"""
@dataclass
class StructuredOutputBackend(ABC):
"""Engine-level backend for structured output requests."""
vllm_config: VllmConfig
tokenizer: AnyTokenizer
vocab_size: int
@abstractmethod
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
"""
Compiles a grammar specification into a structured output grammar.
Args:
request_type (StructuredOutputOptions): The type of structured
output request.
grammar_spec (str): The grammar specification to compile.
Returns:
StructuredOutputGrammar: The compiled structured output grammar.
"""
@abstractmethod
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
"""
Allocates a token bitmask for the specified maximum number of sequences.
Args:
max_num_seqs (int): The maximum number of sequences for which
to allocate the bitmask.
"""
@abstractmethod
def destroy(self):
"""
Backend-specific cleanup.
"""

View File

@@ -0,0 +1,318 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import torch
import vllm.envs
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar,
StructuredOutputOptions)
from vllm.v1.structured_output.utils import (choice_as_grammar,
convert_lark_to_ebnf,
grammar_is_likely_lark)
if TYPE_CHECKING:
import xgrammar as xgr
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
@dataclass
class XgrammarBackend(StructuredOutputBackend):
def __post_init__(self):
self.disable_any_whitespace = \
self.vllm_config.decoding_config.disable_any_whitespace
if isinstance(self.tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try:
if self.tokenizer.is_tekken:
encoded_vocab = self.tokenizer._vocab
else:
encoded_vocab = [
token for token, _ in sorted(
self.tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if (hasattr(
self.tokenizer,
"eos_token_id",
) and self.tokenizer.eos_token_id is not None):
stop_token_ids = [self.tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(self.tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
tokenizer_info = xgr.TokenizerInfo( # type: ignore
encoded_vocab=encoded_vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.RAW
if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK,
vocab_size=self.vocab_size,
stop_token_ids=stop_token_ids,
add_prefix_space=True,
)
else:
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
self.tokenizer,
vocab_size=self.vocab_size,
)
self.compiler = xgr.GrammarCompiler(
tokenizer_info,
max_threads=8,
cache_enabled=True,
cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
)
self.num_speculative_tokens = 0
if self.vllm_config.speculative_config is not None:
self.num_speculative_tokens = \
self.vllm_config.speculative_config.num_speculative_tokens
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
if request_type == StructuredOutputOptions.JSON:
ctx = self.compiler.compile_json_schema(
grammar_spec, any_whitespace=not self.disable_any_whitespace)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
ctx = self.compiler.compile_json_schema(
'{"type": "object"}',
any_whitespace=not self.disable_any_whitespace)
elif request_type == StructuredOutputOptions.GRAMMAR:
ctx = self.compiler.compile_grammar(grammar_spec)
elif request_type == StructuredOutputOptions.REGEX:
ctx = self.compiler.compile_regex(grammar_spec)
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
s_tag = json.loads(grammar_spec)
tags = [
xgr.StructuralTagItem(
begin=s["begin"],
schema=json.dumps(s["schema"]),
end=s["end"],
) for s in s_tag["structures"]
]
ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"])
else:
logger.error(
"Validation should have already occurred. Please file an issue."
)
raise ValueError(
f"grammar is not of valid supported types. ({request_type!s})")
return XgrammarGrammar(
matcher=xgr.GrammarMatcher(
ctx,
max_rollback_tokens=self.num_speculative_tokens,
),
vocab_size=self.vocab_size,
ctx=ctx,
)
def allocate_token_bitmask(self, max_num_seqs: int):
return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)
def destroy(self):
del self.compiler
@dataclass
class XgrammarGrammar(StructuredOutputGrammar):
# NOTE: This would be a generic-enough class for
# supporting different backends, in the future.
# For now, just xgrammar.
#
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding
vocab_size: int
matcher: xgr.GrammarMatcher = field(hash=False)
ctx: xgr.CompiledGrammar = field(hash=False)
num_processed_tokens: int = field(default_factory=lambda: 0,
repr=False,
hash=False,
init=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
for token in tokens:
if not self.matcher.accept_token(token):
logger.error(
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue.", request_id, token)
return False
self.num_processed_tokens += 1
return True
def validate_tokens(self, tokens: list[int]) -> list[int]:
"""Checks if the list of tokens are accepted by the FSM in sequence.
Will not advance the FSM.
Returns the prefix list of tokens that are accepted by the FSM.
"""
accepted_tokens = []
for token in tokens:
if self.matcher.accept_token(token):
accepted_tokens.append(token)
else:
break
if len(accepted_tokens) > 0:
# Rollback the FSM to the initial state
self.matcher.rollback(len(accepted_tokens))
return accepted_tokens
def rollback(self, num_tokens: int) -> None:
self.matcher.rollback(num_tokens)
self.num_processed_tokens -= num_tokens
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(bitmask, idx)
def is_terminated(self) -> bool:
return self.matcher.is_terminated()
def reset(self):
self.num_processed_tokens = 0
self.matcher.reset()
def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""
def check_object(obj: dict[str, Any]) -> bool:
if not isinstance(obj, dict):
return False
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
return True
# Check for array unsupported keywords
if obj.get("type") == "array" and any(
key in obj for key in ("uniqueItems", "contains",
"minContains", "maxContains")):
return True
# Unsupported keywords for strings
if obj.get("type") == "string" and "format" in obj:
return True
# Unsupported keywords for objects
if obj.get("type") == "object" and any(
key in obj for key in ("minProperties", "maxProperties",
"propertyNames", "patternProperties")):
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
"""Validate that the request is supported by structured output.
Raises ValueError if the request is not supported.
"""
if sampling_params.guided_decoding is None:
return
gd_params = sampling_params.guided_decoding
if gd_params.regex:
try:
xgr.Grammar.from_regex(gd_params.regex)
except Exception as err:
raise ValueError("Failed to transform regex into a grammar: "
f"{err}") from err
if gd_params.choice:
choice_grammar = choice_as_grammar(gd_params.choice)
try:
xgr.Grammar.from_ebnf(choice_grammar)
except Exception as err:
raise ValueError("Failed to transform choices into a grammar: "
"{err}") from err
gd_params.choice = None
gd_params.grammar = choice_grammar
return
if gd_params.json:
if isinstance(gd_params.json, str):
try:
schema = json.loads(gd_params.json)
except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e
else:
schema = gd_params.json
try:
xgr.Grammar.from_json_schema(schema)
except Exception as err:
raise ValueError("Failed to transform json schema into a grammar: "
f"{err}") from err
if has_xgrammar_unsupported_json_features(schema):
raise ValueError("The provided JSON schema contains features not "
"supported by xgrammar.")
return
if gd_params.grammar:
if grammar_is_likely_lark(gd_params.grammar):
# xgrammar supports EBNF grammars only
try:
gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar)
except ValueError as e:
raise ValueError(
"Failed to convert the grammar from Lark to EBNF. ") from e
# Test parsing EBNF grammar, possibly already converted from Lark
try:
# parse the grammar, but we aren't compiling it.
xgr.Grammar.from_ebnf(gd_params.grammar)
except Exception as e:
raise ValueError("Invalid grammar specification.") from e
return
if gd_params.structural_tag:
try:
s_tag = json.loads(gd_params.structural_tag)
tags = [
xgr.StructuralTagItem(
begin=s["begin"],
schema=json.dumps(s["schema"]),
end=s["end"],
) for s in s_tag["structures"]
]
xgr.Grammar.from_structural_tag(tags, s_tag["triggers"])
except Exception as e:
raise ValueError("Invalid structural tag specification.") from e

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import dataclasses
import functools
import json
from concurrent.futures import Future
from concurrent.futures._base import TimeoutError
from typing import Optional, Union, cast
from vllm.sampling_params import SamplingParams
from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar,
StructuredOutputKey,
StructuredOutputOptions)
@dataclasses.dataclass
class StructuredOutputRequest:
sampling_params: SamplingParams
_grammar: Optional[Union[Future[StructuredOutputGrammar],
StructuredOutputGrammar]] = None
reasoning_ended: Optional[bool] = None
def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports
from vllm.v1.request import RequestStatus
if isinstance(self._grammar, Future):
try:
# We will check whether the future is ready within 100 us
self._grammar = self._grammar.result(timeout=0.0001)
self.status = RequestStatus.WAITING
except TimeoutError:
return False
return True
@property
def is_grammar_ready(self) -> bool:
return self._check_grammar_completion()
@property
def grammar(self) -> Optional[StructuredOutputGrammar]:
completed = self._check_grammar_completion()
return cast(Optional[StructuredOutputGrammar],
self._grammar) if completed else None
@grammar.setter
def grammar(
self, grammar: Union[StructuredOutputGrammar,
Future[StructuredOutputGrammar]]
) -> None:
self._grammar = grammar
@functools.cached_property
def structured_output_key(self) -> StructuredOutputKey:
return get_structured_output_key(self.sampling_params)
def get_structured_output_key(
sampling_params: SamplingParams) -> StructuredOutputKey:
params = sampling_params.guided_decoding
assert params is not None, "params can't be None."
if params.json is not None:
if not isinstance(params.json, str):
json_str = json.dumps(params.json)
else:
json_str = params.json
return (StructuredOutputOptions.JSON, json_str)
elif params.json_object:
return (StructuredOutputOptions.JSON_OBJECT, "")
elif params.regex is not None:
return (StructuredOutputOptions.REGEX, params.regex)
elif params.choice is not None:
if not isinstance(params.choice, str):
json_str = json.dumps(params.choice)
else:
json_str = params.choice
return (StructuredOutputOptions.CHOICE, json_str)
elif params.grammar is not None:
return (StructuredOutputOptions.GRAMMAR, params.grammar)
elif params.structural_tag is not None:
return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag)
else:
raise ValueError("No valid structured output parameter found")

View File

@@ -0,0 +1,175 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import regex as re
def grammar_is_likely_lark(grammar_str: str) -> bool:
"""
Check if grammar appears to use Lark syntax.
Args:
grammar_str: Input grammar string
Returns:
bool: True if grammar appears to be in Lark format, False otherwise
Examples:
>>> grammar_is_likely_lark("rule: 'abc'")
True
>>> grammar_is_likely_lark("rule ::= 'abc'")
False
"""
if not grammar_str or not isinstance(grammar_str, str):
return False
for line in grammar_str.split('\n'):
# Remove both comment styles
line = re.sub(r'(#|//).*$', '', line).strip()
if not line:
continue
# Look for EBNF rule definition
if '::=' in line:
return False
return True
def convert_lark_to_ebnf(grammar_str: str) -> str:
"""
Convert a Lark grammar string to EBNF format.
EBNF reference:
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
Lark grammar reference:
https://lark-parser.readthedocs.io/en/latest/grammar.html
Args:
grammar_str: Input grammar in Lark format
Returns:
str: Converted grammar in EBNF format
Examples:
>>> print(convert_lark_to_ebnf("rule: 'hello'"))
root ::= rule
rule ::= "hello"
"""
if not isinstance(grammar_str, str):
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
if not grammar_str.strip():
raise ValueError("Grammar string cannot be empty")
defined_rules = set()
referenced_rules = set()
output_lines = []
def clean_line(line: str) -> str:
"""Remove comments and whitespace from line."""
return re.sub(r'(#|//).*$', '', line).strip()
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
"""Validate quote matching in text."""
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
raise ValueError(
f"Mismatched quotes in {rule_name} on line {line_num}")
def extract_references(text: str) -> set:
"""Extract rule references from text."""
# Remove quoted strings and special characters
text = re.sub(r'"[^"]*"', '', text)
text = re.sub(r'[+*?()|\[\]{}]', ' ', text)
return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text))
# First pass: Find root rule and validate rule definitions
lines = [clean_line(line) for line in grammar_str.split('\n')]
first_rule = None
for line_num, line in enumerate(lines, 1):
if not line or line.startswith('|'):
continue
if ':' in line:
try:
name = line.split(':', 1)[0].strip().strip('?')
defined_rules.add(name)
if first_rule is None:
first_rule = name
if name == 'start':
first_rule = 'start'
except IndexError as e:
raise ValueError(f"Invalid rule format on line {line_num}. "
"Expected 'rule_name: definition'") from e
if not defined_rules:
raise ValueError("No valid rules found in grammar")
# Add root rule
output_lines.append(f"root ::= {first_rule}")
# Second pass: Process rule definitions and alternatives
current_rule = None
current_definition = []
for line_num, line in enumerate(lines, 1):
if not line:
continue
try:
if ':' in line and not line.startswith('|'):
# Save previous rule if exists
if current_rule:
output_lines.append(
f"{current_rule} ::= {' | '.join(current_definition)}")
# Process new rule
name, definition = line.split(':', 1)
current_rule = name.strip().strip('?')
check_quotes(definition, f"rule '{current_rule}'", line_num)
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
referenced_rules.update(extract_references(definition))
current_definition = [definition.strip()]
elif line.startswith('|'):
if not current_rule:
raise ValueError(f"Alternative '|' on line {line_num} "
"without a preceding rule definition")
alt_def = line[1:].strip()
check_quotes(alt_def, f"alternative for rule '{current_rule}'",
line_num)
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
referenced_rules.update(extract_references(alt_def))
current_definition.append(alt_def)
except ValueError as e:
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
# Add final rule if exists
if current_rule:
output_lines.append(
f"{current_rule} ::= {' | '.join(current_definition)}")
# Validate all rules are defined
undefined_rules = referenced_rules - defined_rules - {'root'}
if undefined_rules:
raise ValueError("Referenced rules are not defined: "
f"{', '.join(sorted(undefined_rules))}")
return '\n'.join(output_lines)
def choice_as_grammar(choice: list[str]) -> str:
def escape_ebnf_string(s: str) -> str:
"""Escape special characters in a EBNF string."""
# Escape double quotes and backslashes
return re.sub(r'(["\\])', r'\\\1', s)
escaped_choices = (escape_ebnf_string(c) for c in choice)
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
return grammar

743
vllm/v1/utils.py Normal file
View File

@@ -0,0 +1,743 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import multiprocessing
import time
import weakref
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
import msgspec
import torch
import zmq
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
get_tcp_uri, kill_process_tree)
from vllm.v1.executor.abstract import Executor
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
from vllm.attention.layer import Attention
from vllm.v1.engine.coordinator import DPCoordinator
logger = init_logger(__name__)
T = TypeVar("T")
STARTUP_POLL_PERIOD_MS = 10000
class ConstantList(Generic[T], Sequence):
def __init__(self, x: list[T]) -> None:
self._x = x
def append(self, item):
raise Exception("Cannot append to a constant list")
def extend(self, item):
raise Exception("Cannot extend a constant list")
def insert(self, item):
raise Exception("Cannot insert into a constant list")
def pop(self, item):
raise Exception("Cannot pop from a constant list")
def remove(self, item):
raise Exception("Cannot remove from a constant list")
def clear(self):
raise Exception("Cannot clear a constant list")
def index(self,
item: T,
start: int = 0,
stop: Optional[int] = None) -> int:
return self._x.index(item, start,
stop if stop is not None else len(self._x))
@overload
def __getitem__(self, item: int) -> T:
...
@overload
def __getitem__(self, s: slice, /) -> list[T]:
...
def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
return self._x[item]
@overload
def __setitem__(self, item: int, value: T):
...
@overload
def __setitem__(self, s: slice, value: T, /):
...
def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
raise Exception("Cannot set item in a constant list")
def __delitem__(self, item):
raise Exception("Cannot delete item from a constant list")
def __iter__(self):
return iter(self._x)
def __contains__(self, item):
return item in self._x
def __len__(self):
return len(self._x)
def __repr__(self):
return f"ConstantList({self._x})"
def get_engine_client_zmq_addr(local_only: bool,
host: str,
port: int = 0) -> str:
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
host, port or get_open_port()))
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank."""
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.index = index
self.identity = index.to_bytes(2, "little")
self.state = CoreEngineState.NEW
@dataclass
class EngineZmqAddresses:
# ZMQ input socket addresses for each front-end client (requests)
inputs: list[str]
# ZMQ output socket addresses for each front-end client (responses)
outputs: list[str]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input: Optional[str] = None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output: Optional[str] = None
@dataclass
class EngineHandshakeMetadata:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str]]
class APIServerProcessManager:
"""Manages a group of API server processes.
Handles creation, monitoring, and termination of API server worker
processes. Also monitors extra processes to check if they are healthy.
"""
def __init__(
self,
target_server_fn: Callable,
listen_address: str,
sock: Any,
args: argparse.Namespace,
num_servers: int,
input_addresses: list[str],
output_addresses: list[str],
stats_update_address: Optional[str] = None,
):
"""Initialize and start API server worker processes.
Args:
target_server_fn: Function to call for each API server process
listen_address: Address to listen for client connections
sock: Socket for client connections
args: Command line arguments
num_servers: Number of API server processes to start
input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address
"""
self.listen_address = listen_address
self.sock = sock
self.args = args
# Start API servers
spawn_context = multiprocessing.get_context("spawn")
self.processes: list[BaseProcess] = []
for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
output_addresses):
client_config = {
"input_address": in_addr,
"output_address": out_addr,
"client_index": i
}
if stats_update_address is not None:
client_config["stats_update_address"] = stats_update_address
proc = spawn_context.Process(target=target_server_fn,
name=f"ApiServer_{i}",
args=(listen_address, sock, args,
client_config))
self.processes.append(proc)
proc.start()
logger.info("Started %d API server processes", len(self.processes))
# Shutdown only the API server processes on garbage collection
# The extra processes are managed by their owners
self._finalizer = weakref.finalize(self, shutdown, self.processes)
def close(self) -> None:
self._finalizer()
class CoreEngineProcManager:
"""
Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine.
"""
def __init__(
self,
target_fn: Callable,
local_engine_count: int,
start_index: int,
local_start_index: int,
vllm_config: VllmConfig,
on_head_node: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
):
context = get_mp_context()
common_kwargs = {
"vllm_config": vllm_config,
"on_head_node": on_head_node,
"handshake_address": handshake_address,
"executor_class": executor_class,
"log_stats": log_stats,
}
self.processes: list[BaseProcess] = []
for index in range(local_engine_count):
local_index = local_start_index + index
global_index = start_index + index
# Start EngineCore in background process.
self.processes.append(
context.Process(target=target_fn,
name=f"EngineCore_{global_index}",
kwargs=common_kwargs | {
"dp_rank": global_index,
"local_dp_rank": local_index,
}))
self._finalizer = weakref.finalize(self, shutdown, self.processes)
try:
for proc in self.processes:
proc.start()
finally:
# Kill other procs if not all are running.
if self.finished_procs():
self.close()
def close(self):
"""Shutdown all procs."""
self._finalizer()
def join_first(self):
"""Wait for any process to exit."""
connection.wait(proc.sentinel for proc in self.processes)
def sentinels(self) -> list:
return [proc.sentinel for proc in self.processes]
def finished_procs(self) -> dict[str, int]:
"""Returns dict of proc name -> exit code for any finished procs."""
return {
proc.name: proc.exitcode
for proc in self.processes if proc.exitcode is not None
}
class CoreEngineActorManager:
"""
Utility class to handle creation, readiness, and shutdown
of core engine Ray actors used by the AsyncLLM and LLMEngine.
Different from CoreEngineProcManager, this class manages
core engines for both local and remote nodes.
"""
def __init__(
self,
vllm_config: VllmConfig,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
placement_groups: Optional[list["PlacementGroup"]] = None,
local_dp_ranks: Optional[list[int]] = None,
):
import copy
import ray
from ray.util.scheduling_strategies import (
PlacementGroupSchedulingStrategy)
from vllm.v1.engine.core import DPEngineCoreActor
self.local_engine_actors: list[ray.ActorHandle] = []
self.remote_engine_actors: list[ray.ActorHandle] = []
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
world_size = vllm_config.parallel_config.world_size
if ray.is_initialized():
logger.info(
"Ray is already initialized. Skipping Ray initialization.")
else:
ray.init()
if placement_groups is not None:
assert local_dp_ranks is not None, (
"local_dp_ranks must be provided if "
"placement_groups is provided")
assert len(placement_groups) == len(local_dp_ranks), (
"placement_groups and local_dp_ranks must "
"have the same length")
logger.info("Using provided placement groups")
# TODO(rui): validate passed-in placement groups
self.created_placement_groups = []
else:
placement_groups, local_dp_ranks = \
CoreEngineActorManager.create_dp_placement_groups(vllm_config)
self.created_placement_groups = placement_groups
assert len(placement_groups) == dp_size, (
"Number of placement groups must match data parallel size")
refs = []
for index in range(dp_size):
local_index = local_dp_ranks[index]
dp_vllm_config = copy.deepcopy(vllm_config)
pg = placement_groups[index]
dp_vllm_config.parallel_config.placement_group = pg
on_head_node = index < local_engine_count
actor = ray.remote(DPEngineCoreActor).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=world_size,
)).remote(vllm_config=dp_vllm_config,
executor_class=executor_class,
log_stats=log_stats,
on_head_node=on_head_node,
addresses=addresses,
dp_rank=index,
local_dp_rank=local_index)
if on_head_node:
self.local_engine_actors.append(actor)
else:
self.remote_engine_actors.append(actor)
refs.append(actor.wait_for_init.remote())
ray.get(refs)
self.run_refs = []
for actor in self.local_engine_actors + self.remote_engine_actors:
self.run_refs.append(actor.run.remote())
@staticmethod
def create_dp_placement_groups(
vllm_config: VllmConfig
) -> tuple[list["PlacementGroup"], list[int]]:
import ray
from ray._private.state import available_resources_per_node
from ray.util.state import list_nodes
logger.info("Creating placement groups for data parallel")
dp_master_ip = \
vllm_config.parallel_config.data_parallel_master_ip
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
nodes = list_nodes()
nodes = sorted(list_nodes(),
key=lambda node: node.node_ip != dp_master_ip)
assert nodes[0].node_ip == dp_master_ip, (
"The first node must be the head node")
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
"There can only be one head node")
available_resources = available_resources_per_node()
world_size = vllm_config.parallel_config.world_size
placement_groups: list[PlacementGroup] = []
local_dp_ranks: list[int] = []
for node in nodes:
node_ip = node.node_ip
node_resources = available_resources[node.node_id]
# For now, each DP rank can only be assigned to one node
# TODO(rui): support allocating a single DP rank
# to multiple nodes
available_engine_count = int(node_resources["GPU"]) // world_size
if node_ip == dp_master_ip:
assert available_engine_count >= local_engine_count, (
"Not enough resources to allocate DP ranks "
f"on DP master node {node_ip}")
for i in range(local_engine_count):
bundles = [{
"GPU": 1.0,
"node:" + dp_master_ip: 0.001
}] * world_size + [{
"CPU": 1.0
}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
else:
for i in range(available_engine_count):
if len(placement_groups) == dp_size:
break
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
return placement_groups, local_dp_ranks
def get_run_refs(self):
return self.run_refs
def close(self):
import ray
for actor in self.local_engine_actors + self.remote_engine_actors:
ray.kill(actor)
for pg in self.created_placement_groups:
ray.util.remove_placement_group(pg)
def wait_for_engine_startup(
handshake_socket: zmq.Socket,
addresses: EngineZmqAddresses,
core_engines: list[CoreEngine],
parallel_config: ParallelConfig,
cache_config: CacheConfig,
proc_manager: Optional[CoreEngineProcManager],
coord_process: Optional[Process],
):
# Wait for engine core process(es) to send ready messages.
local_count = parallel_config.data_parallel_size_local
remote_count = len(core_engines) - local_count
# [local, remote] counts
conn_pending, start_pending = [local_count, remote_count], [0, 0]
poller = zmq.Poller()
poller.register(handshake_socket, zmq.POLLIN)
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
if coord_process is not None:
poller.register(coord_process.sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
if any(conn_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect.", *conn_pending)
if any(start_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to start.", *start_pending)
continue
if len(events) > 1 or events[0][0] != handshake_socket:
# One of the local core processes exited.
finished = proc_manager.finished_procs() if proc_manager else {}
if coord_process is not None and coord_process.exitcode is not None:
finished[coord_process.name] = coord_process.exitcode
raise RuntimeError("Engine core initialization failed. "
"See root cause above. "
f"Failed core proc(s): {finished}")
# Receive HELLO and READY messages from the input socket.
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
eng_index = int.from_bytes(eng_identity, "little")
engine = next((e for e in core_engines if e.identity == eng_identity),
None)
if engine is None:
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.
init_message = msgspec.msgpack.encode(
EngineHandshakeMetadata(
addresses=addresses,
parallel_config={
"data_parallel_master_ip":
parallel_config.data_parallel_master_ip,
"data_parallel_master_port":
parallel_config.data_parallel_master_port,
"data_parallel_size":
parallel_config.data_parallel_size,
}))
handshake_socket.send_multipart((eng_identity, init_message),
copy=False)
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and (engine.state == CoreEngineState.CONNECTED):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg["num_gpu_blocks"]
cache_config.num_gpu_blocks = num_gpu_blocks
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else:
raise RuntimeError(f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state.")
logger.debug("%s from %s core engine process %s.", status,
"local" if local else "remote", eng_index)
def wait_for_completion_or_failure(
api_server_manager: APIServerProcessManager,
engine_manager: Optional[Union[CoreEngineProcManager,
CoreEngineActorManager]] = None,
coordinator: Optional["DPCoordinator"] = None) -> None:
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
Args:
api_server_manager: The manager for API servers.
engine_manager: The manager for engine processes.
If CoreEngineProcManager, it manages local engines;
if CoreEngineActorManager, it manages all engines.
coordinator: The coordinator for data parallel.
"""
try:
logger.info("Waiting for API servers to complete ...")
# Create a mapping of sentinels to their corresponding processes
# for efficient lookup
sentinel_to_proc: dict[Any, BaseProcess] = {
proc.sentinel: proc
for proc in api_server_manager.processes
}
if coordinator:
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
actor_run_refs = []
if isinstance(engine_manager, CoreEngineProcManager):
for proc in engine_manager.processes:
sentinel_to_proc[proc.sentinel] = proc
elif isinstance(engine_manager, CoreEngineActorManager):
actor_run_refs = engine_manager.get_run_refs()
# Check if any process terminates
while sentinel_to_proc or actor_run_refs:
# Wait for any process to terminate
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
timeout=5)
# Process any terminated processes
for sentinel in ready_sentinels:
proc = sentinel_to_proc.pop(sentinel)
# Check if process exited with error
if proc.exitcode != 0:
raise RuntimeError(
f"Process {proc.name} (PID: {proc.pid}) "
f"died with exit code {proc.exitcode}")
if actor_run_refs:
import ray
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down API servers...")
except Exception as e:
logger.exception("Exception occurred while running API servers: %s",
str(e))
raise
finally:
logger.info("Terminating remaining processes ...")
api_server_manager.close()
if coordinator:
coordinator.close()
if engine_manager:
engine_manager.close()
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
def shutdown(procs: list[BaseProcess]):
# Shutdown the process.
for proc in procs:
if proc.is_alive():
proc.terminate()
# Allow 5 seconds for remaining procs to terminate.
deadline = time.monotonic() + 5
for proc in procs:
remaining = deadline - time.monotonic()
if remaining <= 0:
break
if proc.is_alive():
proc.join(remaining)
for proc in procs:
if proc.is_alive() and (pid := proc.pid) is not None:
kill_process_tree(pid)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> torch.Tensor:
"""
Copy the first length elements of a tensor into another tensor in a
non-blocking manner.
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
Returns the sliced target tensor.
"""
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
def report_usage_stats(
vllm_config,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
"""Report usage statistics if enabled."""
if not is_usage_stats_enabled():
return
from vllm.model_executor.model_loader import get_architecture_class_name
usage_message.report_usage(
get_architecture_class_name(vllm_config.model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(vllm_config.model_config.dtype),
"tensor_parallel_size":
vllm_config.parallel_config.tensor_parallel_size,
"block_size":
vllm_config.cache_config.block_size,
"gpu_memory_utilization":
vllm_config.cache_config.gpu_memory_utilization,
# Quantization
"quantization":
vllm_config.model_config.quantization,
"kv_cache_dtype":
str(vllm_config.cache_config.cache_dtype),
# Feature flags
"enable_lora":
bool(vllm_config.lora_config),
"enable_prompt_adapter":
bool(vllm_config.prompt_adapter_config),
"enable_prefix_caching":
vllm_config.cache_config.enable_prefix_caching,
"enforce_eager":
vllm_config.model_config.enforce_eager,
"disable_custom_all_reduce":
vllm_config.parallel_config.disable_custom_all_reduce,
})

View File

View File

@@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.logger import init_logger
from vllm.utils import cdiv
logger = init_logger(__name__)
class BlockTable:
def __init__(
self,
max_num_reqs: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
self.pin_memory = pin_memory
self.device = device
self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
def append_row(
self,
block_ids: list[int],
row_idx: int,
) -> None:
if not block_ids:
return
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
def swap_row(self, src: int, tgt: int) -> None:
num_blocks_src = self.num_blocks_per_row[src]
num_blocks_tgt = self.num_blocks_per_row[tgt]
self.num_blocks_per_row[src] = num_blocks_tgt
self.num_blocks_per_row[tgt] = num_blocks_src
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
def commit(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)
def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)
def get_device_tensor(self) -> torch.Tensor:
"""Ruturns the device tensor of the block table."""
return self.block_table
def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table_cpu
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, block_sizes: list[int]) -> None:
self.block_tables = [
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
max_num_batched_tokens, pin_memory, device)
for block_size in block_sizes
]
def append_row(self, block_ids: tuple[list[int], ...],
row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)
def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.move_row(src, tgt)
def swap_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def commit(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit(num_reqs)
def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager
from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
class CPUModelRunner(GPUModelRunner):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
super().__init__(vllm_config, device)
assert device == torch.device("cpu")
assert self.speculative_config is None, "spec decode is not supported."
self.use_cuda_graph = False
self.cascade_attn_enabled = False
self._postprocess_tenosrs()
def _postprocess_tenosrs(self) -> None:
# Note: replace device tensors with cpu tensors
def replace_tensor(obj: Any, cpu_attr_name: str,
device_attr_name) -> None:
cpu_tensor = getattr(obj, cpu_attr_name, None)
device_tensor = getattr(obj, device_attr_name, None)
if cpu_tensor is not None and device_tensor is not None:
assert isinstance(cpu_tensor, torch.Tensor)
assert isinstance(device_tensor, torch.Tensor)
setattr(obj, device_attr_name, cpu_tensor)
for k, v in vars(self).items():
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
replace_tensor(self, k, k[:-4])
for k, v in vars(self.input_batch).items():
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
replace_tensor(self.input_batch, k, k[:-11])
for k, v in vars(self.input_batch.block_table).items():
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
replace_tensor(self.input_batch.block_table, k, k[:-4])
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
self.model = self.load_lora_model(self.model, self.model_config,
self.scheduler_config,
self.lora_config, self.device)
def warming_up_model(self) -> None:
logger.info("Warming up model for the compilation...")
# Only generate graph for the generic shape
self._dummy_run(max(16, self.max_num_reqs))
logger.info("Warming up done.")
def _init_device_properties(self) -> None:
pass
def _sync_device(self) -> None:
pass
@contextmanager
def _set_global_compilation_settings():
import torch._inductor.config
# Note: The CPPGEMM backend requires freezing parameters.
freezing_value = torch._inductor.config.freezing
torch._inductor.config.freezing = True
# Note: workaround for "ValueError: fast mode: can't pickle cyclic objects
# including object type dict"
force_disable_caches = torch._inductor.config.force_disable_caches
torch._inductor.config.force_disable_caches = True
yield
torch._inductor.config.freezing = freezing_value
torch._inductor.config.force_disable_caches = force_disable_caches

View File

@@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
import os
from importlib import util
from typing import Optional
import torch
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import IntermediateTensors
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import (Worker,
init_worker_distributed_environment)
logger = init_logger(__name__)
class CPUWorker(Worker):
def __init__(self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False):
super().__init__(vllm_config,
local_rank,
rank,
distributed_init_method,
is_driver_worker=is_driver_worker)
self.parallel_config.disable_custom_all_reduce = True
def init_device(self):
# Setup OpenMP threads affinity.
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
self.local_omp_cpuid = "all"
if omp_cpuids == "auto":
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
)
else:
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
if self.local_omp_cpuid != "all":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret:
logger.info(ret)
# Note: unique identifier for creating allreduce shared memory
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
":")[-1]
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank, "gloo")
# Set random seed.
set_random_seed(self.model_config.seed)
# Construct the model runner
self.model_runner: CPUModelRunner = CPUModelRunner(
self.vllm_config, torch.device("cpu"))
def sleep(self, level: int = 1) -> None:
logger.warning("sleep mode is not supported on CPU, ignore it.")
pass
def wake_up(self, tags: Optional[list[str]] = None) -> None:
logger.warning("sleep mode is not supported on CPU, ignore it.")
pass
def determine_available_memory(self) -> int:
return self.cache_config.cpu_kvcache_space_bytes # type: ignore
def compile_or_warm_up_model(self) -> None:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return None
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
"""Return CPUs id binding based on NUMA nodes.
"""
rank_to_cpus = self.local_omp_cpuid
# Setup OpenMP thread affinity based on NUMA nodes automatically
world_size = self.vllm_config.parallel_config.world_size
libnuma_found = util.find_spec("numa") is not None
psutil_found = util.find_spec("psutil") is not None
if libnuma_found and psutil_found:
import psutil
from numa import info
cpu_count = psutil.cpu_count(logical=False)
cpus_allow_list = psutil.Process().cpu_affinity()
numa_size = info.get_num_configured_nodes()
cpu_count_per_numa = cpu_count // numa_size
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
cpu_count_per_numa // 2)
# check allow node_to_cpus list
node_to_cpus = []
for i in range(numa_size):
node_intersect = set(
info.node_to_cpus(i)).intersection(cpus_allow_list)
if bool(node_intersect):
node_to_cpus.append(list(node_intersect))
if world_size > len(node_to_cpus):
logger.error(
"Auto thread-binding failed due to "
"world size: %d is larger than "
"allowed NUMA nodes number: %d."
"Please try to bind threads manually.", world_size,
len(node_to_cpus))
else:
end = cpu_count_per_numa - num_of_reserved_cpu
rank_to_cpus_list = node_to_cpus[self.rank][:end]
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
logger.info("auto thread-binding list: %s", rank_to_cpus)
else:
logger.warning(
"Auto thread-binding is not supported due to "
"the lack of package numa and psutil,"
"fallback to no thread-binding. To get better performance,"
"please try to manually bind threads.")
return rank_to_cpus

View File

@@ -0,0 +1,681 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining an input batch
from dataclasses import dataclass
from typing import Optional, cast
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable
_SAMPLING_EPS = 1e-5
@dataclass
class CachedRequestState:
req_id: str
prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
generator: Optional[torch.Generator]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
output_token_ids: list[int]
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx]
else:
return self.output_token_ids[idx - self.num_prompt_tokens]
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self._req_ids: list[Optional[str]] = []
self.req_id_to_index: dict[str, int] = {}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
)
# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device=device)
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set()
self.min_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.min_p_reqs: set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures
self.presence_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
)
self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures
self.repetition_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# req_index -> (min_tokens, stop_token_ids)
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
self.logit_bias: list[Optional[dict[int,
float]]] = [None] * max_num_reqs
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.req_output_token_ids: list[Optional[list[int]]] = []
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
# while performing state updates to the batch.
return cast(list[str], self._req_ids)
def add_request(
self,
request: "CachedRequestState",
req_index: Optional[int] = None,
) -> None:
if req_index is None:
req_index = self.num_reqs
assert req_index < self.max_num_reqs
req_id = request.req_id
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
sampling_params = request.sampling_params
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.min_tokens:
self.min_tokens[req_index] = (sampling_params.min_tokens,
sampling_params.all_stop_token_ids)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0
def remove_request(self, req_id: str) -> Optional[int]:
"""This method must always be followed by a call to condense()."""
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.min_p_reqs.discard(req_id)
self.min_tokens.pop(req_index, None)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
self.lora_id_to_request_ids[lora_id].discard(req_id)
if len(self.lora_id_to_request_ids[lora_id]) == 0:
self.lora_id_to_request_ids.pop(lora_id)
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0
self.logit_bias[req_index] = None
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] =\
self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
self.num_tokens[i1], self.num_tokens[i2] =\
self.num_tokens[i2], self.num_tokens[i1]
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
self.min_p_cpu[i2], self.min_p_cpu[i1]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporiarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.min_tokens, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
self.logit_bias[i1], self.logit_bias[i2] =\
self.logit_bias[i2], self.logit_bias[i1]
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[i1], \
self.allowed_token_ids_mask_cpu_tensor[i2] =\
self.allowed_token_ids_mask_cpu_tensor[i2], \
self.allowed_token_ids_mask_cpu_tensor[i1]
self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: list[int]) -> None:
num_reqs = self.num_reqs
if num_reqs == 0:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = empty_req_indices.pop()
if empty_index >= last_req_index:
break
# Swap the states.
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
self._req_ids[empty_index] = req_id
self._req_ids[last_req_index] = None
self.req_output_token_ids[empty_index] = output_token_ids
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[
empty_index] = self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[
empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index]
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
min_token = self.min_tokens.pop(last_req_index, None)
if min_token is not None:
self.min_tokens[empty_index] = min_token
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
last_req_index]
bad_words_token_ids = self.bad_words_token_ids.pop(
last_req_index, None)
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index -= 1
# Trim lists to the batch size.
del self._req_ids[self.num_reqs:]
del self.req_output_token_ids[self.num_reqs:]
def refresh_sampling_metadata(self):
self.sampling_metadata = self._make_sampling_metadata()
def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
if not self.all_greedy:
temperature = copy_slice(self.temperature_cpu_tensor,
self.temperature, num_reqs)
else:
temperature = None
if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
if not self.no_min_p:
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs)
copy_slice(self.presence_penalties_cpu_tensor,
self.presence_penalties, num_reqs)
copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs)
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor()
else:
prompt_token_ids = None
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
min_p=None if self.no_min_p else self.min_p[:num_reqs],
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
min_tokens=self.min_tokens,
no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:num_reqs],
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = self.token_ids_cpu[:self.
num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(
req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values())
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_min_p(self) -> bool:
return len(self.min_p_reqs) == 0
@property
def no_penalties(self) -> bool:
return (len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0)
@property
def max_num_logprobs(self) -> Optional[int]:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,393 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import SchedulerOutput
class Worker(WorkerBase):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker)
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
# Save the buffers before level 2 sleep
if level == 2:
model = self.model_runner.model
self._sleep_saved_buffers = {
name: buffer.cpu().clone()
for name, buffer in model.named_buffers()
}
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags)
# Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers):
model = self.model_runner.model
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
def init_device(self):
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
# take current memory snapshot
self.init_snapshot = MemorySnapshot()
self.requested_memory = (self.init_snapshot.total_memory *
self.cache_config.gpu_memory_utilization)
if self.init_snapshot.free_memory < self.requested_memory:
GiB = lambda b: round(b / GiB_bytes, 2)
raise ValueError(
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
# Construct the model runner
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.load_model()
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the free memory that can be used for KV cache in
bytes.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
GiB = lambda b: b / GiB_bytes
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with memory_profiling(
self.init_snapshot,
weights_memory=int(
self.model_runner.model_memory_usage)) as profile_result:
self.model_runner.profile_run()
free_gpu_memory = profile_result.after_profile.free_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert self.init_snapshot.free_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
f"current free memory {GiB(free_gpu_memory)} GiB. "
"This happens when other processes sharing the same container "
"release GPU memory while vLLM is profiling during initialization. "
"To fix this, ensure consistent GPU memory allocation or "
"isolate vLLM in its own container.")
available_kv_cache_memory = self.requested_memory \
- profile_result.non_kv_cache_memory
logger.debug(
"Initial free memory: %.2f GiB, free memory: %.2f GiB, "
"requested GPU memory: %.2f GiB",
GiB(self.init_snapshot.free_memory), GiB(free_gpu_memory),
GiB(self.requested_memory))
logger.debug(profile_result)
logger.info("Available KV cache memory: %.2f GiB",
GiB(available_kv_cache_memory))
gc.collect()
return int(available_kv_cache_memory)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
self.model_runner._dummy_sampler_run(
hidden_states=self.model_runner._dummy_run(
num_tokens=max_num_reqs))
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
parallel_config = self.vllm_config.parallel_config
if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return None
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
print(self.profiler.key_averages().table(
sort_by="self_cuda_time_total"))
def execute_dummy_batch(self) -> None:
self.model_runner._dummy_run(1)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> set[int]:
return self.model_runner.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.model_executor.model_loader import ShardedStateLoader
ShardedStateLoader.save_model(
self.model_runner.model,
path,
pattern=pattern,
max_size=max_size,
)
def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
) -> None:
self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config, )
def init_worker_distributed_environment(
vllm_config: VllmConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
backend: str = "nccl",
) -> None:
"""Initialize the distributed environment."""
parallel_config = vllm_config.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: # noqa: SIM102
if not current_platform.has_device_capability(80):
capability = current_platform.get_device_capability()
gpu_name = current_platform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.")

View File

@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Define LoRA functionality mixin for model runners.
"""
from contextlib import contextmanager
import numpy as np
import torch.nn as nn
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.v1.worker.gpu_input_batch import InputBatch
logger = init_logger(__name__)
# Defined as a mixin for GPUModelRunner
class LoRAModelRunnerMixin:
LORA_WARMUP_RANK = 8
def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
scheduler_config: SchedulerConfig,
lora_config: LoRAConfig, device: str) -> nn.Module:
if not supports_lora(model):
raise ValueError(
f"{model.__class__.__name__} does not support LoRA yet.")
if supports_multimodal(model):
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# Use get_text_config() in case of multimodal models
text_config = model_config.hf_config.get_text_config()
# Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager(
scheduler_config.max_num_seqs,
scheduler_config.max_num_batched_tokens,
model_config.get_vocab_size(),
lora_config,
device,
model.embedding_modules,
model.embedding_padding_modules,
max_position_embeddings=text_config.max_position_embeddings,
)
return self.lora_manager.create_lora_manager(model)
def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: tuple[int, ...],
lora_requests: set[LoRARequest]) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
# Set is_prefill to True, so we always use the SGMV kernels on
# non-cuda platforms.
# On cuda platforms we use the same kernels for prefill and
# decode and this flag is generally ignored.
lora_mapping = LoRAMapping(token_lora_mapping,
prompt_lora_mapping,
is_prefill=True)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def set_active_loras(self, input_batch: InputBatch,
num_scheduled_tokens: np.ndarray) -> None:
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
token_lora_mapping: tuple[int,
...] # of size np.sum(num_scheduled_tokens)
lora_requests: set[LoRARequest]
prompt_lora_mapping, token_lora_mapping, lora_requests = \
input_batch.make_lora_inputs(num_scheduled_tokens)
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
lora_requests)
@contextmanager
def maybe_setup_dummy_loras(self, lora_config):
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_loras = lora_config.max_loras
# Make dummy lora requests
lora_requests: set[LoRARequest] = {
LoRARequest(lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path")
for lora_id in range(1, num_loras + 1)
}
with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
# load from disk.
for lr in lora_requests:
self.lora_manager.add_dummy_lora(
lr, rank=self.LORA_WARMUP_RANK)
yield
# __exit__ code
self.lora_manager.remove_all_adapters()
@contextmanager
def maybe_select_dummy_loras(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_reqs = len(num_scheduled_tokens)
num_loras = lora_config.max_loras
# Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) %
num_loras) + 1
# Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping,
num_scheduled_tokens)
# Make dummy lora requests
lora_requests: set[LoRARequest] = {
LoRARequest(lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path")
for lora_id in range(1, num_loras + 1)
}
self._set_active_loras(tuple(prompt_lora_mapping),
tuple(token_lora_mapping), lora_requests)
yield
@contextmanager
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
with self.maybe_setup_dummy_loras(
lora_config), self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens):
yield
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,299 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class."""
import os
from typing import Optional
import torch
import torch.distributed
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache, report_usage_stats
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
logger = init_logger(__name__)
class TPUWorker:
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
self.is_driver_worker = is_driver_worker
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.use_spmd = envs.VLLM_XLA_USE_SPMD
self.original_parallel_config = None
if self.use_spmd:
# Under SPMD mode, distributed env is initialized as if there is
# only one worker/device.
self.original_parallel_config = self.parallel_config
self.parallel_config.tensor_parallel_size = 1
self.parallel_config.pipeline_parallel_size = 1
self.parallel_config.world_size = 1
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Delay profiler initialization to the start of the profiling.
# This is because in vLLM V1, MP runtime is initialized before the
# TPU Worker is initialized. The profiler server needs to start after
# MP runtime is initialized.
self.profiler = None
self.profile_dir = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir)
if self.model_config.seed is None:
self.model_config.seed = 0
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
# ring, the xla tpu compiler flag
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
# fix this. It will be removed after the bug in XLA compiler is fixed.
os.environ["LIBTPU_INIT_ARGS"] = (
os.environ.get("LIBTPU_INIT_ARGS", "") +
" --xla_tpu_force_1d_allreduce_at_chunk_count=1"
" --xla_jf_conv_input_fusion=False")
# --xla_jf_conv_input_fusion=False is used to improve the perf of
# quantized matmul.
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)
# Initialize the distributed environment.
self._init_tpu_worker_distributed_environment(
self.parallel_config, self.rank, self.distributed_init_method,
self.local_rank)
# Device initialization should happen after initializing
# the distributed runtime.
self.device = xm.xla_device()
self.device_config.device = self.device
# Set random seed.
set_random_seed(self.model_config.seed)
if self.model_config.seed is not None:
xm.set_rng_state(self.model_config.seed, self.device)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# TODO (NickLucche) On gsm we compile 80+ graphs.
# Re-evaluate limit, with MM we may get close to this limit.
torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size = self.parallel_config.world_size
rank = xr.global_ordinal()
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
# Consequently, changes in optimization flags, which affect compilation
# results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if envs.VLLM_XLA_CACHE_PATH:
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False)
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = \
TPUModelRunner(self.vllm_config, self.device,
self.original_parallel_config)
if rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
def determine_available_memory(self) -> int:
kv_caches: dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, AttentionSpec):
dtype = layer_spec.dtype
# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
kv_caches[layer_name] = tpu_kv_cache
else:
raise NotImplementedError(
f"Unsupported KV cache spec '{type(layer_spec)}'")
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)
# `max_num_tokens >= max_num_batched_tokens` due to padding.
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
self.model_runner.profile_run(self.model_runner.max_num_tokens)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
self.model_runner.reset_dynamo_cache()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
if self.use_spmd:
# This is a workaround for the TPU SPMD mode. The get_memory_info
# API doesn't work with SPMD mode in PyTorch/XLA.
# TODO: use xm.get_memory_info for SPMD once it's supported in
# PyTorch/XLA.
import tpu_info
chip_type, _ = tpu_info.device.get_local_chips()
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
total_memory_size = device_usage[0].total_memory
current_mem = device_usage[0].memory_usage
else:
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
current_mem = m["bytes_used"]
# Ideally we would use profiled = m["peak_bytes_used"] to
# get weights + activations. But there is memory used during
# compilation / weight loading that impacts the peak and
# there is no way to reset peak memory in XLA, So we
# use the heuristic of 2% of weights.
profiled = current_mem * 1.02
# Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(total_memory_size *
self.cache_config.gpu_memory_utilization)
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
return int(tpu_kv_cache_bytes)
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
return output if self.is_driver_worker else None
def profile(self, is_start: bool = True):
if self.rank < 1:
if self.profile_dir is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
if self.profiler is None:
self.profiler = xp.start_server(9012)
xp.start_trace(self.profile_dir)
else:
xp.stop_trace()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def load_model(self) -> None:
self.model_runner.load_model()
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
self.model_runner.initialize_kv_cache(kv_cache_config)
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def _init_tpu_worker_distributed_environment(
self,
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
if self.use_spmd:
xr.use_spmd()
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
try:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
TPUWorker = TPUCommonsWorker # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
pass

111
vllm/v1/worker/utils.py Normal file
View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
def sanity_check_mm_encoder_outputs(
mm_embeddings: object,
expected_num_items: int,
) -> None:
"""
Perform sanity checks for the result of
[`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
"""
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
f"or a single 3D tensor, but got {type(mm_embeddings)} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
assert len(mm_embeddings) == expected_num_items, (
"Expected number of multimodal embeddings to match number of "
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
assert all(e.ndim == 2 for e in mm_embeddings), (
"Expected multimodal embeddings to be a sequence of 2D tensors, "
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
def scatter_mm_placeholders(
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
[`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def gather_mm_placeholders(
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of [scatter_mm_placeholders][].
"""
if is_embed is None:
return placeholders
return placeholders[is_embed]
def initialize_kv_cache_for_kv_sharing(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec],
kv_caches: dict[str, torch.Tensor],
) -> None:
"""
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
for layers that do not allocate its own KV cache, based on the mapping in
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
group, which is needed to ensure that attention metadata is assigned later.
Args:
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
If an Attention layer `layer_name` is in the keys of this dict, it
means this layer will perform attention using the keys and values
from the KV cache of `shared_kv_cache_layers[layer_name]`.
kv_cache_groups: The KV cache groups of the model.
kv_caches: The allocated kv_caches with layer names as keys.
Note that layers in shared_kv_cache_layers.keys() are not
originally included as it only contains layers which have its own
KV cache allocation.
"""
# Record index of KV cache group for each layer that allocates a KV cache.
layer_to_kv_cache_group_idx: dict[str, int] = {}
for i, kv_cache_group in enumerate(kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
layer_to_kv_cache_group_idx[layer_name] = i
for layer_name, target_layer_name in shared_kv_cache_layers.items():
kv_caches[layer_name] = kv_caches[target_layer_name]
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
kv_cache_groups[group_idx].layer_names.append(layer_name)

View File

@@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
logger = init_logger(__name__)
class WorkerBase(WorkerBaseV0):
"""
Abstract class for v1 worker, mainly define some methods for v1.
For methods shared by v0 and v1, define them in v0 WorkerBase
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
"""
Initialize common worker components.
Args:
vllm_config: Complete vLLM configuration
local_rank: Local device index
rank: Global rank in distributed setup
distributed_init_method: Distributed initialization method
is_driver_worker: Whether this worker handles driver
responsibilities
"""
# Configuration storage
super().__init__(vllm_config=vllm_config)
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
# Device and model state
self.device: Optional[torch.device] = None
self.model_runner: Optional[nn.Module] = None
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""Get specifications for KV cache implementation."""
raise NotImplementedError
def compile_or_warm_up_model(self) -> None:
"""Prepare model for execution through compilation/warmup."""
raise NotImplementedError
def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
return