[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
0
vllm/v1/__init__.py
Normal file
0
vllm/v1/__init__.py
Normal file
0
vllm/v1/attention/__init__.py
Normal file
0
vllm/v1/attention/__init__.py
Normal file
0
vllm/v1/attention/backends/__init__.py
Normal file
0
vllm/v1/attention/backends/__init__.py
Normal file
167
vllm/v1/attention/backends/cpu_attn.py
Normal file
167
vllm/v1/attention/backends/cpu_attn.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
|
||||
TorchSDPAMetadata)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
class TorchSDPABackend:
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_SDPA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
|
||||
return TorchSDPABackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return TorchSDPAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
|
||||
return TorchSDPAMetadataBuilderV1
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable) -> None:
|
||||
self.runner = runner
|
||||
self.block_table = block_table
|
||||
|
||||
# For reorder
|
||||
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
|
||||
dtype=np.int64)
|
||||
self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
|
||||
dtype=np.int64)
|
||||
self.num_prompt_req: int = 0
|
||||
|
||||
self.seq_start_loc_cpu = torch.zeros(
|
||||
runner.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
prompt_list_idx = 0
|
||||
decode_list_idx = 0
|
||||
for req_index in range(input_batch.num_reqs):
|
||||
if input_batch.num_computed_tokens_cpu[
|
||||
req_index] < input_batch.num_prompt_tokens[req_index]:
|
||||
# prompt stage
|
||||
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
|
||||
prompt_list_idx += 1
|
||||
else:
|
||||
# decode stage
|
||||
self.reorder_decode_req_index_list[decode_list_idx] = req_index
|
||||
decode_list_idx += 1
|
||||
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
|
||||
|
||||
# Update prompt requests number
|
||||
self.num_prompt_req = prompt_list_idx
|
||||
|
||||
reorder_req_num = 0
|
||||
for req_index in range(decode_list_idx):
|
||||
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
|
||||
reorder_req_num += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if reorder_req_num == 0:
|
||||
return False
|
||||
|
||||
reorder_prompt_list = (
|
||||
self.reorder_prompt_req_index_list[:prompt_list_idx]
|
||||
[-reorder_req_num:])
|
||||
reorder_decode_list = (
|
||||
self.reorder_decode_req_index_list[:decode_list_idx]
|
||||
[:reorder_req_num])
|
||||
assert reorder_decode_list.size == reorder_prompt_list.size
|
||||
|
||||
for idx in range(reorder_req_num):
|
||||
prompt_req_index = reorder_prompt_list[idx].item()
|
||||
decode_req_index = reorder_decode_list[idx].item()
|
||||
input_batch.swap_states(prompt_req_index, decode_req_index)
|
||||
|
||||
return True
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
runner = self.runner
|
||||
block_table = self.block_table
|
||||
seq_lens_np = runner.seq_lens_np[:num_reqs]
|
||||
num_prompt_req = self.num_prompt_req
|
||||
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
|
||||
) if num_prompt_req > 0 else 0
|
||||
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
|
||||
) if num_prompt_req < num_reqs else 0
|
||||
self.seq_start_loc_np[0] = 0
|
||||
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
|
||||
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
|
||||
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
|
||||
) - num_prefill_tokens
|
||||
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
|
||||
block_table_tensor = block_table.get_device_tensor()
|
||||
attn_metadata = TorchSDPAMetadata(
|
||||
num_prefills=num_prompt_req,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens_tensor=runner.
|
||||
seq_lens_cpu[num_prompt_req:num_reqs], # decode
|
||||
max_decode_seq_len=max_decode_seq_len, # decode
|
||||
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
|
||||
chunked_prefill=True,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_prefill_seq_len,
|
||||
prefill_query_start_loc=runner.
|
||||
query_start_loc_cpu[:num_prompt_req + 1], # prefill
|
||||
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
|
||||
1], # prefill
|
||||
prefill_block_tables=block_table_tensor[:
|
||||
num_prompt_req], # prefill
|
||||
query_start_loc=runner.query_start_loc_cpu[:num_reqs +
|
||||
1], # for logits index
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
1060
vllm/v1/attention/backends/flash_attn.py
Normal file
1060
vllm/v1/attention/backends/flash_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
657
vllm/v1/attention/backends/flashinfer.py
Normal file
657
vllm/v1/attention/backends/flashinfer.py
Normal file
@@ -0,0 +1,657 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashInfer."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper)
|
||||
#MultiLevelCascadeAttentionWrapper
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [64, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type[FlashInferImpl]:
|
||||
return FlashInferImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type[FlashInferMetadata]:
|
||||
return FlashInferMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type[FlashInferMetadataBuilder]:
|
||||
return FlashInferMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters.
|
||||
"""
|
||||
|
||||
window_left: int
|
||||
logits_soft_cap: Optional[float]
|
||||
sm_scale: float
|
||||
|
||||
|
||||
def get_per_layer_parameters(
|
||||
vllm_config: VllmConfig) -> dict[str, PerLayerParameters]:
|
||||
"""
|
||||
Scan all attention layers and determine some hyperparameters
|
||||
to use during `plan`.
|
||||
"""
|
||||
|
||||
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
per_layer_params: dict[str, PerLayerParameters] = {}
|
||||
|
||||
for key, layer in layers.items():
|
||||
impl = layer.impl
|
||||
assert isinstance(impl, FlashInferImpl)
|
||||
|
||||
# Infer hyperparameters from the attention layer
|
||||
window_size = impl.sliding_window
|
||||
window_left = window_size[0] if window_size is not None else -1
|
||||
logits_soft_cap = impl.logits_soft_cap
|
||||
sm_scale = impl.scale
|
||||
|
||||
per_layer_params[key] = PerLayerParameters(window_left,
|
||||
logits_soft_cap, sm_scale)
|
||||
|
||||
return per_layer_params
|
||||
|
||||
|
||||
def infer_global_hyperparameters(
|
||||
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters:
|
||||
- `window_left`
|
||||
- `logits_soft_cap`
|
||||
- `sm_scale`
|
||||
|
||||
So this function asserts that all layers share the same values for these
|
||||
hyperparameters and returns the global values.
|
||||
"""
|
||||
|
||||
assert len(per_layer_params) > 0, "No attention layers found in the model."
|
||||
|
||||
param_sets = list(per_layer_params.values())
|
||||
global_params = param_sets[0]
|
||||
for params in param_sets:
|
||||
assert params == global_params, (
|
||||
"FlashInfer backend currently only supports models in which all "
|
||||
"layers share the same values for the following hyperparameters: "
|
||||
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
||||
|
||||
return global_params
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata:
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
qo_indptr: torch.Tensor
|
||||
# An example for paged_kv_indices, paged_kv_indptr:
|
||||
# request 1, page indices [0, 5, 8]
|
||||
# request 2, page indices [1, 6, 7]
|
||||
# request 3, page indices [3, 4]
|
||||
# paged_kv_indices is a concatenation of page indices of all requests:
|
||||
# [0, 5, 8, 1, 6, 7, 3, 4]
|
||||
# paged_kv_indptr is used to index into paged_kv_indices:
|
||||
# [0, 3, 6, 8]
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: torch.Tensor
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: torch.Tensor
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_len: torch.Tensor
|
||||
# The number of query/output heads
|
||||
num_qo_heads: int
|
||||
# The number of key/value heads
|
||||
num_kv_heads: int
|
||||
# The dimension of the attention heads
|
||||
head_dim: int
|
||||
# Block size of vllm
|
||||
page_size: int
|
||||
# The data type of the paged kv cache
|
||||
data_type: torch.dtype
|
||||
# The data type of the query
|
||||
q_data_type: torch.dtype
|
||||
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
shared_qo_indptr: Optional[torch.Tensor] = None
|
||||
shared_kv_page_indptr: Optional[torch.Tensor] = None
|
||||
shared_kv_page_indices: Optional[torch.Tensor] = None
|
||||
shared_kv_last_page_len: Optional[torch.Tensor] = None
|
||||
|
||||
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
|
||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||
# cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
|
||||
|
||||
@property
|
||||
def query_start_loc(self):
|
||||
# The GPUModelRunner expects to be able to access this property.
|
||||
return self.qo_indptr
|
||||
|
||||
def __post_init__(self):
|
||||
# Refer to
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
|
||||
if self.head_dim is not None and self.head_dim \
|
||||
not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f" received {self.head_dim}.")
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
self.runner = runner
|
||||
self._workspace_buffer = None
|
||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||
self._decode_wrapper = None # Wrapper for decode
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
# Global hyperparameters shared by all attention layers
|
||||
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
||||
|
||||
self.vllm_config = runner.vllm_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the decode run only supports num_tokens = 1
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.runner.device)
|
||||
return self._workspace_buffer
|
||||
|
||||
def _get_prefill_wrapper(self):
|
||||
if self._prefill_wrapper is None:
|
||||
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(), "NHD")
|
||||
return self._prefill_wrapper
|
||||
|
||||
def _get_decode_wrapper(self):
|
||||
if self._decode_wrapper is None:
|
||||
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config))
|
||||
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
||||
self.runner.parallel_config)
|
||||
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
|
||||
num_qo_heads // num_kv_heads > 4)
|
||||
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(),
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores)
|
||||
return self._decode_wrapper
|
||||
|
||||
# def _get_cascade_wrapper(self):
|
||||
# if self._cascade_wrapper is None:
|
||||
# self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
|
||||
# 2, self._get_workspace_buffer(), "NHD")
|
||||
# return self._cascade_wrapper
|
||||
|
||||
def _plan(self, attn_metadata: FlashInferMetadata):
|
||||
if self.global_hyperparameters is None:
|
||||
self.global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(self.vllm_config))
|
||||
if attn_metadata.use_cascade and False: # not supported
|
||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||
attn_metadata.cascade_wrapper.plan(
|
||||
[attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr],
|
||||
[
|
||||
attn_metadata.shared_kv_page_indptr,
|
||||
attn_metadata.paged_kv_indptr
|
||||
],
|
||||
[
|
||||
attn_metadata.shared_kv_page_indices,
|
||||
attn_metadata.paged_kv_indices
|
||||
],
|
||||
[
|
||||
attn_metadata.shared_kv_last_page_len,
|
||||
attn_metadata.paged_kv_last_page_len
|
||||
],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
)
|
||||
else:
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
if self._num_prefills > 0:
|
||||
# Decodes are first so prefills start after the last decode
|
||||
prefill_start = self._num_decodes
|
||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
||||
assert attn_metadata.qo_indptr[prefill_start:].shape[
|
||||
0] == self._num_prefills + 1
|
||||
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
|
||||
0] == self._num_prefills + 1
|
||||
assert attn_metadata.paged_kv_last_page_len[
|
||||
prefill_start:].shape[0] == self._num_prefills
|
||||
# Since prefill_wrapper.run() will be called with
|
||||
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
||||
# to be relative to the start of the prefill queries.
|
||||
qo_indptr = attn_metadata.qo_indptr[
|
||||
prefill_start:] - attn_metadata.qo_indptr[prefill_start]
|
||||
attn_metadata.prefill_wrapper.plan(
|
||||
qo_indptr,
|
||||
attn_metadata.paged_kv_indptr[prefill_start:],
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len[prefill_start:],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.data_type,
|
||||
)
|
||||
|
||||
if self._num_decodes > 0:
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
||||
attn_metadata.decode_wrapper.plan(
|
||||
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len[:self._num_decodes],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.data_type,
|
||||
)
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
assert (self._num_decode_tokens +
|
||||
self._num_prefill_tokens == num_actual_tokens)
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
device = self.runner.device
|
||||
qo_indptr = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
# Grab the blocks of the shared prefix from the first request.
|
||||
assert common_prefix_len % page_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // page_size
|
||||
shared_qo_indptr = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
shared_kv_page_indices = block_table_tensor[
|
||||
0, :num_common_kv_blocks]
|
||||
shared_kv_last_page_len = torch.tensor([page_size],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
# Remove the blocks of the shared prefix from all requests.
|
||||
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
||||
block_table_bounds -= num_common_kv_blocks
|
||||
else:
|
||||
shared_qo_indptr = None
|
||||
shared_kv_page_indptr = None
|
||||
shared_kv_page_indices = None
|
||||
shared_kv_last_page_len = None
|
||||
|
||||
mask = (torch.arange(block_table_tensor.size(1),
|
||||
dtype=block_table_tensor.dtype,
|
||||
device=block_table_tensor.device).unsqueeze(0)
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_indptr = torch.cat([
|
||||
torch.zeros(1,
|
||||
dtype=block_table_bounds.dtype,
|
||||
device=block_table_bounds.device),
|
||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
||||
])
|
||||
|
||||
paged_kv_last_page_len = seq_lens % page_size
|
||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||
page_size, paged_kv_last_page_len)
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
qo_indptr=qo_indptr,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
num_qo_heads=self.runner.num_query_heads,
|
||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
||||
head_dim=self.kv_cache_spec.head_size,
|
||||
page_size=page_size,
|
||||
data_type=self.kv_cache_spec.dtype,
|
||||
q_data_type=self.runner.dtype,
|
||||
slot_mapping=slot_mapping,
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
num_prefills=self._num_prefills,
|
||||
num_prefill_tokens=self._num_prefill_tokens,
|
||||
use_cascade=use_cascade,
|
||||
shared_qo_indptr=shared_qo_indptr,
|
||||
shared_kv_page_indptr=shared_kv_page_indptr,
|
||||
shared_kv_page_indices=shared_kv_page_indices,
|
||||
shared_kv_last_page_len=shared_kv_last_page_len,
|
||||
)
|
||||
|
||||
self._plan(attn_metadata)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
logger.warning_once(
|
||||
"Using cascade attention in FlashInfer is not supported yet")
|
||||
return False
|
||||
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
|
||||
# TODO: The cascade wrapper currently does not support setting
|
||||
# kv cache dtype to something different from query dtype.
|
||||
return False
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
class FlashInferImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in FlashInfer is not supported yet, it will fall"
|
||||
" back to global attention for long context.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashInfer.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
window_left = (self.sliding_window[0]
|
||||
if self.sliding_window is not None else -1)
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
query = query[:num_actual_tokens]
|
||||
output_padded = output
|
||||
output = output[:num_actual_tokens]
|
||||
|
||||
# if attn_metadata.use_cascade:
|
||||
# # Cascade attention (rare case).
|
||||
# assert attn_metadata.cascade_wrapper is not None
|
||||
# output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
|
||||
# return output
|
||||
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
if prefill_wrapper := attn_metadata.prefill_wrapper:
|
||||
prefill_query = query[num_decode_tokens:]
|
||||
assert prefill_query.shape[0] == num_prefill_tokens
|
||||
assert prefill_wrapper is not None
|
||||
assert prefill_wrapper._causal
|
||||
assert prefill_wrapper._window_left == window_left
|
||||
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||
or 0.0)
|
||||
assert prefill_wrapper._sm_scale == self.scale
|
||||
prefill_wrapper.run(
|
||||
prefill_query,
|
||||
kv_cache,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[num_decode_tokens:],
|
||||
)
|
||||
|
||||
if decode_wrapper := attn_metadata.decode_wrapper:
|
||||
decode_query = query[:num_decode_tokens]
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
assert decode_wrapper is not None
|
||||
assert decode_wrapper._window_left == window_left
|
||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||
or 0.0)
|
||||
assert decode_wrapper._sm_scale == self.scale
|
||||
decode_wrapper.run(
|
||||
decode_query,
|
||||
kv_cache,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
|
||||
return output_padded
|
||||
473
vllm/v1/attention/backends/flex_attention.py
Normal file
473
vllm/v1/attention/backends/flex_attention.py
Normal file
@@ -0,0 +1,473 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Attention layer with FlashAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
||||
_score_mod_signature,
|
||||
create_block_mask,
|
||||
flex_attention)
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if current_platform.is_cuda():
|
||||
pass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
create_block_mask_compiled = torch.compile(create_block_mask,
|
||||
fullgraph=True,
|
||||
mode="reduce-overhead")
|
||||
flex_attention_compiled = torch.compile(flex_attention, fullgraph=True)
|
||||
|
||||
|
||||
def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
|
||||
device = offsets.device
|
||||
counts = offsets[1:] - offsets[:-1]
|
||||
return torch.repeat_interleave(
|
||||
torch.arange(len(counts), device=device, dtype=torch.int32), counts)
|
||||
|
||||
|
||||
class FlexAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [16, 32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLEX_ATTENTION"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlexAttentionImpl"]:
|
||||
return FlexAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return FlexAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]:
|
||||
return FlexAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# @torch.compile(fullgraph=True, mode="reduce-overhead")
|
||||
def physical_to_logical_mapping(
|
||||
block_table: torch.Tensor,
|
||||
total_blocks: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
Creates an inverse mapping from physical block locations to logical indices.
|
||||
|
||||
The original block_table maps from logical blocks to physical locations:
|
||||
|
||||
Logical to Physical (Original block_table):
|
||||
┌───────────────────────────────────────────┐
|
||||
│ Request 0: │
|
||||
│ │
|
||||
│ Logical Blocks: 0 1 2 3 4 5 6 7 │
|
||||
│ │ │ │ │ │ │ │ │ │
|
||||
│ v v v v v v v v │
|
||||
│ Physical Blocks: 3 5 1 7 4 2 0 6 │
|
||||
└───────────────────────────────────────────┘
|
||||
|
||||
This function creates the inverse mapping:
|
||||
|
||||
Physical to Logical (Inverse mapping):
|
||||
┌───────────────────────────────────────────┐
|
||||
│ Request 0: │
|
||||
│ │
|
||||
│ Physical Blocks: 0 1 2 3 4 5 6 7 │
|
||||
│ │ │ │ │ │ │ │ │ │
|
||||
│ v v v v v v v v │
|
||||
│ Logical Blocks: 6 2 5 0 4 1 7 3 │
|
||||
└───────────────────────────────────────────┘
|
||||
|
||||
If multiple logical blocks map to the same physical block,
|
||||
this function returns the first (minimum) logical block index.
|
||||
|
||||
If a physical block is not mapped to by any logical block,
|
||||
its value in the result will be -1.
|
||||
|
||||
|
||||
Args:
|
||||
block_table: Tensor of shape [max_reqs, max_num_blocks]
|
||||
mapping logical blocks to physical locations
|
||||
|
||||
Returns:
|
||||
A tensor of shape [max_reqs, max_physical_block]
|
||||
"""
|
||||
max_reqs, max_num_blocks = block_table.shape
|
||||
device = block_table.device
|
||||
|
||||
physical_to_logical = torch.full((max_reqs, total_blocks),
|
||||
-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
logical_indices = (torch.arange(max_num_blocks,
|
||||
device=device).unsqueeze(0).expand(
|
||||
max_reqs, -1))
|
||||
|
||||
physical_to_logical.scatter_(-1, block_table.to(torch.int64),
|
||||
logical_indices)
|
||||
# TODO Confirm - Seems like block 0 is always empty so we reset it manually
|
||||
physical_to_logical[:, 0] = -1
|
||||
return physical_to_logical
|
||||
|
||||
|
||||
def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlexAttentionMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# Block info
|
||||
total_cache_tokens: int
|
||||
block_size: int
|
||||
max_possible_sequence_length: int
|
||||
num_reqs: int
|
||||
physical_to_logical: torch.Tensor
|
||||
decode_offset: torch.Tensor
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
# Flex Metadata
|
||||
num_blocks = 0
|
||||
block_mask: Optional[BlockMask] = None
|
||||
score_mod: Optional[_score_mod_signature] = None
|
||||
mask_mod: Optional[_mask_mod_signature] = None
|
||||
logical_mask_mod: _mask_mod_signature = causal_mask_mod
|
||||
|
||||
def get_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the mask_mod function for FlexAttention.
|
||||
|
||||
This function creates the combined mask mod function that handles:
|
||||
1. The paged attention block mapping
|
||||
2. The mapping from packed query sequences to logical query entries
|
||||
|
||||
It also by defaults adds the decoding offset to the query indices.
|
||||
With this info we create the "logical" indices that are passed to
|
||||
mask_mod functions. This allows mask mod functions to be agnostic to
|
||||
layout of the query and key/value tensors.
|
||||
|
||||
TODO is_within_lower_bound: do sequences start on block_boundaries?
|
||||
"""
|
||||
# Create a lookup mapping from query indices -> request number
|
||||
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||
|
||||
def final_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
physical_kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Map query indices to corresponding request indices
|
||||
q_req = request_lookup[q_idx]
|
||||
|
||||
# Convert physical KV indices to logical indices
|
||||
physical_kv_block = physical_kv_idx // self.block_size
|
||||
physical_kv_offset = physical_kv_idx % self.block_size
|
||||
logical_block_idx = self.physical_to_logical[q_req,
|
||||
physical_kv_block]
|
||||
logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # noqa: E501
|
||||
|
||||
# Determine valid kv indices
|
||||
live_block = logical_block_idx >= 0
|
||||
within_upper_bound = logical_kv_idx < self.seq_lens[q_req]
|
||||
within_lower_bound = logical_kv_idx >= 0
|
||||
|
||||
is_valid = live_block & within_upper_bound & within_lower_bound
|
||||
|
||||
# Convert physical query indices to logical indices
|
||||
local_q_idx = q_idx - self.query_start_loc[q_req]
|
||||
logical_q_idx = local_q_idx + self.decode_offset[q_req]
|
||||
|
||||
# Apply mask modification only for valid indices
|
||||
return torch.where(
|
||||
is_valid,
|
||||
self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx),
|
||||
False,
|
||||
)
|
||||
|
||||
return final_mask_mod
|
||||
|
||||
def build_block_mask(self) -> BlockMask:
|
||||
assert self.mask_mod is not None
|
||||
return create_block_mask_compiled(
|
||||
self.mask_mod,
|
||||
None,
|
||||
None,
|
||||
self.num_actual_tokens,
|
||||
self.total_cache_tokens,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.use_cascade is False, "Not implemented yet."
|
||||
assert self.common_prefix_len == 0, "Not implemented yet."
|
||||
assert self.cu_prefix_query_lens is None, "Not implemented yet."
|
||||
assert self.prefix_kv_lens is None, "Not implemented yet."
|
||||
assert self.suffix_kv_lens is None, "Not implemented yet."
|
||||
self.num_blocks = self.total_cache_tokens // self.block_size
|
||||
self.mask_mod = self.get_mask_mod()
|
||||
self.block_mask = self.build_block_mask()
|
||||
|
||||
|
||||
class FlexAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlexAttentionMetadata]):
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
model_config = runner.model_config
|
||||
|
||||
self.runner = runner
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
runner.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
if use_cascade:
|
||||
raise NotImplementedError("Not yet my friend")
|
||||
|
||||
block_size = self.kv_cache_spec.block_size
|
||||
max_possible_seq_len = self.runner.model_config.max_model_len
|
||||
total_cache_tokens = (self.runner.cache_config.num_gpu_blocks *
|
||||
block_size)
|
||||
|
||||
inverse_block_table = physical_to_logical_mapping(
|
||||
block_table_tensor, self.runner.cache_config.num_gpu_blocks)
|
||||
|
||||
# Get the original offset tensor
|
||||
offset_tensor = torch.tensor(
|
||||
self.runner.input_batch.num_computed_tokens_cpu[:num_reqs]).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
|
||||
out = FlexAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
block_size=block_size,
|
||||
max_possible_sequence_length=max_possible_seq_len,
|
||||
num_reqs=num_reqs,
|
||||
physical_to_logical=inverse_block_table,
|
||||
total_cache_tokens=total_cache_tokens,
|
||||
decode_offset=offset_tensor,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class FlexAttentionImpl(AttentionImpl):
|
||||
sliding_window: Optional[tuple[int, int]]
|
||||
alibi_slopes: Optional[torch.Tensor]
|
||||
logits_soft_cap: Optional[float]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
# TODO we should support this :think
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support alibi slopes yet.")
|
||||
else:
|
||||
self.alibi_slopes = None
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support sliding window yet.")
|
||||
else:
|
||||
self.sliding_window = (-1, -1)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if self.logits_soft_cap is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support logits soft cap yet.")
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support kv sharing yet.")
|
||||
|
||||
support_head_sizes = FlexAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by FlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}. "
|
||||
"Set VLLM_USE_V1=0 to use another attention backend.")
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support quantized kv-cache. Yet")
|
||||
|
||||
@staticmethod
|
||||
def view_as_4d(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""View a 3d tensor as 4D."""
|
||||
if tensor.ndim == 4:
|
||||
return tensor
|
||||
assert tensor.ndim == 3
|
||||
return tensor[None, :, :, :]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlexAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FLexAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
enable_gqa = self.num_kv_heads != self.num_heads
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
|
||||
# return torch.empty_like(query)
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# View out the block_size dim
|
||||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
query, key_cache, value_cache = map(
|
||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||
(query, key_cache, value_cache),
|
||||
)
|
||||
query = query[:, :, :num_actual_tokens, :]
|
||||
# Doesn't work for now -> constraint violation
|
||||
# torch._dynamo.try_mark_dynamic(query, 2)
|
||||
out = flex_attention_compiled(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.score_mod,
|
||||
attn_metadata.block_mask,
|
||||
self.scale,
|
||||
enable_gqa=enable_gqa,
|
||||
kernel_options={"FORCE_USE_FLEX_ATTENTION": True},
|
||||
)
|
||||
|
||||
# Flex doesn't have an out variant today, rely on epilogue fusion
|
||||
out = out.permute(0, 2, 1, 3).squeeze(0)
|
||||
output[:num_actual_tokens, :, :].copy_(out)
|
||||
return output
|
||||
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
975
vllm/v1/attention/backends/mla/common.py
Normal file
975
vllm/v1/attention/backends/mla/common.py
Normal file
@@ -0,0 +1,975 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
# MLA Common Components
|
||||
|
||||
This file implements common components for MLA implementations.
|
||||
|
||||
First we define:
|
||||
|
||||
Sq as Q sequence length
|
||||
Skv as KV sequence length
|
||||
|
||||
MLA has two possible ways of computing, a data-movement friendly approach and a
|
||||
compute friendly approach, we generally want to use the compute friendly
|
||||
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
|
||||
and the data-movement friendly approach for "decode" (i.e. the ratio
|
||||
Sq / Skv is "large").
|
||||
|
||||
NOTE what we deem small and large is currently determined by if its labelled
|
||||
prefill or decode by the scheduler, but this is something we should probably
|
||||
tune.
|
||||
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||
|
||||
Deepseek's MLA attention works the following way:
|
||||
* Use a single latent vector to represent the per-token entry of the KV cache.
|
||||
* For decode (i.e. the memory friendly approach) the attention "simulates" a
|
||||
multi-head attention, while the compute is similar to multi-query attention.
|
||||
|
||||
Below is example of both paths assuming batchsize = 1
|
||||
|
||||
## More Extent Definitions:
|
||||
|
||||
C Context length, `Skv - Sq`
|
||||
H hidden size
|
||||
N number of attention heads
|
||||
Lq latent dimension for Q 1536 in DSV3
|
||||
Lkv latent dimension for K/V 512 in DSV3
|
||||
P nope dimension, no rope. 128 in DSV3
|
||||
R rope dimension, goes through rope. 64 in DSV3
|
||||
V V head dim. 128 in DSV3
|
||||
|
||||
## Vector/Matrix Definitions
|
||||
|
||||
h_t hidden states (input to attention) shape [Sq, H]
|
||||
q_c latent/compressed Q shape [Sq, Lq]
|
||||
q_nope uncompressed Q (no-rope) shape [Sq, N, P]
|
||||
q_pe uncompressed Q (rope) shape [Sq, N, R]
|
||||
kv_c latent/compressed KV shape [Skv, Lkv]
|
||||
k_pe decoupled k position embeddings shape [Skv, R]
|
||||
new_kv_c new kv_c from current iter shape [Sq, Lkv]
|
||||
new_k_pe new k_pe from current iter shape [Sq, R]
|
||||
cache_kv_c cached k_c from previous iters shape [C, Lkv]
|
||||
cache_k_pe cached k_pe from previous iters shape [C, R]
|
||||
W_DQ project h_t to q_c shape [H, Lq]
|
||||
W_UQ project q_c to q_nope shape [Lq, N * P]
|
||||
W_QR project q_c to q_pe shape [Lq, N * R]
|
||||
W_DKV project h_t to kv_c shape [H, Lkv]
|
||||
W_UK project kv_c to k_nope shape [Lkv, N, P]
|
||||
W_KR project h_t to k_pe shape [H, R]
|
||||
W_UV project kv_c to v shape [Lkv, N, V]
|
||||
W_O project v to h_t shape [N * V, H]
|
||||
|
||||
|
||||
## Compute Friendly Approach (i.e. "_forward_prefill"):
|
||||
|
||||
q_c = h_t @ W_DQ
|
||||
q_nope = (q_c @ W_UQ).view(Sq, N, P)
|
||||
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
|
||||
new_kv_c = h_t @ W_DKV
|
||||
new_k_pe = RoPE(h_t @ W_KR)
|
||||
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
|
||||
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
|
||||
k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
|
||||
v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
|
||||
|
||||
// MHA with QK headdim = P + R
|
||||
// V headdim = V
|
||||
// spda_o shape [Sq, N, V]
|
||||
spda_o = scaled_dot_product_attention(
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
|
||||
v
|
||||
)
|
||||
return spda_o @ W_O
|
||||
|
||||
NOTE: in the actual code,
|
||||
`kv_b_proj` is [W_UK; W_UV] concatenated per head
|
||||
`q_b_proj` is [W_UQ; W_QR] concatenated per head
|
||||
`out_proj` is W_O
|
||||
|
||||
|
||||
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
|
||||
|
||||
Runtime
|
||||
q_c = h_t @ W_DQ
|
||||
q_nope = (q_c @ W_UQ).view(-1, N, P)
|
||||
ql_nope = einsum("snh,lnh->snl", q, W_UK)
|
||||
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
|
||||
new_kv_c = h_t @ W_DKV
|
||||
new_k_pe = RoPE(h_t @ W_KR)
|
||||
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
|
||||
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
|
||||
|
||||
// MQA with QK headdim = Lkv + R
|
||||
// V headdim = Lkv
|
||||
// spda_o shape [Sq, N, Lkv]
|
||||
// NOTE: this is less compute-friendly since Lkv > P
|
||||
// but is more data-movement friendly since its MQA vs MHA
|
||||
spda_o = scaled_dot_product_attention(
|
||||
torch.cat([ql_nope, q_pe], dim=-1),
|
||||
torch.cat([kv_c, k_pe], dim=-1),
|
||||
kv_c
|
||||
)
|
||||
|
||||
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
|
||||
return o.view(-1, N * V) @ self.num_heads @ W_O
|
||||
|
||||
|
||||
## Chunked Prefill
|
||||
|
||||
For chunked prefill we want to use the compute friendly algorithm. We are
|
||||
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
|
||||
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
|
||||
|
||||
However, the compute-friendly approach can potentially run out of memory if Skv
|
||||
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
|
||||
|
||||
To mitigate this, we chunk the computation of attention with respect to the
|
||||
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
|
||||
fixed workspace size.
|
||||
|
||||
The chunked prefill approach is as follows:
|
||||
|
||||
MCC Max chunk of context to process per iter, computed dynamically,
|
||||
used to bound the memory usage
|
||||
|
||||
q_c = h_t @ W_DQ
|
||||
q_nope = (q_c @ W_UQ).view(Sq, N, P)
|
||||
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
|
||||
new_kv_c = h_t @ W_DKV
|
||||
new_k_pe = RoPE(h_t @ W_KR)
|
||||
new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
|
||||
new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
|
||||
|
||||
// MHA between queries and new KV
|
||||
// with QK headdim = P + R
|
||||
// V headdim = V
|
||||
// curr_o shape [Sq, N, V]
|
||||
// curr_lse shape [N, Sq], this is just order FA returns
|
||||
curr_o, curr_lse = scaled_dot_product_attention(
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
|
||||
new_v,
|
||||
casual=True,
|
||||
return_softmax_lse=True
|
||||
)
|
||||
|
||||
// Compute attention with the already existing context
|
||||
for chunk_idx in range(cdiv(C, MCC)):
|
||||
chunk_start = chunk_idx * MCC
|
||||
chunk_end = min(chunk_start + MCC, C)
|
||||
Sc = chunk_end - chunk_start
|
||||
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
|
||||
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
|
||||
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
|
||||
cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
|
||||
|
||||
chunk_o, chunk_lse = scaled_dot_product_attention(
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
torch.cat([cache_k_nope_chunk,
|
||||
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
|
||||
dim=-1),
|
||||
cache_v_chunk,
|
||||
casual=False,
|
||||
return_softmax_lse=True
|
||||
)
|
||||
|
||||
curr_o, curr_lse = merge_attn_states(
|
||||
suffix_output=curr_o,
|
||||
suffix_lse=curr_lse,
|
||||
prefix_output=chunk_o,
|
||||
prefix_lse=chunk_lse,
|
||||
)
|
||||
|
||||
return curr_o @ W_O
|
||||
"""
|
||||
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
# from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = True
|
||||
except ImportError:
|
||||
# For rocm use upstream flash attention
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
from vllm import envs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
def get_flash_attn_version():
|
||||
return None
|
||||
|
||||
class MLACommonBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return MLACommonMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
|
||||
return MLACommonMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonPrefillMetadata:
|
||||
""" Prefill Specific Metadata """
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling chunked prefill
|
||||
cu_seq_lens: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
seq_tot: list[int]
|
||||
max_seq_lens: list[int]
|
||||
workspace: torch.Tensor
|
||||
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
|
||||
D = TypeVar("D", bound=MLACommonDecodeMetadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonMetadata(Generic[D]):
|
||||
"""Metadata for MLACommon.
|
||||
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
decode: Optional[D] = None
|
||||
prefill: Optional[MLACommonPrefillMetadata] = None
|
||||
|
||||
def __post_init__(self):
|
||||
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
|
||||
if self.head_dim is not None and self.head_dim \
|
||||
not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f"received {self.head_dim}.")
|
||||
|
||||
|
||||
M = TypeVar("M", bound=MLACommonMetadata)
|
||||
|
||||
|
||||
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner: "GPUModelRunner",
|
||||
kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable,
|
||||
metadata_cls: Optional[type[M]] = None):
|
||||
self.metadata_cls = metadata_cls \
|
||||
if metadata_cls is not None else MLACommonMetadata
|
||||
self.runner = runner
|
||||
scheduler_config = runner.scheduler_config
|
||||
model_config = runner.model_config
|
||||
cache_config = runner.cache_config
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
self.num_heads = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.mla_dims = get_mla_dims(model_config)
|
||||
self.aot_schedule = False # current_platform.is_cuda()
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
# Dont try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
self.page_size = self.kv_cache_spec.block_size
|
||||
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(
|
||||
8 * model_config.max_model_len, 4 *
|
||||
scheduler_config.max_num_seqs * cache_config.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
# which would result in the workspace being:
|
||||
# 2*(576)*(64*1024) = 144mb
|
||||
# (assuming 576 MLA head dim, and fp16)
|
||||
# which would result in up-projected context being
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * cache_config.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
model_config.get_head_size()),
|
||||
dtype=model_config.dtype,
|
||||
device=runner.device,
|
||||
)
|
||||
self.block_table = block_table
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor):
|
||||
return MLACommonDecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with MLA.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
assert m.num_reqs == m.num_actual_tokens, \
|
||||
"MLA only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
# Update state usually set in reorder_batch.
|
||||
self._num_decodes = m.num_reqs
|
||||
self._num_decode_tokens = m.num_actual_tokens
|
||||
self._num_prefills = 0
|
||||
self._num_prefill_tokens = 0
|
||||
return self.build(0, m)
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
device = self.runner.device
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
prefill_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
reqs_start = self._num_decodes # prefill_start
|
||||
|
||||
context_lens_cpu = self.runner.input_batch.\
|
||||
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
chunked_context_metadata = None
|
||||
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
||||
and max_context_len_cpu > 0:
|
||||
# NOTE: it is recommend you read the `Chunked Prefill` section
|
||||
# in the comment at the top of the file before trying to
|
||||
# understand the following code
|
||||
|
||||
# currently we allocate an equal amount of workspace for each
|
||||
# prefill in the batch, we could probably use a more advanced
|
||||
# algorithm here and allocate more workspace to prefills with
|
||||
# longer context lengths
|
||||
max_context_chunk = (self.chunked_prefill_workspace_size //
|
||||
num_prefills_with_context_cpu)
|
||||
|
||||
if self.aot_schedule:
|
||||
# align max_context_chunk to page_size by rounding down,
|
||||
# currently the `gather_cache` kernel cannot handle
|
||||
# `context_chunk_starts` that are not aligned to page_size
|
||||
max_context_chunk = round_down(max_context_chunk,
|
||||
self.page_size)
|
||||
|
||||
assert max_context_chunk > 0
|
||||
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
||||
|
||||
# if `max_context_chunk = 256`, `num_chunks = 3`, and
|
||||
# `num_prefills_with_context = 4`, create a tensor that looks
|
||||
# like
|
||||
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
|
||||
# Note(simon): this is done in CPU because of downstream's
|
||||
# of `to_list`.
|
||||
chunk_starts = \
|
||||
torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, self._num_prefills) \
|
||||
* max_context_chunk
|
||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||
chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
|
||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
self._num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(chunk_seq_lens,
|
||||
dim=1,
|
||||
out=cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
|
||||
chunked_context_metadata = \
|
||||
MLACommonPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
|
||||
assert max(chunked_context_metadata.max_seq_lens) <= \
|
||||
self.chunked_prefill_workspace_size
|
||||
|
||||
prefill_metadata = MLACommonPrefillMetadata(
|
||||
block_table=block_table_tensor[reqs_start:, ...],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
chunked_context=chunked_context_metadata,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if self._num_decodes > 0:
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
|
||||
seq_lens=seq_lens[:self._num_decodes],
|
||||
)
|
||||
|
||||
return self.metadata_cls(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
# MLACommonMetadata Chunk prefill specific
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
num_prefills=self._num_prefills,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == 1
|
||||
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported for MLA")
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if self.vllm_flash_attn_version is not None:
|
||||
self.flash_attn_varlen_func = \
|
||||
functools.partial(flash_attn_varlen_func,
|
||||
fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim for attention backends that do
|
||||
# not support different headdims
|
||||
# We don't need to pad V if we are on a hopper system with FA3
|
||||
self._pad_v = self.vllm_flash_attn_version is None or not (
|
||||
self.vllm_flash_attn_version == 3
|
||||
and current_platform.get_device_capability()[0] == 9)
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=False,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
maybe_padded_v = v
|
||||
if self._pad_v:
|
||||
maybe_padded_v = torch.nn.functional.pad(
|
||||
v, [0, q.shape[-1] - v.shape[-1]], value=0)
|
||||
|
||||
if is_vllm_fa:
|
||||
attn_out = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=maybe_padded_v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Use return_attn_probs instead of return_softmax_lse for RoCM
|
||||
attn_out = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=maybe_padded_v,
|
||||
return_attn_probs=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Unpack the output if there is multiple results
|
||||
lse = None
|
||||
if isinstance(attn_out, tuple):
|
||||
attn_out, lse = attn_out[0], attn_out[1]
|
||||
|
||||
# Remain consistent with old `flash_attn_varlen_func` where there
|
||||
# is only one output tensor if `return_softmax_lse` is False.
|
||||
if return_softmax_lse:
|
||||
return attn_out, lse
|
||||
return attn_out
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device)
|
||||
dequant_weights = layer.quant_method.apply(layer,
|
||||
eye,
|
||||
bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight if not envs.MACA_VLLM_USE_TN_2_NN else layer.weight.T
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _compute_prefill_context(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
):
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
assert prefill_metadata.chunked_context is not None
|
||||
|
||||
output = None
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
workspace = prefill_metadata.chunked_context.workspace
|
||||
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
|
||||
ops.gather_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_table,
|
||||
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
)
|
||||
|
||||
kv_c_normed = workspace[:toks]\
|
||||
[..., :self.kv_lora_rank]
|
||||
k_pe = workspace[:toks]\
|
||||
[..., self.kv_lora_rank:].unsqueeze(1)
|
||||
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
||||
dim=-1)
|
||||
|
||||
attn_output, attn_softmax_lse = \
|
||||
self._flash_attn_varlen_diff_headdims(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # Context is unmasked
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
|
||||
if output is None:
|
||||
output = attn_output
|
||||
output_lse = attn_softmax_lse
|
||||
else:
|
||||
output_tmp = torch.empty_like(output)
|
||||
output_lse_tmp = torch.empty_like(output_lse)
|
||||
merge_attn_states(
|
||||
output=output_tmp,
|
||||
output_lse=output_lse_tmp,
|
||||
prefix_output=output,
|
||||
prefix_lse=output_lse,
|
||||
suffix_output=attn_output,
|
||||
suffix_lse=attn_softmax_lse,
|
||||
)
|
||||
output = output_tmp
|
||||
output_lse = output_lse_tmp
|
||||
|
||||
return output, output_lse
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.prefill is not None
|
||||
|
||||
has_context = False # attn_metadata.prefill.chunked_context is not None
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
output = self._flash_attn_varlen_diff_headdims(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||
max_seqlen_k=attn_metadata.prefill.max_query_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
# return_softmax_lse=has_context,
|
||||
)
|
||||
|
||||
if has_context:
|
||||
suffix_output, suffix_lse = output
|
||||
context_output, context_lse = self._compute_prefill_context( \
|
||||
q, kv_c_and_k_pe_cache, attn_metadata)
|
||||
|
||||
output = torch.empty_like(suffix_output)
|
||||
merge_attn_states(
|
||||
output=output,
|
||||
prefix_output=context_output,
|
||||
prefix_lse=context_lse,
|
||||
suffix_output=suffix_output,
|
||||
suffix_lse=suffix_lse,
|
||||
)
|
||||
|
||||
# unpad if necessary
|
||||
if self._pad_v:
|
||||
output = output.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]] \
|
||||
.reshape(-1, self.num_heads * v.shape[-1])
|
||||
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
self,
|
||||
ql_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_toks, ...]
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
decode_q = q[:num_decode_tokens]
|
||||
|
||||
prefill_q = q[num_decode_tokens:]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if has_prefill:
|
||||
output[num_decode_tokens:] = self._forward_prefill(
|
||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
if has_decode:
|
||||
assert attn_metadata.decode is not None
|
||||
decode_q_nope, decode_q_pe = decode_q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
|
||||
output[:num_decode_tokens] = self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
|
||||
return output_padded
|
||||
97
vllm/v1/attention/backends/mla/cutlass_mla.py
Normal file
97
vllm/v1/attention/backends/mla/cutlass_mla.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CUTLASS_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||
return CutlassMLAImpl
|
||||
|
||||
|
||||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
# Run MLA
|
||||
# Clone q_nope and q_pe to make sure strides computation is correct.
|
||||
q_nope = q_nope.clone()
|
||||
q_pe = q_pe.clone()
|
||||
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table, self.scale)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
180
vllm/v1/attention/backends/mla/flashmla.py
Normal file
180
vllm/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["FlashMLAMetadata"]:
|
||||
return FlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
|
||||
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
seq_lens,
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
if self.runner.full_cuda_graph:
|
||||
# First time around (CUDAGraph capture), allocate the static buffer
|
||||
if self.cg_buf_tile_scheduler_metadata is None:
|
||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||
self.cg_buf_num_splits = num_splits
|
||||
else:
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
||||
assert (self.cg_buf_tile_scheduler_metadata.size() ==
|
||||
tile_scheduler_metadata.size())
|
||||
self.cg_buf_tile_scheduler_metadata.\
|
||||
copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
|
||||
num_splits = num_splits_view
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
assert is_flashmla_supported(), \
|
||||
"FlashMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=attn_metadata.decode.
|
||||
tile_scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
220
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
220
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def is_aiter_mla_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER \
|
||||
and envs.VLLM_ROCM_USE_AITER_MLA
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterMLAImpl"]:
|
||||
return AiterMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AiterMLAMetadata"]:
|
||||
return AiterMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
|
||||
return AiterMLAMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: Optional[torch.Tensor] = None
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: Optional[torch.Tensor] = None
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_len: Optional[torch.Tensor] = None
|
||||
# The query indptr, shape : [num_decode + 1]
|
||||
qo_indptr: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
|
||||
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
||||
"only supports block size 1."
|
||||
|
||||
def _get_paged_kv_tensors(
|
||||
self, block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
device = self.runner.device
|
||||
|
||||
mask = (torch.arange(block_table.size(1),
|
||||
dtype=block_table.dtype,
|
||||
device=device).unsqueeze(0)
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
paged_kv_indices = block_table[mask]
|
||||
|
||||
paged_kv_indptr = torch.cat([
|
||||
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
|
||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
||||
])
|
||||
|
||||
paged_kv_last_page_len = seq_lens % page_size
|
||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||
page_size, paged_kv_last_page_len)
|
||||
qo_indptr = torch.arange(0,
|
||||
self._num_decodes + 1,
|
||||
step=1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
return (
|
||||
paged_kv_indices,
|
||||
paged_kv_indptr,
|
||||
paged_kv_last_page_len,
|
||||
qo_indptr,
|
||||
)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
|
||||
|
||||
(
|
||||
paged_kv_indices,
|
||||
paged_kv_indptr,
|
||||
paged_last_page_len,
|
||||
qo_indptr,
|
||||
) = self._get_paged_kv_tensors(block_table_tensor, seq_lens)
|
||||
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_last_page_len,
|
||||
qo_indptr=qo_indptr)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
assert (num_heads == 16 or num_heads == 128), (
|
||||
f"Aiter MLA only supports 16 or 128 number of heads.\n"
|
||||
f"Provided {num_heads} number of heads.\n"
|
||||
"Try adjusting tensor_parallel_size value.")
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=False,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
if self.num_heads == 16:
|
||||
# AITER MLA decode kernel only supports
|
||||
# max_seqlen_q=1 when using 16 heads.
|
||||
max_seqlen_qo = 1
|
||||
else:
|
||||
# AITER MLA decode Kernel handles arbitrary
|
||||
# max_seqlen_q values when using 128 heads.
|
||||
assert attn_metadata.prefill is not None
|
||||
max_seqlen_qo = attn_metadata.prefill.max_query_len
|
||||
|
||||
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
|
||||
attn_metadata.decode.qo_indptr, max_seqlen_qo,
|
||||
attn_metadata.decode.paged_kv_indptr,
|
||||
attn_metadata.decode.paged_kv_indices,
|
||||
attn_metadata.decode.paged_kv_last_page_len)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
162
vllm/v1/attention/backends/mla/triton_mla.py
Normal file
162
vllm/v1/attention/backends/mla/triton_mla.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.attention.backends.triton_mla import (load_config,
|
||||
find_best_mla_para)
|
||||
logger = init_logger(__name__)
|
||||
|
||||
import os
|
||||
# TODO: Configure environment variables temporarily. New versions do not need to be configured
|
||||
os.environ['TRITON_ENABLE_MACA_OPT_MOVE_DOT_OPERANDS_OUT_LOOP'] = '1'
|
||||
os.environ['TRITON_ENABLE_MACA_CHAIN_DOT_OPT'] = '1'
|
||||
|
||||
JSON_DATA = load_config()
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["TritonMLAMetadata"]:
|
||||
return TritonMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TritonMLAMetadataBuilder"]:
|
||||
return TritonMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
@dataclass
|
||||
class TritonMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
num_kv_splits: int
|
||||
num_stages: int
|
||||
|
||||
@dataclass
|
||||
class TritonMLAMetadata(MLACommonMetadata[TritonMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
class TritonMLAMetadataBuilder(MLACommonMetadataBuilder[TritonMLAMetadata]):
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> TritonMLADecodeMetadata:
|
||||
if seq_lens is not None:
|
||||
batch = seq_lens.shape[0]
|
||||
max_seq_len = int(seq_lens.max())
|
||||
num_kv_splits, num_stages = find_best_mla_para(JSON_DATA, batch, max_seq_len, 8)
|
||||
else:
|
||||
num_kv_splits = 4
|
||||
num_stages = 1
|
||||
return TritonMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
num_kv_splits=num_kv_splits,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
B,
|
||||
self.num_heads,
|
||||
attn_metadata.decode.num_kv_splits,
|
||||
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||
# just mirror that
|
||||
self.kv_lora_rank + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
# Add a head dim of 1
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||
|
||||
# Run MQA
|
||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens, attn_logits,
|
||||
attn_metadata.decode.num_kv_splits,
|
||||
attn_metadata.decode.num_stages,
|
||||
self.scale, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
240
vllm/v1/attention/backends/pallas.py
Normal file
240
vllm/v1/attention/backends/pallas.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
# Required to register custom ops.
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, next_power_of_2
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "PALLAS_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["PallasMetadata"]:
|
||||
return PallasMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||
|
||||
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
|
||||
# block_tables within the PallasMetadata constitute almost the entire SMEM
|
||||
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
|
||||
# we simply make sure that the size is smaller than half of SMEM capacity.
|
||||
@staticmethod
|
||||
def get_min_page_size(vllm_config: VllmConfig) -> int:
|
||||
max_num_page_per_req = (1024 * 1024 // 2 //
|
||||
vllm_config.scheduler_config.max_num_seqs // 4)
|
||||
min_page_size = cdiv(vllm_config.model_config.max_model_len,
|
||||
max_num_page_per_req)
|
||||
min_page_size = 1 << (min_page_size - 1).bit_length()
|
||||
return min_page_size
|
||||
|
||||
# TPU has limited SREGs (scalar registers), if page_size is too small, we
|
||||
# can spill SREGs easily which leads to bad performance. The strategy we
|
||||
# apply here is trying to split max-model-len to 16 pages which make the
|
||||
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
|
||||
@staticmethod
|
||||
def get_page_size(vllm_config: VllmConfig) -> int:
|
||||
page_size = next_power_of_2(
|
||||
vllm_config.model_config.max_model_len) // 16
|
||||
if page_size <= 16:
|
||||
return 16
|
||||
if page_size >= 256:
|
||||
return 256
|
||||
return page_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class PallasMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Used in the PallasAttentionBackendImpl
|
||||
slot_mapping: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
context_lens: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
num_seqs: torch.Tensor
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in Pallas is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError("Paged attention Pallas kernel does "
|
||||
"not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if head_size % 128 != 0:
|
||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
tpu_version = torch_xla.tpu.version()
|
||||
if tpu_version < 4:
|
||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: PallasMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
# For determine_available_memory case.
|
||||
if kv_cache.numel() == 0:
|
||||
if output is None:
|
||||
output = torch.ones_like(query)
|
||||
return output
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
|
||||
# Write input keys and values to the KV cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
write_to_kv_cache(key, value, kv_cache, slot_mapping)
|
||||
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.num_seqs,
|
||||
# By default, the system utilizes optimized block size and
|
||||
# vmem_limit_bytes parameters from the kernel repository. However,
|
||||
# these can be manually adjusted for debugging if necessary.
|
||||
num_kv_pages_per_block=None,
|
||||
num_queries_per_block=None,
|
||||
vmem_limit_bytes=None,
|
||||
use_kernel=True,
|
||||
sm_scale=self.scale,
|
||||
sliding_window=self.sliding_window,
|
||||
soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
""" Write the key and values to the KV cache.
|
||||
|
||||
Args:
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
|
||||
"""
|
||||
_, _, num_combined_kv_heads, head_size = kv_cache.shape
|
||||
num_kv_heads = num_combined_kv_heads // 2
|
||||
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
|
||||
head_size)
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
|
||||
|
||||
kv_cache = kv_cache.flatten(0, 1)
|
||||
kv_cache.index_copy_(0, slot_mapping, kv)
|
||||
295
vllm/v1/attention/backends/triton_attn.py
Normal file
295
vllm/v1/attention/backends/triton_attn.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table)
|
||||
self.aot_schedule = False
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
||||
return TritonAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return FlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
|
||||
return TritonAttentionMetadataBuilder
|
||||
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"TritonAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.use_irope = use_irope
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by TritonAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonAttentionImpl")
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
self.force_prefill_decode_attn = \
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
|
||||
self.sinks = sinks
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
use_prefill_decode_attn = self.force_prefill_decode_attn
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if use_prefill_decode_attn:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if use_prefill_decode_attn:
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
assert layer._q_scale == 1.0, \
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
if not current_platform.is_rocm():
|
||||
# Skip Q quantization on ROCm, since dequantizing back to
|
||||
# f32 in the attention kernel is not supported.
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
seqused_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
if use_prefill_decode_attn:
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=seqused_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale,
|
||||
sinks=self.sinks)
|
||||
|
||||
else:
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
return output
|
||||
118
vllm/v1/attention/backends/utils.py
Normal file
118
vllm/v1/attention/backends/utils.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import abc
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
@dataclass
|
||||
class CommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
seq_lens: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
max_query_len: int
|
||||
"""Longest query in batch"""
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention.
|
||||
full_cudagraph_supported: ClassVar[bool] = False
|
||||
|
||||
@abstractmethod
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
Central method that builds attention metadata.
|
||||
Some builders (MLA) require reorder_batch to be called prior to build.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
"""
|
||||
Can this batch (with given metadata) use CUDA Graphs for attention.
|
||||
"""
|
||||
return False
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
Build attention metadata for CUDA graph capture. Uses build by default.
|
||||
Subclasses that override this method should call self.build or
|
||||
super().build_for_cudagraph_capture.
|
||||
"""
|
||||
return self.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
|
||||
def use_cascade_attention(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
query_lens: np.ndarray,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_alibi: bool,
|
||||
use_sliding_window: bool,
|
||||
num_sms: int,
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""
|
||||
This method can reorder the batch if desired by the backend.
|
||||
:return: Has the batch been reordered (default False).
|
||||
"""
|
||||
return False
|
||||
|
||||
def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
||||
static_forward_context):
|
||||
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
|
||||
f"is not valid: target layer {target_layer_name} ")
|
||||
|
||||
if current_layer_name == target_layer_name:
|
||||
raise ValueError(error_msg +
|
||||
"cannot be the same as the current layer.")
|
||||
|
||||
if target_layer_name not in static_forward_context:
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
# If target layer name is not in the static fwd context, it means either
|
||||
# a) the target layer does not come BEFORE the current layer, or
|
||||
# b) the target layer is not an Attention layer that exists in the model
|
||||
current_layer_idx = extract_layer_index(current_layer_name)
|
||||
target_layer_idx = extract_layer_index(target_layer_name)
|
||||
if current_layer_idx <= target_layer_idx:
|
||||
raise ValueError(error_msg + "must come before the current layer.")
|
||||
else:
|
||||
raise ValueError(error_msg +
|
||||
"is not a valid Attention layer in the model.")
|
||||
|
||||
# Currently KV sharing is only supported between layers of the same type
|
||||
target_layer_attn_type = static_forward_context[
|
||||
target_layer_name].attn_type
|
||||
expected = static_forward_context[current_layer_name].attn_type
|
||||
if target_layer_attn_type != expected:
|
||||
raise ValueError(
|
||||
error_msg +
|
||||
f"must be the same type as the current layer ({expected}).")
|
||||
0
vllm/v1/core/__init__.py
Normal file
0
vllm/v1/core/__init__.py
Normal file
349
vllm/v1/core/block_pool.py
Normal file
349
vllm/v1/core/block_pool.py
Normal file
@@ -0,0 +1,349 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Optional
|
||||
|
||||
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
||||
BlockStored, KVCacheEvent)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockPool:
|
||||
"""BlockPool that manages KVCacheBlocks.
|
||||
It provides methods to allocate, free and cache the kv cache blocks. The
|
||||
free_block_queue stores the free blocks in eviction order to enable
|
||||
allocation, free, and cache eviction. The cached_block_hash_to_block
|
||||
maps between block hash and cached block to support finding cached blocks
|
||||
by their block hash.
|
||||
|
||||
Args:
|
||||
num_gpu_blocks: The number of blocks in the pool.
|
||||
enable_caching: Whether to enable prefix caching.
|
||||
enable_kv_cache_events: Whether to enable kv cache events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool = False,
|
||||
):
|
||||
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.enable_caching = enable_caching
|
||||
# All kv-cache blocks.
|
||||
self.blocks: list[KVCacheBlock] = [
|
||||
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
|
||||
]
|
||||
# Free block queue that constructs and manipulates a doubly linked
|
||||
# list of free blocks (including eviction candidates when caching is
|
||||
# enabled).
|
||||
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
|
||||
|
||||
# {block_hash: {block ID: block}}. A cached block is
|
||||
# a full block with a block hash that can be used for prefix caching.
|
||||
# The cached block may be used by running requests or in the
|
||||
# free_block_queue that could potentially be evicted.
|
||||
# NOTE: We currently don't de-duplicate the blocks in the cache,
|
||||
# meaning that if a block becomes full and is cached, we don't check
|
||||
# if there is already an identical block in the cache. This is because
|
||||
# we want to make sure the allocated block IDs won't change so that
|
||||
# block tables are append-only.
|
||||
self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[
|
||||
int, KVCacheBlock]] = defaultdict(dict)
|
||||
|
||||
# To represent a placeholder block with block_id=0.
|
||||
# The ref_cnt of null_block is not maintained, needs special care to
|
||||
# avoid freeing it.
|
||||
self.null_block = self.free_block_queue.popleft()
|
||||
self.null_block.is_null = True
|
||||
|
||||
self.enable_kv_cache_events = enable_kv_cache_events
|
||||
self.kv_event_queue: list[KVCacheEvent] = []
|
||||
|
||||
def get_cached_block(
|
||||
self, block_hash: BlockHash,
|
||||
kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]:
|
||||
"""Get the cached block by the block hash for each group in
|
||||
`kv_cache_group_ids`, or None if cache miss for any group.
|
||||
If there are duplicated blocks, we return the first block in the cache.
|
||||
|
||||
Args:
|
||||
block_hash: The hash value of the block.
|
||||
kv_cache_group_ids: The ids of the KV cache groups.
|
||||
|
||||
Returns:
|
||||
The cached blocks if exists, or None.
|
||||
"""
|
||||
cached_blocks = []
|
||||
for group_id in kv_cache_group_ids:
|
||||
cached_blocks_one_group = self.cached_block_hash_to_block.get(
|
||||
BlockHashWithGroupId(block_hash, group_id))
|
||||
if not cached_blocks_one_group:
|
||||
return None
|
||||
first_block = next(iter(cached_blocks_one_group.values()))
|
||||
cached_blocks.append(first_block)
|
||||
return cached_blocks
|
||||
|
||||
def cache_full_blocks(
|
||||
self,
|
||||
request: Request,
|
||||
blocks: list[KVCacheBlock],
|
||||
block_hashes: list[BlockHash],
|
||||
num_cached_blocks: int,
|
||||
num_full_blocks: int,
|
||||
block_size: int,
|
||||
kv_cache_group_id: int,
|
||||
hash_fn: Callable,
|
||||
) -> None:
|
||||
"""Cache a list of full blocks for prefix caching.
|
||||
This function takes a list of blocks that will have their block hash
|
||||
metadata to be updated and cached. Given a request, it computes the
|
||||
block hashes for the blocks starting from `num_cached_blocks` to
|
||||
`num_full_blocks`, updating the metadata for each block
|
||||
and caching them in the `cached_block_hash_to_block`.
|
||||
|
||||
Args:
|
||||
request: The request to cache the blocks.
|
||||
blocks: All blocks in the request.
|
||||
block_hashes: Block hashes of the blocks in the request. Note that
|
||||
this list may be shorter than the blocks list. In this case the
|
||||
missed block hash will be computed in this function.
|
||||
num_cached_blocks: The number of blocks that are already cached.
|
||||
num_full_blocks: The number of blocks that are full and should
|
||||
be cached after this function.
|
||||
block_size: Number of tokens in each block.
|
||||
kv_cache_group_id: The id of the KV cache group.
|
||||
hash_fn: The hash function to use for block hashes.
|
||||
"""
|
||||
if num_cached_blocks == num_full_blocks:
|
||||
return
|
||||
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
|
||||
assert len(block_hashes) >= num_cached_blocks
|
||||
new_block_hashes = block_hashes[num_cached_blocks:]
|
||||
|
||||
# Update the new blocks with the block hashes through the chain.
|
||||
if num_cached_blocks == 0:
|
||||
prev_block_hash_value = None
|
||||
else:
|
||||
prev_block = blocks[num_cached_blocks - 1]
|
||||
assert prev_block.block_hash is not None
|
||||
prev_block_hash_value = prev_block.block_hash.get_hash_value()
|
||||
|
||||
parent_block_hash = prev_block_hash_value
|
||||
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
|
||||
else None)
|
||||
for i, blk in enumerate(new_full_blocks):
|
||||
assert blk.block_hash is None
|
||||
|
||||
if i < len(new_block_hashes):
|
||||
# The block hash may already be computed in
|
||||
# "get_computed_blocks" if the tokens are not generated by
|
||||
# this request (either the prompt tokens or the previously
|
||||
# generated tokens with preemption), or by other
|
||||
# single_type_managers with the same block_size.
|
||||
# In this case we simply reuse the block hash.
|
||||
block_hash = new_block_hashes[i]
|
||||
else:
|
||||
# Otherwise compute the block hash and cache it in the request
|
||||
# in case it will be preempted in the future.
|
||||
blk_idx = num_cached_blocks + i
|
||||
start_token_idx = blk_idx * block_size
|
||||
end_token_idx = (blk_idx + 1) * block_size
|
||||
block_tokens = request.all_token_ids[
|
||||
start_token_idx:end_token_idx]
|
||||
assert len(block_tokens) == block_size, (
|
||||
f"Expected {block_size} tokens, got "
|
||||
f"{len(block_tokens)} at {blk_idx}th block for request "
|
||||
f"{request.request_id}({request})")
|
||||
|
||||
# Generate extra keys for multi-modal inputs. Note that since
|
||||
# we reach to this branch only when the block is completed with
|
||||
# generated tokens, we only need to consider the last mm input.
|
||||
extra_keys, _ = generate_block_hash_extra_keys(
|
||||
request, start_token_idx, end_token_idx, -1)
|
||||
|
||||
# Compute the hash of the current block.
|
||||
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
|
||||
block_tokens, extra_keys)
|
||||
block_hashes.append(block_hash)
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
block_hash_with_group_id = BlockHashWithGroupId(
|
||||
block_hash, kv_cache_group_id)
|
||||
blk.block_hash = block_hash_with_group_id
|
||||
self.cached_block_hash_to_block[block_hash_with_group_id][
|
||||
blk.block_id] = blk
|
||||
if new_hashes is not None:
|
||||
new_hashes.append(block_hash.hash_value)
|
||||
prev_block_hash_value = block_hash.hash_value
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
self.kv_event_queue.append(
|
||||
BlockStored(
|
||||
block_hashes=new_hashes,
|
||||
parent_block_hash=parent_block_hash,
|
||||
token_ids=request.
|
||||
all_token_ids[num_cached_blocks *
|
||||
block_size:num_full_blocks * block_size],
|
||||
block_size=block_size,
|
||||
lora_id=request.lora_request.id
|
||||
if request.lora_request else None,
|
||||
))
|
||||
|
||||
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
|
||||
"""Get new blocks from the free block pool.
|
||||
|
||||
Note that we do not check block cache in this function.
|
||||
|
||||
Args:
|
||||
num_blocks: The number of blocks to allocate.
|
||||
|
||||
Returns:
|
||||
A list of new block.
|
||||
"""
|
||||
if num_blocks > self.get_num_free_blocks():
|
||||
raise ValueError(
|
||||
f"Cannot get {num_blocks} free blocks from the pool")
|
||||
|
||||
ret: list[KVCacheBlock] = []
|
||||
idx = 0
|
||||
while idx < num_blocks:
|
||||
# First allocate blocks.
|
||||
curr_block = self.free_block_queue.popleft()
|
||||
assert curr_block.ref_cnt == 0
|
||||
|
||||
# If the block is cached, evict it.
|
||||
if self.enable_caching:
|
||||
self._maybe_evict_cached_block(curr_block)
|
||||
|
||||
curr_block.incr_ref()
|
||||
ret.append(curr_block)
|
||||
idx += 1
|
||||
|
||||
return ret
|
||||
|
||||
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
|
||||
"""
|
||||
If a block is cached in `cached_block_hash_to_block`, we reset its hash
|
||||
metadata and evict it from the cache.
|
||||
|
||||
Args:
|
||||
block: The block to evict.
|
||||
|
||||
Returns:
|
||||
True if the block is evicted, False otherwise.
|
||||
"""
|
||||
block_hash = block.block_hash
|
||||
if block_hash and block_hash in self.cached_block_hash_to_block:
|
||||
block.reset_hash()
|
||||
del self.cached_block_hash_to_block[block_hash][block.block_id]
|
||||
|
||||
if len(self.cached_block_hash_to_block[block_hash]) == 0:
|
||||
del self.cached_block_hash_to_block[block_hash]
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
# FIXME (Chen): Not sure whether we should return `hash_value`
|
||||
# or `(hash_value, group_id)` here. But it's fine now because
|
||||
# we disable hybrid kv cache manager when kv cache event is
|
||||
# enabled, so there is only one group.
|
||||
self.kv_event_queue.append(
|
||||
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
|
||||
return True
|
||||
return False
|
||||
|
||||
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
|
||||
"""Touch a block increases its reference count by 1, and may remove
|
||||
the block from the free queue. This is used when a block is hit by
|
||||
another request with the same prefix.
|
||||
|
||||
Args:
|
||||
blocks: A list of blocks to touch.
|
||||
"""
|
||||
for blocks_per_group in blocks:
|
||||
for block in blocks_per_group:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.remove(block)
|
||||
block.incr_ref()
|
||||
|
||||
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
|
||||
"""Free a list of blocks. The blocks should be ordered by their
|
||||
eviction priority, where the first block will be evicted first.
|
||||
|
||||
Args:
|
||||
ordered_blocks: A list of blocks to free ordered by their eviction
|
||||
priority.
|
||||
"""
|
||||
for block in ordered_blocks:
|
||||
block.decr_ref()
|
||||
# null_block should not be added to the free list.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.append(block)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalid prefix caching after the weights are updated,
|
||||
or used for resetting prefix caching status for benchmarking.
|
||||
|
||||
Returns:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
|
||||
if num_used_blocks != 1: # The null block is always marked as used
|
||||
logger.warning(
|
||||
"Failed to reset prefix cache because some "
|
||||
"blocks (%d) are not freed yet", num_used_blocks - 1)
|
||||
return False
|
||||
|
||||
# Remove all hashes so that no new blocks will hit.
|
||||
self.cached_block_hash_to_block = defaultdict(dict)
|
||||
|
||||
# Remove all hashes from all blocks.
|
||||
for block in self.blocks:
|
||||
block.reset_hash()
|
||||
|
||||
logger.info("Successfully reset prefix cache")
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
self.kv_event_queue.append(AllBlocksCleared())
|
||||
|
||||
return True
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
"""Get the number of free blocks in the pool.
|
||||
|
||||
Returns:
|
||||
The number of free blocks.
|
||||
"""
|
||||
return self.free_block_queue.num_free_blocks
|
||||
|
||||
def get_usage(self) -> float:
|
||||
"""Get the KV cache usage.
|
||||
|
||||
Returns:
|
||||
The KV cache usage (between 0.0 and 1.0).
|
||||
"""
|
||||
return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Atomically takes all events and clears the queue.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
if not self.enable_kv_cache_events:
|
||||
return []
|
||||
events = self.kv_event_queue
|
||||
self.kv_event_queue = []
|
||||
return events
|
||||
150
vllm/v1/core/encoder_cache_manager.py
Normal file
150
vllm/v1/core/encoder_cache_manager.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.v1.request import Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EncoderCacheManager:
|
||||
|
||||
def __init__(self, cache_size: int):
|
||||
self.cache_size = cache_size
|
||||
self.num_free_slots = cache_size
|
||||
# req_id -> cached input ids
|
||||
self.cached: dict[str, set[int]] = {}
|
||||
# list of [req_id, input_id]
|
||||
self.freed: list[tuple[str, int]] = []
|
||||
|
||||
def has_cache(self, request: Request, input_id: int) -> bool:
|
||||
req_id = request.request_id
|
||||
return req_id in self.cached and input_id in self.cached[req_id]
|
||||
|
||||
def can_allocate(self, request: Request, input_id: int) -> bool:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
return num_tokens <= self.num_free_slots
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
req_id = request.request_id
|
||||
if req_id not in self.cached:
|
||||
self.cached[req_id] = set()
|
||||
self.cached[req_id].add(input_id)
|
||||
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||
return self.cached.get(request.request_id, set())
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
"""Free a single encoder input id for the request."""
|
||||
req_id = request.request_id
|
||||
if req_id not in self.cached:
|
||||
return
|
||||
|
||||
self.cached[req_id].discard(input_id)
|
||||
if len(self.cached[req_id]) == 0:
|
||||
del self.cached[req_id]
|
||||
self.num_free_slots += request.get_num_encoder_tokens(input_id)
|
||||
self.freed.append((req_id, input_id))
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free all cached input ids for the request."""
|
||||
input_ids = self.get_cached_input_ids(request).copy()
|
||||
for input_id in input_ids:
|
||||
self.free_encoder_input(request, input_id)
|
||||
|
||||
def get_freed_ids(self) -> list[tuple[str, int]]:
|
||||
freed = self.freed
|
||||
self.freed = []
|
||||
return freed
|
||||
|
||||
|
||||
def compute_encoder_budget(
|
||||
model_config: "ModelConfig",
|
||||
scheduler_config: "SchedulerConfig",
|
||||
mm_registry: MultiModalRegistry,
|
||||
) -> tuple[int, int]:
|
||||
"""Compute the encoder cache budget based on the model and scheduler
|
||||
configurations.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration.
|
||||
scheduler_config: Scheduler configuration.
|
||||
mm_registry: Provides information about the token cost.
|
||||
|
||||
Returns:
|
||||
- Compute budget for encoder execution, in unit of number of tokens
|
||||
in the input sequence.
|
||||
- Space budget for encoder cache size, in unit of number of tokens
|
||||
in the input sequence.
|
||||
"""
|
||||
|
||||
if not model_config.is_multimodal_model:
|
||||
return 0, 0
|
||||
|
||||
# TODO: handle encoder-decoder models once we support them.
|
||||
(
|
||||
encoder_compute_budget,
|
||||
encoder_cache_size,
|
||||
) = _compute_encoder_budget_multimodal(
|
||||
model_config,
|
||||
scheduler_config,
|
||||
mm_registry,
|
||||
)
|
||||
|
||||
return encoder_compute_budget, encoder_cache_size
|
||||
|
||||
|
||||
def _compute_encoder_budget_multimodal(
|
||||
model_config: "ModelConfig",
|
||||
scheduler_config: "SchedulerConfig",
|
||||
mm_registry: MultiModalRegistry,
|
||||
) -> tuple[int, int]:
|
||||
"""Compute the encoder cache budget based on the model and scheduler
|
||||
configurations for a multimodal model.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration.
|
||||
scheduler_config: Scheduler configuration.
|
||||
mm_registry: Provides information about the token cost.
|
||||
|
||||
Returns:
|
||||
- Compute budget for encoder execution, in unit of number of tokens
|
||||
in the input sequence.
|
||||
- Space budget for encoder cache size, in unit of number of tokens
|
||||
in the input sequence.
|
||||
"""
|
||||
|
||||
max_tokens_by_modality_dict = mm_registry \
|
||||
.get_max_tokens_per_item_by_nonzero_modality(model_config)
|
||||
|
||||
if not max_tokens_by_modality_dict:
|
||||
logger.warning(
|
||||
"All non-text modalities supported by the model have been "
|
||||
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
|
||||
"not be initialized.")
|
||||
return 0, 0
|
||||
|
||||
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
|
||||
key=lambda item: item[1])
|
||||
|
||||
if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item
|
||||
> scheduler_config.max_num_batched_tokens):
|
||||
raise ValueError(
|
||||
"Chunked MM input disabled but max_tokens_per_mm_item "
|
||||
f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens"
|
||||
f" ({scheduler_config.max_num_batched_tokens}). Please increase "
|
||||
"max_num_batched_tokens.")
|
||||
|
||||
encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
|
||||
max_tokens_per_mm_item)
|
||||
encoder_cache_size = max(scheduler_config.encoder_cache_size,
|
||||
max_tokens_per_mm_item)
|
||||
|
||||
return encoder_compute_budget, encoder_cache_size
|
||||
363
vllm/v1/core/kv_cache_coordinator.py
Normal file
363
vllm/v1/core/kv_cache_coordinator.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Optional
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
FullAttentionManager, get_manager_for_kv_cache_spec)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class KVCacheCoordinator(ABC):
|
||||
"""
|
||||
Coordinate the KV cache of different KV cache groups.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
caching_hash_fn: Callable,
|
||||
enable_kv_cache_events: bool,
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
|
||||
enable_kv_cache_events)
|
||||
|
||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||
self.use_eagle = use_eagle
|
||||
self.single_type_managers = tuple(
|
||||
get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_group.kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=i,
|
||||
caching_hash_fn=caching_hash_fn,
|
||||
) for i, kv_cache_group in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups))
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self, request_id: str, num_tokens: int,
|
||||
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
"""
|
||||
num_blocks_to_allocate = 0
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_tokens, new_computed_blocks[i])
|
||||
return num_blocks_to_allocate
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str,
|
||||
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix cache.
|
||||
"""
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
manager.save_new_computed_blocks(request_id,
|
||||
new_computed_blocks[i])
|
||||
|
||||
def allocate_new_blocks(self, request_id: str,
|
||||
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
return tuple(
|
||||
manager.allocate_new_blocks(request_id, num_tokens)
|
||||
for manager in self.single_type_managers)
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||
num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
block_hashes: The block hashes of the request.
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.cache_blocks(request, block_hashes, num_computed_tokens)
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
Free the blocks for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.free(request_id)
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> list[int]:
|
||||
"""
|
||||
Get the number of common prefix blocks for a request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
block_hashes: The block hashes of the request.
|
||||
|
||||
Returns:
|
||||
The number of common prefix blocks.
|
||||
"""
|
||||
num_blocks_per_group = [
|
||||
manager.get_num_common_prefix_blocks(request_id,
|
||||
num_running_requests)
|
||||
for manager in self.single_type_managers
|
||||
]
|
||||
return num_blocks_per_group
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks` and replace
|
||||
the removed blocks with null_block.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.remove_skipped_blocks(request_id, num_computed_tokens)
|
||||
|
||||
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Get the blocks for the request.
|
||||
"""
|
||||
return tuple(
|
||||
manager.req_to_blocks.get(request_id) or []
|
||||
for manager in self.single_type_managers)
|
||||
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
pass
|
||||
|
||||
|
||||
class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator for models with only one KV cache group. This is the
|
||||
case for models with only one KV cache type, e.g., all attention layers use
|
||||
full attention or all attention layers use sliding window attention.
|
||||
"""
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_caching: bool,
|
||||
caching_hash_fn: Callable, enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec
|
||||
self.block_size = self.kv_cache_spec.block_size
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"UnitaryKVCacheCoordinator assumes only one kv cache group")
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
)
|
||||
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
||||
|
||||
|
||||
class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator for hybrid models with multiple KV cache types, and
|
||||
thus multiple kv cache groups.
|
||||
To simplify `find_longest_cache_hit`, it only supports the combination of
|
||||
two types of KV cache groups, and one of them must be full attention.
|
||||
May extend to more general cases in the future.
|
||||
"""
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_caching: bool,
|
||||
caching_hash_fn: Callable, enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
self.verify_and_split_kv_cache_groups()
|
||||
|
||||
def verify_and_split_kv_cache_groups(self) -> None:
|
||||
"""
|
||||
Verifies that the model has exactly two types of KV cache groups, and
|
||||
one of them is full attention. Then, split the kv cache groups into full
|
||||
attention groups and other groups.
|
||||
"""
|
||||
full_attention_type_id: Optional[str] = None
|
||||
other_type_id: Optional[str] = None
|
||||
self.full_attention_group_ids: list[int] = []
|
||||
self.other_group_ids: list[int] = []
|
||||
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
|
||||
if isinstance(g.kv_cache_spec, FullAttentionSpec):
|
||||
if full_attention_type_id is None:
|
||||
full_attention_type_id = g.kv_cache_spec.type_id
|
||||
else:
|
||||
assert full_attention_type_id == g.kv_cache_spec.type_id, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of "
|
||||
"full attention groups now.")
|
||||
self.full_attention_group_ids.append(i)
|
||||
else:
|
||||
if other_type_id is None:
|
||||
other_type_id = g.kv_cache_spec.type_id
|
||||
else:
|
||||
assert other_type_id == g.kv_cache_spec.type_id, (
|
||||
"HybridKVCacheCoordinator assumes "
|
||||
"exactly one other type of groups now.")
|
||||
self.other_group_ids.append(i)
|
||||
|
||||
assert full_attention_type_id is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of full "
|
||||
"attention groups now.")
|
||||
assert other_type_id is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of other "
|
||||
"groups now.")
|
||||
|
||||
self.full_attention_manager_cls = FullAttentionManager
|
||||
self.other_attention_cls = self.single_type_managers[
|
||||
self.other_group_ids[0]].__class__
|
||||
|
||||
self.full_attention_spec = self.kv_cache_config.kv_cache_groups[
|
||||
self.full_attention_group_ids[0]].kv_cache_spec
|
||||
self.other_spec = self.kv_cache_config.kv_cache_groups[
|
||||
self.other_group_ids[0]].kv_cache_spec
|
||||
|
||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||
self.other_block_size = self.other_spec.block_size
|
||||
assert self.other_block_size % self.full_attention_block_size == 0, (
|
||||
"KVCacheCoordinator assumes the block_size of full attention "
|
||||
"layers is divisible by other layers now.")
|
||||
|
||||
if max(self.full_attention_group_ids) < min(self.other_group_ids):
|
||||
self.full_attn_first = True
|
||||
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
|
||||
self.full_attn_first = False
|
||||
else:
|
||||
raise ValueError(
|
||||
"HybridKVCacheCoordinator assumes the full "
|
||||
"attention group ids and other attention group ids "
|
||||
"do not interleave, either full attention group ids "
|
||||
"are before other attention group ids or vice versa."
|
||||
"This is for simplifying merging hit_blocks_full_attn and "
|
||||
"hit_blocks_other_attn to hit_blocks.")
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
"""
|
||||
Find the longest cache hit for the request.
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_cache_hit_length: The maximum length of the cache hit.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A list of the cache hit blocks for each single type manager.
|
||||
- The number of tokens of the longest cache hit.
|
||||
"""
|
||||
# First, find the longest cache hit for full attention.
|
||||
hit_blocks_full_attn = (
|
||||
self.full_attention_manager_cls.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=self.full_attention_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.full_attention_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
))
|
||||
hit_length = len(
|
||||
hit_blocks_full_attn[0]) * self.full_attention_block_size
|
||||
|
||||
# Next, find the cache hit for the other attention WITHIN
|
||||
# the cache hit of full attention.
|
||||
hit_blocks_other_attn = (
|
||||
self.other_attention_cls.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=hit_length,
|
||||
kv_cache_group_ids=self.other_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.other_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
))
|
||||
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
|
||||
|
||||
# NOTE: the prefix cache hit length must be a multiple of block_size as
|
||||
# we don't support partial block cache hit yet. The cache hit length
|
||||
# of other attention is ensured to be a multiple of the block size of
|
||||
# full attention layers in current implementation, because hit_length is
|
||||
# a multiple of other attention's block size, and other attention's
|
||||
# block size is a multiple of full attention's block size (verified in
|
||||
# `verify_and_split_kv_cache_groups`).
|
||||
assert hit_length % self.full_attention_block_size == 0
|
||||
|
||||
# Truncate the full attention cache hit to the length of the
|
||||
# cache hit of the other attention.
|
||||
for group_hit_blocks in hit_blocks_full_attn:
|
||||
del group_hit_blocks[hit_length // self.full_attention_block_size:]
|
||||
|
||||
# Merge the hit blocks of full attention and other attention.
|
||||
if self.full_attn_first:
|
||||
hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn
|
||||
else:
|
||||
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
|
||||
return hit_blocks, hit_length
|
||||
|
||||
|
||||
def get_kv_cache_coordinator(
|
||||
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
|
||||
enable_caching: bool, caching_hash_fn: Callable,
|
||||
enable_kv_cache_events: bool) -> KVCacheCoordinator:
|
||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
|
||||
use_eagle, enable_caching,
|
||||
caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
392
vllm/v1/core/kv_cache_manager.py
Normal file
392
vllm/v1/core/kv_cache_manager.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
hash_request_tokens)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheBlocks:
|
||||
"""
|
||||
The allocation result of KVCacheManager, work as the interface between
|
||||
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
|
||||
structure from the Scheduler.
|
||||
"""
|
||||
blocks: tuple[list[KVCacheBlock], ...]
|
||||
"""
|
||||
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
|
||||
We don't use block of tokens as the outer dimension because it assumes all
|
||||
kv_cache_groups have the same number of blocks, which is true for now but
|
||||
will be broken if we want to give different block_size to different
|
||||
kv_cache_groups in the future.
|
||||
"""
|
||||
|
||||
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
|
||||
"""Adds two KVCacheBlocks instances."""
|
||||
return KVCacheBlocks(
|
||||
tuple(blk1 + blk2
|
||||
for blk1, blk2 in zip(self.blocks, other.blocks)))
|
||||
|
||||
def get_block_ids(self) -> tuple[list[int], ...]:
|
||||
"""
|
||||
Converts the KVCacheBlocks instance to block_ids.
|
||||
|
||||
Returns:
|
||||
tuple[list[int], ...]: A tuple of lists where
|
||||
* the outer tuple corresponds to KV cache groups
|
||||
* each inner list contains the block_ids of the blocks in that group
|
||||
"""
|
||||
return tuple([blk.block_id for blk in group] for group in self.blocks)
|
||||
|
||||
def get_unhashed_block_ids(self) -> list[int]:
|
||||
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||||
assert len(self.blocks) == 1, "Only one group is supported"
|
||||
return [
|
||||
block.block_id for block in self.blocks[0]
|
||||
if block.block_hash is None
|
||||
]
|
||||
|
||||
def new_empty(self) -> "KVCacheBlocks":
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
enable_caching: bool = True,
|
||||
caching_hash_algo: str = "builtin",
|
||||
use_eagle: bool = False,
|
||||
log_stats: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
self.use_eagle = use_eagle
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
assert len(
|
||||
set(g.kv_cache_spec.block_size
|
||||
for g in kv_cache_config.kv_cache_groups)
|
||||
) == 1, "Only one block size is supported for now"
|
||||
self.block_size = kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec.block_size
|
||||
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
use_eagle=self.use_eagle,
|
||||
enable_caching=enable_caching,
|
||||
caching_hash_fn=self.caching_hash_fn,
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
self.kv_cache_config = kv_cache_config
|
||||
|
||||
# Mapping from request ID to kv block hashes.
|
||||
# This is to avoid recomputing the block hashes for each call of
|
||||
# `get_computed_blocks` or `allocate_slots`.
|
||||
self.req_to_block_hashes: defaultdict[
|
||||
str, list[BlockHash]] = defaultdict(list)
|
||||
|
||||
@property
|
||||
def usage(self) -> float:
|
||||
"""Get the KV cache usage.
|
||||
|
||||
Returns:
|
||||
The KV cache usage (between 0.0 and 1.0).
|
||||
"""
|
||||
return self.block_pool.get_usage()
|
||||
|
||||
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
|
||||
"""Get (and reset) the prefix cache stats.
|
||||
|
||||
Returns:
|
||||
The current prefix caching stats, or None if logging is disabled.
|
||||
"""
|
||||
if not self.log_stats:
|
||||
return None
|
||||
stats = self.prefix_cache_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats()
|
||||
return stats
|
||||
|
||||
def get_computed_blocks(self,
|
||||
request: Request) -> tuple[KVCacheBlocks, int]:
|
||||
"""Get the computed (cached) blocks for the request.
|
||||
Note that the computed blocks must be full.
|
||||
|
||||
Args:
|
||||
request: The request to get the computed blocks.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A list of blocks that are computed for the request.
|
||||
- The number of computed tokens.
|
||||
"""
|
||||
# Prefix caching is disabled or
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (not self.enable_caching
|
||||
or request.sampling_params.prompt_logprobs is not None):
|
||||
return self.create_empty_block_list(), 0
|
||||
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
if not block_hashes:
|
||||
block_hashes = hash_request_tokens(self.caching_hash_fn,
|
||||
self.block_size, request)
|
||||
self.req_to_block_hashes[request.request_id] = block_hashes
|
||||
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.requests += 1
|
||||
|
||||
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||||
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
|
||||
# This can trigger recomputation of an entire block, rather than just
|
||||
# the single last token, because allocate_slots() requires
|
||||
# num_computed_tokens to be block-size aligned. Removing this limitation
|
||||
# could slightly improve performance in the future.
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
computed_blocks, num_new_computed_tokens = (
|
||||
self.coordinator.find_longest_cache_hit(block_hashes,
|
||||
max_cache_hit_length))
|
||||
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.queries += request.num_tokens
|
||||
self.prefix_cache_stats.hits += num_new_computed_tokens
|
||||
|
||||
return KVCacheBlocks(computed_blocks), num_new_computed_tokens
|
||||
|
||||
def allocate_slots(
|
||||
self,
|
||||
request: Request,
|
||||
num_new_tokens: int,
|
||||
num_new_computed_tokens: int = 0,
|
||||
new_computed_blocks: Optional[KVCacheBlocks] = None,
|
||||
num_draft_tokens: int = 0,
|
||||
num_lookahead_tokens: int = 0,
|
||||
delay_cache_blocks: bool = False,
|
||||
) -> Optional[KVCacheBlocks]:
|
||||
"""Add slots for a request with new tokens to append.
|
||||
|
||||
Args:
|
||||
request: The request to allocate slots.
|
||||
num_new_tokens: The number of tokens to allocate, including external
|
||||
tokens. Note that this does not include tokens that have
|
||||
already been computed locally (i.e. new_computed_blocks).
|
||||
num_new_computed_tokens: The number of new computed tokens just
|
||||
hitting the prefix caching, excluding external tokens.
|
||||
new_computed_blocks: The cached blocks for the above new computed
|
||||
tokens.
|
||||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||
This is used by spec decode proposers with kv-cache such
|
||||
as eagle.
|
||||
delay_cache_blocks: Whether to skip caching the blocks. This is
|
||||
used by P/D when allocating blocks used in a KV transfer
|
||||
which will complete in a future step.
|
||||
|
||||
Blocks layout:
|
||||
```
|
||||
-----------------------------------------------------------------------
|
||||
| < computed > | < new computed > | < new > | < pre-allocated > |
|
||||
-----------------------------------------------------------------------
|
||||
| < required > |
|
||||
--------------------------------------------------
|
||||
| < full > |
|
||||
------------------------------------------------
|
||||
| <new full> |
|
||||
--------------
|
||||
```
|
||||
The following *_blocks are illustrated in this layout.
|
||||
|
||||
Returns:
|
||||
A list of new allocated blocks.
|
||||
"""
|
||||
if num_new_tokens == 0:
|
||||
raise ValueError("num_new_tokens must be greater than 0")
|
||||
|
||||
if new_computed_blocks is not None:
|
||||
new_computed_block_list = new_computed_blocks.blocks
|
||||
else:
|
||||
new_computed_block_list = tuple(
|
||||
[] for _ in range(len(self.kv_cache_config.kv_cache_groups)))
|
||||
|
||||
# Free the blocks that are skipped during the attention computation
|
||||
# (e.g., tokens outside the sliding window).
|
||||
# We can do this even if we cannot schedule this request due to
|
||||
# insufficient free blocks.
|
||||
# Should call this function before allocating new blocks to reduce
|
||||
# the number of evicted blocks.
|
||||
self.coordinator.remove_skipped_blocks(request.request_id,
|
||||
request.num_computed_tokens)
|
||||
|
||||
# The number of computed tokens is the number of computed tokens plus
|
||||
# the new prefix caching hits
|
||||
num_computed_tokens = (request.num_computed_tokens +
|
||||
num_new_computed_tokens)
|
||||
num_tokens_need_slot = min(
|
||||
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
|
||||
self.max_model_len)
|
||||
|
||||
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
|
||||
request_id=request.request_id,
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
)
|
||||
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
# Cannot allocate new blocks
|
||||
return None
|
||||
|
||||
# Touch the computed blocks to make sure they won't be evicted.
|
||||
if self.enable_caching:
|
||||
self.block_pool.touch(new_computed_block_list)
|
||||
else:
|
||||
assert not any(new_computed_block_list), (
|
||||
"Computed blocks should be empty when "
|
||||
"prefix caching is disabled")
|
||||
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.coordinator.save_new_computed_blocks(request.request_id,
|
||||
new_computed_block_list)
|
||||
|
||||
new_blocks = self.coordinator.allocate_new_blocks(
|
||||
request.request_id, num_tokens_need_slot)
|
||||
|
||||
# P/D: delay caching blocks if we have to recv from
|
||||
# remote. Update state for locally cached blocks.
|
||||
if not self.enable_caching or delay_cache_blocks:
|
||||
return KVCacheBlocks(new_blocks)
|
||||
|
||||
# Speculated tokens might be rejected in the future, so we does
|
||||
# not cache any speculated tokens. We only cache blocks with
|
||||
# generated (accepted) tokens.
|
||||
self.coordinator.cache_blocks(
|
||||
request, self.req_to_block_hashes[request.request_id],
|
||||
num_computed_tokens + num_new_tokens - num_draft_tokens)
|
||||
|
||||
return KVCacheBlocks(new_blocks)
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free the blocks allocated for the request.
|
||||
We free the blocks in reverse order so that he tail blocks are evicted
|
||||
first when caching is enabled.
|
||||
|
||||
Args:
|
||||
request: The request to free the blocks.
|
||||
"""
|
||||
self.coordinator.free(request.request_id)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalidate prefix caching after the weights are updated,
|
||||
or used for resetting prefix caching status for benchmarking.
|
||||
|
||||
Returns:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
if not self.block_pool.reset_prefix_cache():
|
||||
return False
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.reset = True
|
||||
return True
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self,
|
||||
request: Request,
|
||||
num_running_requests: int,
|
||||
) -> list[int]:
|
||||
"""Calculate the number of common prefix blocks shared by all requests
|
||||
in the RUNNING state for each kv cache group.
|
||||
|
||||
The function determines this by selecting any request and iterating
|
||||
through its blocks. A block is considered a common prefix block if its
|
||||
`ref_cnt` equals the total number of requests in the RUNNING state.
|
||||
|
||||
NOTE(woosuk): The number of requests in the RUNNING state is **greater
|
||||
than or equal to** the number of requests scheduled in the current step.
|
||||
This is because the RUNNING state only indicates that:
|
||||
1. The request has not yet finished, and
|
||||
2. The request holds its blocks unfreed.
|
||||
|
||||
While all scheduled requests must be in the RUNNING state, the inverse
|
||||
is not necessarily true. There may be RUNNING requests that are not
|
||||
scheduled in the current step.
|
||||
|
||||
This can result in an edge case where the number of common prefix blocks
|
||||
is 0, even though all scheduled requests share a common prefix. This
|
||||
occurs because there may be unscheduled RUNNING requests that do not
|
||||
share the common prefix. Currently, this case cannot be easily detected,
|
||||
so the function returns 0 in such cases.
|
||||
|
||||
Args:
|
||||
request: Any request in the RUNNING state, used to identify the
|
||||
common prefix blocks.
|
||||
num_running_requests: The total number of requests in the RUNNING
|
||||
state. This can be different from the number of scheduled
|
||||
requests in the current step.
|
||||
|
||||
Returns:
|
||||
list[int]: The number of common prefix blocks for each kv cache
|
||||
group.
|
||||
"""
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
return self.coordinator.get_num_common_prefix_blocks(
|
||||
request.request_id, num_running_requests)
|
||||
|
||||
def free_block_hashes(self, request: Request) -> None:
|
||||
"""Discard the block hashes for the request.
|
||||
|
||||
NOTE: Unlike `free`, this method should be called only when the request
|
||||
is finished, not when it is preempted.
|
||||
"""
|
||||
self.req_to_block_hashes.pop(request.request_id, None)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Take the KV cache events from the block pool.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
return self.block_pool.take_events()
|
||||
|
||||
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
|
||||
"""Get the block ids of a request."""
|
||||
return KVCacheBlocks(
|
||||
self.coordinator.get_blocks(request_id)).get_block_ids()
|
||||
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""Cache the blocks for the request."""
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
self.coordinator.cache_blocks(request, block_hashes,
|
||||
num_computed_tokens)
|
||||
|
||||
def create_empty_block_list(self) -> KVCacheBlocks:
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
return KVCacheBlocks(tuple([]
|
||||
for _ in range(self.num_kv_cache_groups)))
|
||||
996
vllm/v1/core/kv_cache_utils.py
Normal file
996
vllm/v1/core/kv_cache_utils.py
Normal file
@@ -0,0 +1,996 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""KV-Cache Utilities."""
|
||||
|
||||
import os
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, NamedTuple, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import GiB_bytes, cdiv, sha256
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
KVCacheTensor, SlidingWindowSpec)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockHash(NamedTuple):
|
||||
"""Hash value of a block (int), the token IDs in the block, and extra keys.
|
||||
We keep a tuple of token IDs and extra keys to reduce the likelihood of
|
||||
hash collisions when the hash value is the same. By using SHA256 however,
|
||||
hash collisions are practically impossible.
|
||||
"""
|
||||
# Hash value of the block in an integer.
|
||||
hash_value: int
|
||||
# Token IDs in the block.
|
||||
token_ids: tuple[int, ...]
|
||||
# Extra keys for the block.
|
||||
extra_keys: Optional[Any] = None
|
||||
|
||||
|
||||
class BlockHashWithGroupId(NamedTuple):
|
||||
# The hash value for the contents (e.g., token_ids) of a block without group
|
||||
# ID. The value is the same for blocks representing the same tokens but for
|
||||
# different groups.
|
||||
block_hash: BlockHash
|
||||
# The KV cache group ID.
|
||||
group_id: int
|
||||
|
||||
def get_hash_value(self) -> int:
|
||||
return self.block_hash.hash_value
|
||||
|
||||
|
||||
# The hash seed for the first block of the prefix block sequence.
|
||||
#
|
||||
# Even if the hash function is the builtin hash(), we use sha256 to generate
|
||||
# the initial hash to simplify the code. This is not performance critical
|
||||
# as it is done one per process.
|
||||
#
|
||||
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
|
||||
# variable if set such that processes can share the seed if needed.
|
||||
# This aligns with the behavior of Python's hash() function, which also uses
|
||||
# a random seed if PYTHONHASHSEED is not set.
|
||||
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
|
||||
"PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED"))
|
||||
|
||||
|
||||
class PrefixCachingMetrics:
|
||||
"""Metrics for prefix caching with a hit rate of the max recent N requests.
|
||||
|
||||
Args:
|
||||
max_recent_requests: The number of the max recent requests to aggregate.
|
||||
Defaults to 1000.
|
||||
"""
|
||||
|
||||
def __init__(self, max_recent_requests: int = 1000):
|
||||
self.max_recent_requests = max_recent_requests
|
||||
# The current aggregated values.
|
||||
self.aggregated_requests = 0
|
||||
self.aggregated_query_total = 0
|
||||
self.aggregated_query_hit = 0
|
||||
# A deque of (requests, queries, hits) for the most recent requests.
|
||||
self.query_queue: deque[tuple[int, int, int]] = deque()
|
||||
|
||||
def observe(self, stats: PrefixCacheStats):
|
||||
"""Observe the prefix caching for a set of requests.
|
||||
|
||||
This function is called with information gathered when new requests
|
||||
are being scheduled and are looking for computed blocks.
|
||||
|
||||
When there are more than `interval` requests, the oldest set of
|
||||
requests are removed from the metrics.
|
||||
|
||||
Args:
|
||||
stats: The prefix cache stats.
|
||||
"""
|
||||
# reset_prefix_cache was invoked before the current update.
|
||||
# Reset the metrics before aggregating the current stats.
|
||||
if stats.reset:
|
||||
self.reset()
|
||||
|
||||
# Update the metrics.
|
||||
self.query_queue.append((stats.requests, stats.queries, stats.hits))
|
||||
self.aggregated_requests += stats.requests
|
||||
self.aggregated_query_total += stats.queries
|
||||
self.aggregated_query_hit += stats.hits
|
||||
|
||||
# Remove the oldest stats if the number of requests exceeds.
|
||||
if self.aggregated_requests > self.max_recent_requests:
|
||||
old_requests, old_queries, old_hits = self.query_queue.popleft()
|
||||
self.aggregated_requests -= old_requests
|
||||
self.aggregated_query_total -= old_queries
|
||||
self.aggregated_query_hit -= old_hits
|
||||
|
||||
def reset(self):
|
||||
"""Reset the metrics."""
|
||||
self.aggregated_requests = 0
|
||||
self.aggregated_query_total = 0
|
||||
self.aggregated_query_hit = 0
|
||||
self.query_queue.clear()
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate the hit rate for the past N requests."""
|
||||
if self.aggregated_query_total == 0:
|
||||
return 0.0
|
||||
return self.aggregated_query_hit / self.aggregated_query_total
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheBlock:
|
||||
"""KV-cache block metadata."""
|
||||
# Block ID, ranging from 0 to num_gpu_blocks - 1.
|
||||
block_id: int
|
||||
# Reference count.
|
||||
ref_cnt: int = 0
|
||||
# The hash of the block composed of (block hash, tuple of token IDs).
|
||||
# It is only available when the block is full.
|
||||
_block_hash: Optional[BlockHashWithGroupId] = None
|
||||
|
||||
# Used to construct a doubly linked list for free blocks.
|
||||
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
|
||||
prev_free_block: Optional["KVCacheBlock"] = None
|
||||
next_free_block: Optional["KVCacheBlock"] = None
|
||||
|
||||
# Whether the block is a null block that should never be cached.
|
||||
is_null: bool = False
|
||||
|
||||
def incr_ref(self):
|
||||
self.ref_cnt += 1
|
||||
|
||||
def decr_ref(self):
|
||||
self.ref_cnt -= 1
|
||||
|
||||
@property
|
||||
def block_hash(self) -> Optional[BlockHashWithGroupId]:
|
||||
return self._block_hash
|
||||
|
||||
@block_hash.setter
|
||||
def block_hash(self, block_hash: BlockHashWithGroupId):
|
||||
assert self.block_hash is None, (
|
||||
"The block already has a hash. This should not happen.")
|
||||
self._block_hash = block_hash
|
||||
|
||||
def reset_hash(self):
|
||||
"""Reset the block hash when the block is evicted."""
|
||||
self._block_hash = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Use block_id instead of KVCacheBlock object to avoid calling __repr__
|
||||
# on KVCacheBlock object recursively.
|
||||
prev_block_id = (self.prev_free_block.block_id
|
||||
if self.prev_free_block else None)
|
||||
next_block_id = (self.next_free_block.block_id
|
||||
if self.next_free_block else None)
|
||||
return (f"KVCacheBlock(block_id={self.block_id}, "
|
||||
f"ref_cnt={self.ref_cnt}, "
|
||||
f"_block_hash={self._block_hash}, "
|
||||
f"prev_free_block={prev_block_id}, "
|
||||
f"next_free_block={next_block_id})")
|
||||
|
||||
|
||||
class FreeKVCacheBlockQueue:
|
||||
"""This class organizes a list of KVCacheBlock objects to a doubly linked
|
||||
list of free blocks. We implement this class instead of using Python
|
||||
builtin deque to support removing a block in the middle of the queue
|
||||
in O(1) time. To close the performance gap to the builtin deque which is
|
||||
implemented in C++, this class does not allocate any Python objects when
|
||||
manipulating the linked list. Instead, this class manipulates the
|
||||
prev_free_block and next_free_block attributes of the given blocks.
|
||||
|
||||
The queue is ordered by block ID in the beginning. When a block is allocated
|
||||
and then freed, it will be appended back with the eviction order:
|
||||
1. The least recent used block is at the front (LRU).
|
||||
2. If two blocks have the same last accessed time (allocated by the
|
||||
same sequence), the one with more hash tokens (the tail of a block
|
||||
chain) is at the front.
|
||||
Note that we maintain this order by reversing the block order when free
|
||||
blocks of a request. This operation is outside of this class.
|
||||
|
||||
Args:
|
||||
blocks: A list of KVCacheBlock objects.
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: list[KVCacheBlock]) -> None:
|
||||
self.num_free_blocks = len(blocks)
|
||||
|
||||
# Initialize the doubly linked list of free blocks.
|
||||
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
|
||||
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
|
||||
for i in range(self.num_free_blocks):
|
||||
if i > 0:
|
||||
blocks[i].prev_free_block = blocks[i - 1]
|
||||
if i < self.num_free_blocks - 1:
|
||||
blocks[i].next_free_block = blocks[i + 1]
|
||||
|
||||
def popleft(self) -> KVCacheBlock:
|
||||
"""Pop the first free block and reduce num_free_blocks by 1.
|
||||
|
||||
Returns:
|
||||
The first free block.
|
||||
"""
|
||||
if not self.free_list_head:
|
||||
raise ValueError("No free blocks available")
|
||||
|
||||
block = self.free_list_head
|
||||
self.remove(block)
|
||||
return block
|
||||
|
||||
def remove(self, block: KVCacheBlock) -> None:
|
||||
"""Remove a block in the free list and reduce num_free_blocks by 1.
|
||||
|
||||
Args:
|
||||
block: The block to remove.
|
||||
"""
|
||||
if block.prev_free_block is not None:
|
||||
# Link the previous block to the next block.
|
||||
block.prev_free_block.next_free_block = block.next_free_block
|
||||
if block.next_free_block is not None:
|
||||
# Link the next block to the previous block.
|
||||
block.next_free_block.prev_free_block = block.prev_free_block
|
||||
|
||||
if block == self.free_list_head:
|
||||
# Update the head if the block is the head.
|
||||
self.free_list_head = block.next_free_block
|
||||
if block == self.free_list_tail:
|
||||
# Update the tail if the block is the tail.
|
||||
self.free_list_tail = block.prev_free_block
|
||||
|
||||
# Remove the block from the linked list.
|
||||
block.prev_free_block = block.next_free_block = None
|
||||
self.num_free_blocks -= 1
|
||||
|
||||
def append(self, block: KVCacheBlock) -> None:
|
||||
"""Put a block back into the free list and increase
|
||||
num_free_blocks by 1.
|
||||
|
||||
Args:
|
||||
block: The block to append.
|
||||
"""
|
||||
if self.free_list_tail is not None:
|
||||
# Link the last block to the new block.
|
||||
self.free_list_tail.next_free_block = block
|
||||
block.prev_free_block = self.free_list_tail
|
||||
self.free_list_tail = block
|
||||
else:
|
||||
# The free list is empty.
|
||||
assert self.free_list_head is None
|
||||
self.free_list_head = self.free_list_tail = block
|
||||
|
||||
block.next_free_block = None
|
||||
self.num_free_blocks += 1
|
||||
|
||||
def get_all_free_blocks(self) -> list[KVCacheBlock]:
|
||||
"""Get all free blocks in the free list. Mainly used for testing.
|
||||
|
||||
Returns:
|
||||
A list of free blocks.
|
||||
"""
|
||||
ret = []
|
||||
curr_block = self.free_list_head
|
||||
while curr_block is not None:
|
||||
ret.append(curr_block)
|
||||
curr_block = curr_block.next_free_block
|
||||
return ret
|
||||
|
||||
|
||||
def need_extra_keys(request: Request) -> bool:
|
||||
"""Check whether the blocks allocated to this request need extra hash keys.
|
||||
|
||||
Args:
|
||||
request (Request): The request.
|
||||
|
||||
Returns:
|
||||
bool: Whether blocks allocated to this request need extra hash keys.
|
||||
"""
|
||||
|
||||
# Multimodal requests need to include the MM hash.
|
||||
# LoRA requests need to include the LoRA ID.
|
||||
# Request with provided cache salt need to include the salt.
|
||||
return bool(request.mm_positions) or (request.lora_request
|
||||
is not None) or (request.cache_salt
|
||||
is not None)
|
||||
|
||||
|
||||
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
|
||||
end_token_idx: int,
|
||||
start_mm_idx: int) -> tuple[list[Any], int]:
|
||||
"""Generate extra keys related to MultiModal request for block hash
|
||||
computation. For multi-modal inputs, the extra keys are
|
||||
(mm_hash, start_offset) that indicate a mm input contained in the
|
||||
block and its starting offset in the block tokens.
|
||||
|
||||
Args:
|
||||
request: The request object.
|
||||
start_token_idx: The start token index of the block.
|
||||
end_token_idx: The end token index of the block.
|
||||
start_mm_idx: The start multi-modal index of the block.
|
||||
|
||||
Returns:
|
||||
A tuple of extra keys and the next multi-modal index.
|
||||
"""
|
||||
extra_keys: list[Any] = []
|
||||
|
||||
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
|
||||
if not mm_positions:
|
||||
return extra_keys, start_mm_idx
|
||||
|
||||
if mm_positions and len(mm_positions) != len(mm_hashes):
|
||||
raise ValueError(
|
||||
"The number of multi-modal positions and hashes must match. This "
|
||||
"is likely because you do not enable MM preprocessor hashing. "
|
||||
"Please set disable_mm_preprocessor_cache=False.")
|
||||
|
||||
# Note that we assume mm_positions is sorted by offset.
|
||||
# We do not need to check all mm inputs if the start token index is out of
|
||||
# range. This usually happens in the late prefill phase and decoding phase.
|
||||
if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
|
||||
return extra_keys, start_mm_idx
|
||||
|
||||
# Support start_mm_idx == -1 to indicate the last mm input.
|
||||
if start_mm_idx < 0:
|
||||
assert -start_mm_idx <= len(mm_positions)
|
||||
start_mm_idx = len(mm_positions) + start_mm_idx
|
||||
|
||||
curr_mm_idx = start_mm_idx
|
||||
while mm_positions and curr_mm_idx < len(mm_positions):
|
||||
assert mm_hashes[curr_mm_idx] is not None
|
||||
offset = mm_positions[curr_mm_idx].offset
|
||||
length = mm_positions[curr_mm_idx].length
|
||||
if end_token_idx > offset:
|
||||
if start_token_idx > offset + length:
|
||||
# This block has passed the current mm input.
|
||||
curr_mm_idx += 1
|
||||
continue
|
||||
|
||||
# The block contains the current mm input.
|
||||
extra_keys.append(mm_hashes[curr_mm_idx])
|
||||
|
||||
if end_token_idx >= offset + length:
|
||||
# If this block contains the end of the current mm input,
|
||||
# move to the next mm input as this block may also contain
|
||||
# the next mm input.
|
||||
curr_mm_idx += 1
|
||||
else:
|
||||
# Otherwise this block is done with mm inputs.
|
||||
break
|
||||
else:
|
||||
# This block has not reached the current mm input.
|
||||
break
|
||||
return extra_keys, curr_mm_idx
|
||||
|
||||
|
||||
def _gen_lora_extra_hash_keys(request: Request) -> list[int]:
|
||||
"""Generate extra keys related to LoRA for block hash computation.
|
||||
|
||||
Args:
|
||||
request: The request object.
|
||||
|
||||
Returns:
|
||||
Return LoRA id of the request if it is a LoRA request. Return empty
|
||||
list otherwise.
|
||||
"""
|
||||
if not request.lora_request:
|
||||
return []
|
||||
return [request.lora_request.lora_int_id]
|
||||
|
||||
|
||||
def generate_block_hash_extra_keys(
|
||||
request: Request, start_token_idx: int, end_token_idx: int,
|
||||
start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]:
|
||||
"""Generate extra keys for the block hash. The extra keys can come from
|
||||
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
|
||||
|
||||
Args:
|
||||
request: The request object.
|
||||
start_token_idx: The start token index of the block.
|
||||
end_token_idx: The end token index of the block.
|
||||
start_mm_idx: The start multi-modal index of the block.
|
||||
|
||||
Returns:
|
||||
A tuple of extra keys and the next multi-modal index.
|
||||
"""
|
||||
mm_extra_keys: list[Any]
|
||||
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
|
||||
request, start_token_idx, end_token_idx, start_mm_idx)
|
||||
lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request)
|
||||
cache_salt_keys: list[str] = [request.cache_salt] if (
|
||||
start_token_idx == 0 and request.cache_salt) else []
|
||||
|
||||
extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys
|
||||
|
||||
if not extra_keys:
|
||||
return None, new_start_mm_idx
|
||||
|
||||
return tuple(extra_keys), new_start_mm_idx
|
||||
|
||||
|
||||
def hash_block_tokens(
|
||||
hash_function: Callable,
|
||||
parent_block_hash: Optional[int],
|
||||
curr_block_token_ids: Sequence[int],
|
||||
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash:
|
||||
"""Computes a hash value corresponding to the contents of a block and
|
||||
the contents of the preceding block(s). The hash value is used for
|
||||
prefix caching. We use LRU cache for this function to avoid recomputing
|
||||
hash values for the same block contents.
|
||||
|
||||
Args:
|
||||
parent_block_hash: The hash of the parent block. None
|
||||
if this is the first block.
|
||||
curr_block_token_ids: A list of token ids in the current
|
||||
block. The current block is assumed to be full.
|
||||
extra_keys: Extra keys for the block.
|
||||
|
||||
Returns:
|
||||
The hash value of the block and the token ids in the block.
|
||||
The entire tuple is used as the hash key of the block.
|
||||
"""
|
||||
if not parent_block_hash:
|
||||
parent_block_hash = NONE_HASH
|
||||
|
||||
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
|
||||
return BlockHash(
|
||||
hash_function(
|
||||
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
|
||||
curr_block_token_ids_tuple, extra_keys)
|
||||
|
||||
|
||||
def hash_request_tokens(hash_function: Any, block_size: int,
|
||||
request: Request) -> list[BlockHash]:
|
||||
"""Computes hash values of a chain of blocks given a sequence of
|
||||
token IDs. The hash value is used for prefix caching.
|
||||
|
||||
Args:
|
||||
block_size: The size of each block.
|
||||
request: The request object.
|
||||
|
||||
Returns:
|
||||
The list of computed hash values.
|
||||
"""
|
||||
token_ids = request.all_token_ids
|
||||
|
||||
req_need_extra_keys = need_extra_keys(request)
|
||||
req_extra_keys = None
|
||||
curr_mm_idx = 0
|
||||
|
||||
ret = []
|
||||
parent_block_hash_value = None
|
||||
for start in range(0, len(token_ids), block_size):
|
||||
end = start + block_size
|
||||
block_token_ids = token_ids[start:end]
|
||||
# Do not hash the block if it is not full.
|
||||
if len(block_token_ids) < block_size:
|
||||
break
|
||||
|
||||
if req_need_extra_keys:
|
||||
# MM and LoRA requests need extra keys for block-hash computation.
|
||||
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
|
||||
request, start, end, curr_mm_idx)
|
||||
|
||||
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
|
||||
block_token_ids, req_extra_keys)
|
||||
ret.append(block_hash)
|
||||
parent_block_hash_value = block_hash.hash_value
|
||||
return ret
|
||||
|
||||
|
||||
def max_memory_usage_bytes(vllm_config: VllmConfig,
|
||||
kv_cache_specs: Iterable[KVCacheSpec]) -> int:
|
||||
"""
|
||||
Get the maximum memory usage in bytes for the given KV cache specs.
|
||||
"""
|
||||
return sum(
|
||||
spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs)
|
||||
|
||||
|
||||
def estimate_max_model_len(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> int:
|
||||
"""
|
||||
Estimates the maximum model length that can fit in the available memory
|
||||
using binary search.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The estimated maximum model length that can fit in the available memory.
|
||||
"""
|
||||
|
||||
# Define a function to check if a given model length fits in memory
|
||||
def fits_in_memory(model_len: int) -> bool:
|
||||
# Modify the max_model_len for this calculation
|
||||
vllm_config.model_config.max_model_len = model_len
|
||||
# Calculate memory needed for the given model length
|
||||
memory_needed = max_memory_usage_bytes(vllm_config,
|
||||
kv_cache_spec.values())
|
||||
return memory_needed <= available_memory
|
||||
|
||||
# Binary search for the maximum model length
|
||||
current_max = vllm_config.model_config.max_model_len
|
||||
left, right = 1, current_max
|
||||
|
||||
# If even the smallest model length doesn't fit, return 0
|
||||
if not fits_in_memory(left):
|
||||
return 0
|
||||
|
||||
# Binary search for the maximum model length that fits
|
||||
result = 1
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
if fits_in_memory(mid):
|
||||
result = mid
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid - 1
|
||||
return result
|
||||
|
||||
|
||||
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int):
|
||||
"""
|
||||
Checks whether `available_memory` is enough for the KV cache to hold at
|
||||
least one request with the model's max_model_len.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is not enough memory available for the KV cache.
|
||||
"""
|
||||
|
||||
if available_memory <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values())
|
||||
|
||||
if needed_memory > available_memory:
|
||||
# Estimate the maximum model length that can fit in the available memory
|
||||
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
estimated_msg = ""
|
||||
if estimated_max_len > 0:
|
||||
estimated_msg = (
|
||||
"Based on the available memory, "
|
||||
f"the estimated maximum model length is {estimated_max_len}.")
|
||||
|
||||
raise ValueError(
|
||||
f"To serve at least one request with the models's max seq len "
|
||||
f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV "
|
||||
f"cache is needed, which is larger than the available KV cache "
|
||||
f"memory ({available_memory/GiB_bytes:.2f} GiB). "
|
||||
f"{estimated_msg} "
|
||||
f"Try increasing `gpu_memory_utilization` or decreasing "
|
||||
f"`max_model_len` when initializing the engine.")
|
||||
|
||||
|
||||
def create_kv_cache_group_specs(
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
|
||||
"""
|
||||
Create KVCacheGroupSpec object for each kv cache group layer.
|
||||
The layers in the same group should share the same
|
||||
KVCacheSpec.
|
||||
|
||||
Args:
|
||||
kv_cache_spec:
|
||||
A mapping from each layer name to its corresponding KVCacheSpec.
|
||||
grouped_layer_names:
|
||||
A list of kv cache groups, where each element is a list of layer
|
||||
names that belong to the same group and should share the same
|
||||
KVCacheSpec.
|
||||
Returns:
|
||||
A list of KVCacheGroupSpec objects, one for each group.
|
||||
"""
|
||||
kv_cache_groups = []
|
||||
for layer_names_one_group in grouped_layer_names:
|
||||
layer_specs = [
|
||||
kv_cache_spec[layer_name] for layer_name in layer_names_one_group
|
||||
]
|
||||
merged_layer_spec = layer_specs[0].merge(layer_specs)
|
||||
kv_cache_groups.append(
|
||||
KVCacheGroupSpec(layer_names_one_group, merged_layer_spec))
|
||||
return kv_cache_groups
|
||||
|
||||
|
||||
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
"""
|
||||
Whether all layers in the given KVCacheSpec have the same type of KV cache.
|
||||
|
||||
Args:
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
|
||||
Returns:
|
||||
True if all layers have the same type, False otherwise.
|
||||
"""
|
||||
|
||||
layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
|
||||
return len(layer_keys) == 1
|
||||
|
||||
|
||||
def get_max_concurrency_for_kv_cache_config(
|
||||
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig) -> float:
|
||||
"""
|
||||
Get the maximum concurrency for the given KV cache configuration.
|
||||
"""
|
||||
num_layer_per_group = max(
|
||||
len(group.layer_names) for group in kv_cache_config.kv_cache_groups)
|
||||
max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes(
|
||||
vllm_config,
|
||||
(group.kv_cache_spec for group in kv_cache_config.kv_cache_groups))
|
||||
memory_per_block = kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec.page_size_bytes * num_layer_per_group
|
||||
num_block_per_request = cdiv(max_memory_usage_per_request,
|
||||
memory_per_block)
|
||||
max_concurrency = kv_cache_config.num_blocks / num_block_per_request
|
||||
return max_concurrency
|
||||
|
||||
|
||||
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
|
||||
available_memory: int, page_size: int) -> int:
|
||||
"""
|
||||
Get the number of kv cache blocks.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
num_layers: The number of layers
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
page_size: The page size of the KV cache.
|
||||
"""
|
||||
num_blocks = int(available_memory // page_size // num_layers)
|
||||
num_blocks = max(num_blocks, 0)
|
||||
if vllm_config.cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = \
|
||||
vllm_config.cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
||||
return num_blocks
|
||||
|
||||
|
||||
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
|
||||
"""
|
||||
Get the page size of the KV cache.
|
||||
"""
|
||||
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
|
||||
assert len(page_sizes) == 1
|
||||
return page_sizes.pop()
|
||||
|
||||
|
||||
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
"""
|
||||
Generates the KV cache configuration for a model with one type of KV cache.
|
||||
Divide the available memory equally among all layers.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
|
||||
page_size = get_uniform_page_size(kv_cache_spec)
|
||||
num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec),
|
||||
available_memory, page_size)
|
||||
|
||||
per_layer_size = page_size * num_blocks
|
||||
# All layers have the same KV cache spec, so we create one kv cache group
|
||||
# for all layers.
|
||||
grouped_layer_names = [list(kv_cache_spec.keys())]
|
||||
|
||||
# Each layer uses a separate Tensor to store its KV cache.
|
||||
kv_cache_tensors = [
|
||||
KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
|
||||
for layer_name in kv_cache_spec
|
||||
]
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=kv_cache_tensors,
|
||||
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
|
||||
grouped_layer_names),
|
||||
)
|
||||
|
||||
num_tokens = num_blocks * vllm_config.cache_config.block_size
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
max_concurrency = get_max_concurrency_for_kv_cache_config(
|
||||
vllm_config, kv_cache_config)
|
||||
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
||||
max_model_len_str, max_concurrency)
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
def is_kv_cache_page_size_uniform(
|
||||
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
"""
|
||||
Whether all layers in the given KVCacheSpec have the same page size.
|
||||
Args:
|
||||
kv_cache_spec: The KVCacheSpec of each attention layer in the model
|
||||
|
||||
Returns:
|
||||
True if all layers have the same page size, False otherwise.
|
||||
"""
|
||||
|
||||
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
|
||||
return len(page_sizes) == 1
|
||||
|
||||
|
||||
def _get_kv_cache_config_uniform_page_size(
|
||||
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
"""
|
||||
Generates the KV cache configuration for hybrid models with multiple
|
||||
attention types but still with a uniform page size (physical memory per
|
||||
block per layer) for all layers.
|
||||
|
||||
Detailed explanation about kv cache management of hybrid models:
|
||||
The layers in the models are repeated with some patterns, e.g., a model
|
||||
with 10 full attention layers and 20 sliding window attention layers can be
|
||||
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
|
||||
The KVCacheManager allocates different block tables for each of the 3 layers
|
||||
in the pattern, and repeats each of them 10 times to generate the
|
||||
block_table for the 30 layers in the model.
|
||||
Therefore, we can group the layers in the model into 3 kv_cache_groups, each
|
||||
of which contains 10 layers in the model.
|
||||
The KVCacheManager allocates the block_table for each group based on its
|
||||
kv_cache spec, and the model runner applies the block table to each layer
|
||||
in the group.
|
||||
For example:
|
||||
1. A model only uses full attention. The pattern is
|
||||
(num_hidden_layers * full), so there is only one group and the block table
|
||||
is shared by all layers. It is already handled by
|
||||
`_get_kv_cache_config_uniform_type`.
|
||||
2. A model with 10 full attention layers and 20 sliding window
|
||||
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
|
||||
there are 3 kv_cache_groups, each of which represents 10 layers.
|
||||
|
||||
To simplify the implementation, we make the following assumptions:
|
||||
1. Physical memory per block: Must be the same across all KV cache groups.
|
||||
Breaking this assumption is non-trivial due to memory fragmentation concerns
|
||||
when allocating blocks of different sizes.
|
||||
2. Tokens per block (block_size): Currently, we directly use
|
||||
`CacheConfig.block_size` for all layers. It can be extended to vary by KV
|
||||
cache group, but within each KV cache group, all layers must share the same
|
||||
block size.
|
||||
3. Physical memory per token per layer: This property is decided by model
|
||||
config. Currently we only support models that have the same physical memory
|
||||
per token per layer for all layers. Can be relaxed with a simple extension,
|
||||
but still need to keep physical memory per block the same for all groups.
|
||||
4. Number of layers per group: Currently assumed the same for all layers.
|
||||
Can be relaxed with a simple extension, but still need to keep physical
|
||||
memory per block the same for all groups.
|
||||
5. Attention type within groups: All layers in a group must share the same
|
||||
attention type. One exception is that, when
|
||||
`--disable-hybrid-kv-cache-manager` is true, the single group for full
|
||||
attention layers may also include attention layers using sliding window or
|
||||
LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details.
|
||||
6. Support for multiple attention types: The design for most components is
|
||||
general to an arbitrary number of attention types. But
|
||||
`find_longest_cache_hit` only supports one attention type or two
|
||||
types of full-attention plus exactly one another type. The general
|
||||
implementation of this function is feasible but we don't know how to
|
||||
implement it cleanly yet.
|
||||
|
||||
As we assume tokens per block, physical memory per token per layer, and
|
||||
number of layers per group are the same now, we can ensure that physical
|
||||
memory per block is the same for all groups.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The KVCacheSpec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
# Group all layers by type_id.
|
||||
# E.g., 2 full attention layers and 3 sliding window attention layers,
|
||||
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
|
||||
same_type_layers: dict[str, list[str]] = defaultdict(list)
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
same_type_layers[layer_spec.type_id].append(layer_name)
|
||||
|
||||
# Split each group into smaller groups, to make the number of layers in each
|
||||
# group identical. Add padding to the last group of each type if necessary.
|
||||
# E.g., (full.0, full.1), (sw.0, sw.1, sw.2)
|
||||
# split to 3 groups with 2 layers each:
|
||||
# (full.0, full.1), (sw.0, sw.1), (sw.2, padding).
|
||||
# FIXME(Chen): At the moment of writing this code (2025-06-02), all
|
||||
# open-source hybrid model follows a n:1 pattern between different attention
|
||||
# types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and
|
||||
# full), so we can use the "1" in the n:1 pattern as the group size, which
|
||||
# is the minimum number of layers among all attention types. Need a better
|
||||
# strategy if we want to support more complex patterns (e.g., 20 full + 30
|
||||
# sw, where the group size should be 10).
|
||||
group_size = min([len(layers) for layers in same_type_layers.values()])
|
||||
grouped_layers = []
|
||||
for layers in same_type_layers.values():
|
||||
num_padding_layers = group_size - len(layers) % group_size
|
||||
if num_padding_layers != group_size:
|
||||
logger.warning(
|
||||
"Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa
|
||||
num_padding_layers,
|
||||
num_padding_layers / len(layers) * 100,
|
||||
)
|
||||
for i in range(0, len(layers), group_size):
|
||||
grouped_layers.append(layers[i:i + group_size])
|
||||
kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec,
|
||||
grouped_layers)
|
||||
|
||||
# Determine how model runners should initialize the KV cache tensors.
|
||||
# We will have group_size memory pools, each is shared by one layer from
|
||||
# each group. As layers of different groups have different block table,
|
||||
# they will use different parts of the shared Tensor.
|
||||
# The memory layout in the example will be:
|
||||
# full.0, sw.0, sw.2: share a Tensor with size=available_memory//2
|
||||
# full.1, sw.1: share another Tensor with size=available_memory//2
|
||||
page_size = get_uniform_page_size(kv_cache_spec)
|
||||
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
|
||||
page_size)
|
||||
per_memory_pool_size = page_size * num_blocks
|
||||
kv_cache_tensors = []
|
||||
for i in range(group_size):
|
||||
shared_by = []
|
||||
for j in range(len(kv_cache_groups)):
|
||||
if i < len(grouped_layers[j]):
|
||||
shared_by.append(grouped_layers[j][i])
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by))
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=kv_cache_tensors,
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
# Print the KV cache size and maximum concurrency.
|
||||
num_tokens = num_blocks // len(
|
||||
grouped_layers) * vllm_config.cache_config.block_size
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
max_concurrency = get_max_concurrency_for_kv_cache_config(
|
||||
vllm_config, kv_cache_config)
|
||||
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
||||
max_model_len_str, max_concurrency)
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
"""
|
||||
This function tries to convert the KV cache specs to one type if the model
|
||||
is a hybrid model with multiple type of KV cache. It will convert all
|
||||
SlidingWindowSpec to FullAttentionSpec if both types are present.
|
||||
|
||||
Args:
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
"""
|
||||
|
||||
def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
type_ids = set(layer_spec.type_id
|
||||
for layer_spec in kv_cache_spec.values())
|
||||
return len(type_ids) > 1
|
||||
|
||||
if not is_hybrid(kv_cache_spec):
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Hybrid KV cache manager is disabled for this hybrid model, "
|
||||
"This means we do not enable any optimizations for saving KV cache "
|
||||
"memory (e.g., dropping the KV cache outside the sliding window). "
|
||||
"The compute of layers like sliding window is still saved.")
|
||||
|
||||
has_full_attention = any(
|
||||
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
|
||||
has_sliding_window = any(
|
||||
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
|
||||
if has_full_attention and has_sliding_window:
|
||||
for layer_name, spec in kv_cache_spec.items():
|
||||
if isinstance(spec, SlidingWindowSpec):
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=spec.block_size,
|
||||
num_kv_heads=spec.num_kv_heads,
|
||||
head_size=spec.head_size,
|
||||
dtype=spec.dtype,
|
||||
use_mla=spec.use_mla,
|
||||
sliding_window=spec.sliding_window,
|
||||
)
|
||||
|
||||
if is_hybrid(kv_cache_spec):
|
||||
raise ValueError("Hybrid KV cache manager is disabled but failed to "
|
||||
"convert the KV cache specs to one unified type.")
|
||||
|
||||
|
||||
def get_kv_cache_config(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int,
|
||||
) -> KVCacheConfig:
|
||||
"""
|
||||
Generates the KV cache configuration for a model.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The generated KVCacheConfigs
|
||||
"""
|
||||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
|
||||
|
||||
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
|
||||
unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
|
||||
if is_kv_cache_type_uniform(kv_cache_spec):
|
||||
# KV cache of all layers are the same, which is true for
|
||||
# most models. Allocate the same amount of memory for
|
||||
# each layer.
|
||||
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
elif is_kv_cache_page_size_uniform(kv_cache_spec):
|
||||
# Model contains multiple attention types, but KV cache of all layers
|
||||
# have the same physical memory per block per layer. Split the layers
|
||||
# into groups with the same number of layers, and thus same total page
|
||||
# size.
|
||||
return _get_kv_cache_config_uniform_page_size(vllm_config,
|
||||
kv_cache_spec,
|
||||
available_memory)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
|
||||
"""
|
||||
Make the KV cache configurations for each worker consistent, so that all
|
||||
workers can be controlled by the same KVCacheManager.
|
||||
This function verifies that the layer group of each worker are the same,
|
||||
and changes the num_blocks of each worker to the smallest among all workers.
|
||||
|
||||
Args:
|
||||
kv_cache_configs: The KV cache configurations for each worker. Will be
|
||||
in-place modified to make them consistent.
|
||||
"""
|
||||
|
||||
# Sort the kv cache groups by the type_id of their KV cache spec.
|
||||
# This can avoid the inconsistency caused by the order of groups.
|
||||
for kv_cache_config in kv_cache_configs:
|
||||
kv_cache_config.kv_cache_groups.sort(
|
||||
key=lambda x: x.kv_cache_spec.type_id)
|
||||
|
||||
# Verify that the groups of each rank are the same.
|
||||
for kv_cache_config in kv_cache_configs[1:]:
|
||||
for group_rank_0, group_rank_i in zip(
|
||||
kv_cache_configs[0].kv_cache_groups,
|
||||
kv_cache_config.kv_cache_groups):
|
||||
assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec
|
||||
|
||||
# Change the num_blocks of each rank to the smallest among all ranks. We
|
||||
# do not need to shrink the tensor size because it is valid to only use the
|
||||
# first `num_blocks` blocks of the tensor.
|
||||
min_num_blocks = min(kv_cache_config.num_blocks
|
||||
for kv_cache_config in kv_cache_configs)
|
||||
for kv_cache_config in kv_cache_configs:
|
||||
kv_cache_config.num_blocks = min_num_blocks
|
||||
|
||||
return kv_cache_configs
|
||||
0
vllm/v1/core/sched/__init__.py
Normal file
0
vllm/v1/core/sched/__init__.py
Normal file
150
vllm/v1/core/sched/interface.py
Normal file
150
vllm/v1/core/sched/interface.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.engine import EngineCoreOutputs
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
|
||||
class SchedulerInterface(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def schedule(self) -> "SchedulerOutput":
|
||||
"""Schedule the requests to process in this scheduling step.
|
||||
|
||||
The scheduling decision is made at the iteration level. Each scheduling
|
||||
step corresponds to a single forward pass of the model. Therefore, this
|
||||
method is called repeatedly by a busy loop in the engine.
|
||||
|
||||
Essentially, the scheduler produces a dictionary of {req_id: num_tokens}
|
||||
that specifies how many tokens to process for each request in this
|
||||
scheduling step. For example, num_tokens can be as large as the number
|
||||
of prompt tokens for new requests, or it can be 1 for the requests that
|
||||
are auto-regressively generating new tokens one by one. Otherwise, it
|
||||
can be somewhere in between in case of chunked prefills, prefix caching,
|
||||
speculative decoding, etc.
|
||||
|
||||
Additionally, the scheduler also returns useful data about each request
|
||||
or the batch as a whole. The model runner will use this information in
|
||||
preparing inputs to the model.
|
||||
|
||||
Returns:
|
||||
A SchedulerOutput object containing information about the scheduled
|
||||
requests.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_from_output(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
model_runner_output: "ModelRunnerOutput",
|
||||
) -> dict[int, "EngineCoreOutputs"]:
|
||||
"""Update the scheduler state based on the model runner output.
|
||||
|
||||
This method is called after the model runner has processed the scheduled
|
||||
requests. The model runner output includes generated token ids, draft
|
||||
token ids for next step, etc. The scheduler uses this information to
|
||||
update its states, checks the finished requests, and returns the output
|
||||
for each request.
|
||||
|
||||
Returns:
|
||||
A dict of client index to EngineCoreOutputs object containing the
|
||||
outputs for each request originating from that client.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, request: "Request") -> None:
|
||||
"""Add a new request to the scheduler's internal queue.
|
||||
|
||||
Args:
|
||||
request: The new request being added.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def finish_requests(
|
||||
self,
|
||||
request_ids: Union[str, Iterable[str]],
|
||||
finished_status: "RequestStatus",
|
||||
) -> None:
|
||||
"""Finish the requests in the scheduler's internal queue. If the request
|
||||
is not in the queue, this method will do nothing.
|
||||
|
||||
This method is called in two cases:
|
||||
1. When the request is aborted by the client.
|
||||
2. When the frontend process detects a stop string of the request after
|
||||
de-tokenizing its generated tokens.
|
||||
|
||||
Args:
|
||||
request_ids: A single or a list of request IDs.
|
||||
finished_status: The finished status of the given requests.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Number of unfinished requests in the scheduler's internal queue."""
|
||||
raise NotImplementedError
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
"""Returns True if there are unfinished requests in the scheduler's
|
||||
internal queue."""
|
||||
return self.get_num_unfinished_requests() > 0
|
||||
|
||||
@abstractmethod
|
||||
def has_finished_requests(self) -> bool:
|
||||
"""Returns True if there are finished requests that need to be cleared.
|
||||
NOTE: This is different from `not self.has_unfinished_requests()`.
|
||||
|
||||
The scheduler maintains an internal list of the requests finished in the
|
||||
previous step. This list is returned from the next call to schedule(),
|
||||
to be sent to the model runner in the next step to clear cached states
|
||||
for these finished requests.
|
||||
|
||||
This method checks if this internal list of finished requests is
|
||||
non-empty. This information is useful for DP attention.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def has_requests(self) -> bool:
|
||||
"""Returns True if there are unfinished requests, or finished requests
|
||||
not yet returned in SchedulerOutputs."""
|
||||
return self.has_unfinished_requests() or self.has_finished_requests()
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset the prefix cache for KV cache.
|
||||
|
||||
This is particularly required when the model weights are live-updated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_request_counts(self) -> tuple[int, int]:
|
||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def make_stats(self) -> Optional["SchedulerStats"]:
|
||||
"""Make a SchedulerStats object for logging.
|
||||
|
||||
The SchedulerStats object is created for every scheduling step.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the scheduler."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
|
||||
return None
|
||||
154
vllm/v1/core/sched/output.py
Normal file
154
vllm/v1/core/sched/output.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewRequestData:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_inputs: list[MultiModalKwargs]
|
||||
mm_hashes: list[str]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: SamplingParams
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
lora_request: Optional[LoRARequest]
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> NewRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
mm_inputs=request.mm_inputs,
|
||||
mm_hashes=request.mm_hashes,
|
||||
mm_positions=request.mm_positions,
|
||||
sampling_params=request.sampling_params,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
lora_request=request.lora_request,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids={self.prompt_token_ids},"
|
||||
f"mm_inputs={self.mm_inputs},"
|
||||
f"mm_hashes={self.mm_hashes},"
|
||||
f"mm_positions={self.mm_positions},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request}"
|
||||
")")
|
||||
|
||||
# Version of __repr__ with the prompt data obfuscated
|
||||
def anon_repr(self):
|
||||
return (f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
|
||||
f"mm_inputs={self.mm_inputs},"
|
||||
f"mm_hashes={self.mm_hashes},"
|
||||
f"mm_positions={self.mm_positions},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request}"
|
||||
")")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestData:
|
||||
|
||||
req_id: str
|
||||
# If resumed_from_preemption is False, new_block_ids will be appended to
|
||||
# the request's block IDs. If True, new_block_ids will be used as the
|
||||
# request's block IDs instead of appending to the existing block IDs.
|
||||
resumed_from_preemption: bool
|
||||
new_token_ids: list[int]
|
||||
new_block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
resumed_from_preemption: bool,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: tuple[list[int], ...],
|
||||
) -> CachedRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
resumed_from_preemption=resumed_from_preemption,
|
||||
new_token_ids=new_token_ids,
|
||||
new_block_ids=new_block_ids,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerOutput:
|
||||
|
||||
# list of the requests that are scheduled for the first time.
|
||||
# We cache the request's data in each worker process, so that we don't
|
||||
# need to re-send it every scheduling step.
|
||||
scheduled_new_reqs: list[NewRequestData]
|
||||
# list of the requests that have been scheduled before.
|
||||
# Since the request's data is already cached in the worker processes,
|
||||
# we only send the diff to minimize the communication cost.
|
||||
scheduled_cached_reqs: list[CachedRequestData]
|
||||
|
||||
# req_id -> num_scheduled_tokens
|
||||
# Number of tokens scheduled for each request.
|
||||
num_scheduled_tokens: dict[str, int]
|
||||
# Total number of tokens scheduled for all requests.
|
||||
# Equal to sum(num_scheduled_tokens.values())
|
||||
total_num_scheduled_tokens: int
|
||||
# req_id -> spec_token_ids
|
||||
# If a request does not have any spec decode tokens, it will not be
|
||||
# included in the dictionary.
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]]
|
||||
# req_id -> encoder input indices that need processing.
|
||||
# E.g., if a request has [0, 1], it could mean the vision encoder needs
|
||||
# to process that the request's 0-th and 1-th images in the current step.
|
||||
scheduled_encoder_inputs: dict[str, list[int]]
|
||||
# Number of common prefix blocks for all requests in each KV cache group.
|
||||
# This can be used for cascade attention.
|
||||
num_common_prefix_blocks: list[int]
|
||||
|
||||
# Request IDs that are finished in between the previous and the current
|
||||
# steps. This is used to notify the workers about the finished requests
|
||||
# so that they can free the cached states for those requests.
|
||||
finished_req_ids: set[str]
|
||||
# list of (req_id, encoder_input_index) tuples.
|
||||
# Used to free the encoder cache.
|
||||
free_encoder_input_ids: list[tuple[str, int]]
|
||||
|
||||
# Dict of request ids to their index within the batch
|
||||
# for filling the next token bitmask
|
||||
structured_output_request_ids: dict[str, int]
|
||||
# the bitmask for the whole batch
|
||||
grammar_bitmask: Optional[npt.NDArray[np.int32]]
|
||||
|
||||
# KV Cache Connector metadata.
|
||||
kv_connector_metadata: Optional[KVConnectorMetadata] = None
|
||||
1044
vllm/v1/core/sched/scheduler.py
Normal file
1044
vllm/v1/core/sched/scheduler.py
Normal file
File diff suppressed because it is too large
Load Diff
23
vllm/v1/core/sched/utils.py
Normal file
23
vllm/v1/core/sched/utils.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
|
||||
def check_stop(request: Request, max_model_len: int) -> bool:
|
||||
if (request.num_tokens >= max_model_len
|
||||
or request.num_output_tokens >= request.max_tokens):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
return True
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
last_token_id = request.output_token_ids[-1]
|
||||
if (not sampling_params.ignore_eos
|
||||
and last_token_id == request.eos_token_id):
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
return True
|
||||
|
||||
if last_token_id in (sampling_params.stop_token_ids or ()):
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
request.stop_reason = last_token_id
|
||||
return True
|
||||
return False
|
||||
403
vllm/v1/core/single_type_kv_cache_manager.py
Normal file
403
vllm/v1/core/single_type_kv_cache_manager.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Callable
|
||||
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
SlidingWindowSpec)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class SingleTypeKVCacheManager(ABC):
|
||||
"""
|
||||
An abstract base class for a manager that handle the kv cache management
|
||||
logic of one specific type of attention layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
kv_cache_group_id: int,
|
||||
caching_hash_fn: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the SpecializedManager.
|
||||
Args:
|
||||
kv_cache_spec: The kv_cache_spec for this manager.
|
||||
block_pool: The block pool.
|
||||
kv_cache_group_id: The id of the kv cache group of this manager.
|
||||
caching_hash_fn: The caching hash function.
|
||||
"""
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
|
||||
# Mapping from request ID to blocks to track the blocks allocated
|
||||
# for each request, so that we can free the blocks when the request
|
||||
# is finished.
|
||||
self.req_to_blocks: defaultdict[str,
|
||||
list[KVCacheBlock]] = defaultdict(list)
|
||||
|
||||
# {req_id: The number of cached blocks for this given request}
|
||||
# This is used to track the number of cached blocks for each request.
|
||||
# This is only used to track the RUNNING requests, we do not track the
|
||||
# data for reempted ones.
|
||||
self.num_cached_block: dict[str, int] = {}
|
||||
|
||||
self.caching_hash_fn = caching_hash_fn
|
||||
self.kv_cache_group_id = kv_cache_group_id
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self, request_id: str, num_tokens: int,
|
||||
new_computed_blocks: list[KVCacheBlock]) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
"""
|
||||
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_new_blocks = (num_required_blocks - len(new_computed_blocks) -
|
||||
len(self.req_to_blocks[request_id]))
|
||||
# If a computed block of a request is an eviction candidate (in the
|
||||
# free queue and ref_cnt == 0), it will be changed from a free block
|
||||
# to a computed block when the request is allocated, so we also count
|
||||
# it as needed to be allocated.
|
||||
num_evictable_computed_blocks = sum(
|
||||
blk.ref_cnt == 0 and not blk.is_null
|
||||
for blk in new_computed_blocks)
|
||||
return num_new_blocks + num_evictable_computed_blocks
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str,
|
||||
new_computed_blocks: list[KVCacheBlock]) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix cache.
|
||||
"""
|
||||
if request_id not in self.num_cached_block:
|
||||
# A new request.
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
assert len(req_blocks) == 0
|
||||
req_blocks.extend(new_computed_blocks)
|
||||
self.num_cached_block[request_id] = len(new_computed_blocks)
|
||||
else:
|
||||
# A running request. Should not have new computed blocks.
|
||||
assert len(new_computed_blocks) == 0
|
||||
|
||||
def allocate_new_blocks(self, request_id: str,
|
||||
num_tokens: int) -> list[KVCacheBlock]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_new_blocks = num_required_blocks - len(req_blocks)
|
||||
if num_new_blocks <= 0:
|
||||
return []
|
||||
else:
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
return new_blocks
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||
num_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
block_hashes: The block hashes of the request.
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
num_cached_blocks = self.num_cached_block[request.request_id]
|
||||
num_full_blocks = num_tokens // self.block_size
|
||||
|
||||
self.block_pool.cache_full_blocks(
|
||||
request=request,
|
||||
blocks=self.req_to_blocks[request.request_id],
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks,
|
||||
block_size=self.block_size,
|
||||
kv_cache_group_id=self.kv_cache_group_id,
|
||||
hash_fn=self.caching_hash_fn,
|
||||
)
|
||||
|
||||
self.num_cached_block[request.request_id] = num_full_blocks
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
Free the blocks for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
"""
|
||||
# Default to [] in case a request is freed (aborted) before alloc.
|
||||
req_blocks = self.req_to_blocks.pop(request_id, [])
|
||||
|
||||
# Free blocks in reverse order so that the tail blocks are
|
||||
# freed first.
|
||||
ordered_blocks = reversed(req_blocks)
|
||||
|
||||
self.block_pool.free_blocks(ordered_blocks)
|
||||
self.num_cached_block.pop(request_id, None)
|
||||
|
||||
@abstractmethod
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> int:
|
||||
"""
|
||||
Get the number of common prefix blocks for a request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
block_hashes: The block hashes of the request.
|
||||
|
||||
Returns:
|
||||
The number of common prefix blocks.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Get the longest cache hit prefix of the blocks that is not longer than
|
||||
`max_length`. The prefix should be a common prefix hit for all the
|
||||
kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
|
||||
return an empty list.
|
||||
If eagle is enabled, drop the last matched block to force recompute the
|
||||
last block to get the required hidden states for eagle drafting head.
|
||||
Need to be customized for each attention type.
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_length: The maximum length of the cache hit prefix.
|
||||
kv_cache_group_ids: The ids of the kv cache groups.
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
|
||||
Returns:
|
||||
A list of cached blocks with skipped blocks replaced by null block
|
||||
for each kv cache group in `kv_cache_group_ids`.
|
||||
Return a list of length `len(kv_cache_group_ids)`, where the i-th
|
||||
element is a list of cached blocks for the i-th kv cache group
|
||||
in `kv_cache_group_ids`.
|
||||
For example, sliding window manager should return a list like
|
||||
([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
|
||||
and sliding window 8 and len(kv_cache_group_ids) = 1.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks` and free the
|
||||
blocks. The removed blocks should be replaced by null_block.
|
||||
Need to be customized for each attention type.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, FullAttentionSpec), (
|
||||
"FullAttentionManager can only be used for full attention groups")
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids)))
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
for i, block_hash in zip(range(max_num_blocks), block_hashes):
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
# not computed yet for sure.
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hash, kv_cache_group_ids):
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed.append(cached)
|
||||
else:
|
||||
break
|
||||
if use_eagle and computed_blocks[0]:
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
# No need to remove blocks for full attention.
|
||||
pass
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> int:
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
num_common_blocks = 0
|
||||
for block in blocks:
|
||||
if block.ref_cnt == num_running_requests:
|
||||
num_common_blocks += 1
|
||||
else:
|
||||
break
|
||||
return num_common_blocks
|
||||
|
||||
|
||||
class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
|
||||
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
|
||||
**kwargs) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, **kwargs)
|
||||
self.sliding_window = kv_cache_spec.sliding_window
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
||||
"SlidingWindowManager can only be used for sliding window groups")
|
||||
|
||||
# The number of contiguous blocks needed for prefix cache hit.
|
||||
# -1 since the input token itself is also included in the window
|
||||
sliding_window_contiguous_blocks = cdiv(
|
||||
kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size)
|
||||
if use_eagle:
|
||||
# Need to drop the last matched block if eagle is enabled. For
|
||||
# sliding window layer, we achieve this by increasing the number of
|
||||
# contiguous blocks needed for prefix cache hit by one and dropping
|
||||
# the last matched block.
|
||||
sliding_window_contiguous_blocks += 1
|
||||
|
||||
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
||||
# optimize the time complexity from O(max_num_blocks) to
|
||||
# O(max_num_blocks / sliding_window_contiguous_blocks +
|
||||
# sliding_window_contiguous_blocks),
|
||||
# which is good for low cache hit rate scenarios.
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
computed_blocks = tuple([block_pool.null_block] * max_num_blocks
|
||||
for _ in range(len(kv_cache_group_ids)))
|
||||
num_contiguous_blocks = 0
|
||||
match_found = False
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(max_num_blocks - 1, -1, -1):
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids):
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed[i] = cached
|
||||
num_contiguous_blocks += 1
|
||||
if num_contiguous_blocks >= sliding_window_contiguous_blocks:
|
||||
# Trim the trailing blocks.
|
||||
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
||||
# when sliding_window_contiguous_blocks=2.
|
||||
for computed in computed_blocks:
|
||||
del computed[i + num_contiguous_blocks:]
|
||||
match_found = True
|
||||
break
|
||||
else:
|
||||
num_contiguous_blocks = 0
|
||||
if not match_found:
|
||||
# The first `num_contiguous_blocks` is a cache hit even if
|
||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||
for computed in computed_blocks:
|
||||
del computed[num_contiguous_blocks:]
|
||||
if use_eagle and computed_blocks[0]:
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
# Remove the blocks that are no longer be in the sliding window and
|
||||
# skipped during the attention computation.
|
||||
last_useful_token = num_computed_tokens - self.sliding_window + 1
|
||||
last_useful_block = last_useful_token // self.block_size
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
removed_blocks: list[KVCacheBlock] = []
|
||||
for i in range(last_useful_block - 1, -1, -1):
|
||||
if blocks[i] == self._null_block:
|
||||
# If the block is already a null block, the blocks before it
|
||||
# should also have been set to null blocks by the previous calls
|
||||
# to this function.
|
||||
break
|
||||
removed_blocks.append(blocks[i])
|
||||
blocks[i] = self._null_block
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> int:
|
||||
"""
|
||||
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
|
||||
So it's not correct to count ref_cnt like FullAttentionManager. Return
|
||||
0 here for correctness. Need to support cascade attention + sliding
|
||||
window in the future.
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
}
|
||||
|
||||
|
||||
def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec,
|
||||
**kwargs) -> SingleTypeKVCacheManager:
|
||||
manager_class = spec_manager_map[type(kv_cache_spec)]
|
||||
manager = manager_class(kv_cache_spec, **kwargs)
|
||||
return manager
|
||||
173
vllm/v1/engine/__init__.py
Normal file
173
vllm/v1/engine/__init__.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
# These are possible values of RequestOutput.finish_reason,
|
||||
# so form part of the external API.
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort")
|
||||
|
||||
|
||||
class FinishReason(enum.IntEnum):
|
||||
"""
|
||||
Reason a request finished - stop, length, or abort.
|
||||
|
||||
Int rather than Str for more compact serialization.
|
||||
|
||||
stop - a stop string was emitted
|
||||
length - max_tokens was consumed, or max_model_len was reached
|
||||
abort - aborted for another reason
|
||||
|
||||
"""
|
||||
STOP = 0
|
||||
LENGTH = 1
|
||||
ABORT = 2
|
||||
|
||||
def __str__(self):
|
||||
return FINISH_REASON_STRINGS[self.value]
|
||||
|
||||
|
||||
class EngineCoreRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
request_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
|
||||
mm_hashes: Optional[list[str]]
|
||||
mm_placeholders: Optional[list[PlaceholderRange]]
|
||||
sampling_params: SamplingParams
|
||||
eos_token_id: Optional[int]
|
||||
arrival_time: float
|
||||
lora_request: Optional[LoRARequest]
|
||||
cache_salt: Optional[str]
|
||||
data_parallel_rank: Optional[int]
|
||||
|
||||
# Index of the client, used to ensure outputs are sent back to the same
|
||||
# client for this request when scaling out the front-end.
|
||||
client_index: int = 0
|
||||
|
||||
# Used in DP case to indicate which wave of requests this is expected to
|
||||
# belong to, to cover a race condition where the request is sent before
|
||||
# a wave finished notification is received.
|
||||
current_wave: int = 0
|
||||
|
||||
|
||||
class EngineCoreEventType(enum.IntEnum):
|
||||
"""The type of engine core request event."""
|
||||
QUEUED = 1
|
||||
SCHEDULED = 2
|
||||
PREEMPTED = 3
|
||||
|
||||
|
||||
class EngineCoreEvent(msgspec.Struct):
|
||||
"""A timestamped engine core event associated with a request.
|
||||
|
||||
The timestamp is a monotonic timestamps and is used for by the engine
|
||||
frontend to calculate intervals between engine core events. These
|
||||
timestamps should not be compared with timestamps from other processes.
|
||||
"""
|
||||
type: EngineCoreEventType
|
||||
timestamp: float
|
||||
|
||||
@classmethod
|
||||
def new_event(cls,
|
||||
event_type: EngineCoreEventType,
|
||||
timestamp: Optional[float] = None) -> "EngineCoreEvent":
|
||||
timestamp = time.monotonic() if timestamp is None else timestamp
|
||||
return cls(event_type, timestamp)
|
||||
|
||||
|
||||
class EngineCoreOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
request_id: str
|
||||
new_token_ids: list[int]
|
||||
|
||||
new_logprobs: Optional[LogprobsLists] = None
|
||||
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
|
||||
|
||||
finish_reason: Optional[FinishReason] = None
|
||||
stop_reason: Union[int, str, None] = None
|
||||
events: Optional[list[EngineCoreEvent]] = None
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
# The number of tokens with prefix cache hits.
|
||||
num_cached_tokens: int = 0
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self.finish_reason is not None
|
||||
|
||||
|
||||
class UtilityOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
call_id: int
|
||||
|
||||
# Non-None implies the call failed, result should be None.
|
||||
failure_message: Optional[str] = None
|
||||
result: Any = None
|
||||
|
||||
|
||||
class EngineCoreOutputs(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
#NOTE(Nick): We could consider ways to make this more compact,
|
||||
# e.g. columnwise layout
|
||||
|
||||
engine_index: int = 0
|
||||
|
||||
# [num_reqs]
|
||||
outputs: list[EngineCoreOutput] = []
|
||||
scheduler_stats: Optional[SchedulerStats] = None
|
||||
timestamp: float = 0.0
|
||||
|
||||
utility_output: Optional[UtilityOutput] = None
|
||||
finished_requests: Optional[set[str]] = None
|
||||
|
||||
# In DP case, used to signal that the current wave of requests
|
||||
# has finished and the engines are paused.
|
||||
wave_complete: Optional[int] = None
|
||||
# In DP case, used to signal that a request was received for an
|
||||
# "old" wave, so the next wave needs to be started in other engines.
|
||||
start_wave: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp == 0.0:
|
||||
self.timestamp = time.monotonic()
|
||||
|
||||
|
||||
class EngineCoreRequestType(enum.Enum):
|
||||
"""
|
||||
Request types defined as hex byte strings, so it can be sent over sockets
|
||||
without separate encoding step.
|
||||
"""
|
||||
ADD = b'\x00'
|
||||
ABORT = b'\x01'
|
||||
START_DP_WAVE = b'\x02'
|
||||
UTILITY = b'\x03'
|
||||
# Sentinel used within EngineCoreProc.
|
||||
EXECUTOR_FAILED = b'\x04'
|
||||
558
vllm/v1/engine/async_llm.py
Normal file
558
vllm/v1/engine/async_llm.py
Normal file
@@ -0,0 +1,558 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from copy import copy
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, cdiv
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||
RequestOutputCollector)
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
|
||||
setup_default_loggers)
|
||||
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncLLM(EngineClient):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = True,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Create an AsyncLLM.
|
||||
|
||||
Args:
|
||||
vllm_config: global configuration.
|
||||
executor_class: an Executor impl, e.g. MultiprocExecutor.
|
||||
log_stats: Whether to log stats.
|
||||
usage_context: Usage context of the LLM.
|
||||
mm_registry: Multi-modal registry.
|
||||
use_cached_outputs: Whether to use cached outputs.
|
||||
log_requests: Whether to log requests.
|
||||
start_engine_loop: Whether to start the engine loop.
|
||||
stat_loggers: customized stat loggers for the engine.
|
||||
If not provided, default stat loggers will be used.
|
||||
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
|
||||
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
# Ensure we can serialize custom transformer configs
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.log_requests = log_requests
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Set up stat loggers; independent set for each DP rank.
|
||||
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
|
||||
vllm_config=vllm_config,
|
||||
log_stats=self.log_stats,
|
||||
engine_num=vllm_config.parallel_config.data_parallel_size,
|
||||
custom_stat_loggers=stat_loggers,
|
||||
)
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(
|
||||
vllm_config=vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
mm_registry=mm_registry,
|
||||
)
|
||||
|
||||
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(self.tokenizer,
|
||||
log_stats=self.log_stats)
|
||||
|
||||
# EngineCore (starts the engine in background process).
|
||||
|
||||
self.engine_core = EngineCoreClient.make_async_mp_client(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
client_addresses=client_addresses,
|
||||
client_index=client_index,
|
||||
)
|
||||
if self.stat_loggers:
|
||||
for stat_logger in self.stat_loggers[0]:
|
||||
stat_logger.log_engine_initialized()
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
try:
|
||||
# Start output handler eagerly if we are in the asyncio eventloop.
|
||||
asyncio.get_running_loop()
|
||||
self._run_output_handler()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
disable_log_requests: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0,
|
||||
) -> "AsyncLLM":
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
# Create the LLMEngine.
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
start_engine_loop=start_engine_loop,
|
||||
stat_loggers=stat_loggers,
|
||||
log_requests=not disable_log_requests,
|
||||
log_stats=not disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
client_addresses=client_addresses,
|
||||
client_index=client_index,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
) -> "AsyncLLM":
|
||||
"""Create an AsyncLLM from the EngineArgs."""
|
||||
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
# Create the AsyncLLM.
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown, cleaning up the background proc and IPC."""
|
||||
|
||||
shutdown_prometheus()
|
||||
|
||||
if engine_core := getattr(self, "engine_core", None):
|
||||
engine_core.shutdown()
|
||||
|
||||
if handler := getattr(self, "output_handler", None):
|
||||
handler.cancel()
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> RequestOutputCollector:
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
if self.errored:
|
||||
raise EngineDeadError()
|
||||
|
||||
assert isinstance(params, SamplingParams), \
|
||||
"Pooling is not supported in V1"
|
||||
|
||||
# Create a new output collector for the request.
|
||||
queue = RequestOutputCollector(output_kind=params.output_kind)
|
||||
|
||||
# Convert Input --> Request.
|
||||
prompt_str, request = self.processor.process_inputs(
|
||||
request_id, prompt, params, arrival_time, lora_request,
|
||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||
priority, data_parallel_rank)
|
||||
|
||||
if params.n == 1:
|
||||
await self._add_request(request, prompt_str, None, 0, queue)
|
||||
return queue
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_request = ParentRequest(request_id, params)
|
||||
for idx in range(params.n):
|
||||
request_id, params = parent_request.get_child_info(idx)
|
||||
child_request = request if idx == params.n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
await self._add_request(child_request, prompt_str, parent_request,
|
||||
idx, queue)
|
||||
return queue
|
||||
|
||||
async def _add_request(self, request: EngineCoreRequest,
|
||||
prompt: Optional[str],
|
||||
parent_req: Optional[ParentRequest], index: int,
|
||||
queue: RequestOutputCollector):
|
||||
|
||||
# Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, prompt, parent_req, index,
|
||||
queue)
|
||||
|
||||
# Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
||||
# TODO: we should support multiple prompts in one call, as you
|
||||
# can do with LLM.generate. So that for multi-prompt completion
|
||||
# requests we don't need to send multiple messages to core proc,
|
||||
# and so we don't need multiple streams which then get
|
||||
# re-multiplexed in the API server anyhow.
|
||||
async def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""
|
||||
Main function called by the API server to kick off a request
|
||||
* 1) Making an AsyncStream corresponding to the Request.
|
||||
* 2) Processing the Input.
|
||||
* 3) Adding the Request to the Detokenizer.
|
||||
* 4) Adding the Request to the EngineCore (separate process).
|
||||
|
||||
A separate output_handler loop runs in a background AsyncIO task,
|
||||
pulling outputs from EngineCore and putting them into the
|
||||
per-request AsyncStream.
|
||||
|
||||
The caller of generate() iterates the returned AsyncGenerator,
|
||||
returning the RequestOutput back to the caller.
|
||||
"""
|
||||
|
||||
try:
|
||||
# We start the output_handler on the first call to generate() so
|
||||
# we can call __init__ before the event loop, which enables us
|
||||
# to handle startup failure gracefully in the OpenAI server.
|
||||
self._run_output_handler()
|
||||
|
||||
q = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
# The output_handler task pushes items into the queue.
|
||||
# This task pulls from the queue and yields to caller.
|
||||
finished = False
|
||||
while not finished:
|
||||
# Note: drain queue without await if possible (avoids
|
||||
# task switching under load which helps performance).
|
||||
out = q.get_nowait() or await q.get()
|
||||
|
||||
# Note: both OutputProcessor and EngineCore handle their
|
||||
# own request cleanup based on finished.
|
||||
finished = out.finished
|
||||
yield out
|
||||
|
||||
# If the request is disconnected by the client, generate()
|
||||
# is cancelled or the generator is garbage collected. So,
|
||||
# we abort the request if we end up here.
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s aborted.", request_id)
|
||||
raise
|
||||
|
||||
# Engine is dead. Do not abort since we shut down.
|
||||
except EngineDeadError:
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (engine dead).", request_id)
|
||||
raise
|
||||
|
||||
# Request validation error.
|
||||
except ValueError:
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (bad request).", request_id)
|
||||
raise
|
||||
|
||||
# Unexpected error in the generate() task (possibly recoverable).
|
||||
except Exception as e:
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed.", request_id)
|
||||
raise EngineGenerateError() from e
|
||||
|
||||
def _run_output_handler(self):
|
||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||
|
||||
if self.output_handler is not None:
|
||||
return
|
||||
|
||||
# Ensure that the task doesn't have a circular ref back to the AsyncLLM
|
||||
# object, or else it won't be garbage collected and cleaned up properly.
|
||||
engine_core = self.engine_core
|
||||
output_processor = self.output_processor
|
||||
log_stats = self.log_stats
|
||||
stat_loggers = self.stat_loggers if log_stats else None
|
||||
|
||||
async def output_handler():
|
||||
try:
|
||||
while True:
|
||||
# 1) Pull EngineCoreOutputs from the EngineCore.
|
||||
outputs = await engine_core.get_output_async()
|
||||
num_outputs = len(outputs.outputs)
|
||||
|
||||
iteration_stats = IterationStats() if (
|
||||
log_stats and num_outputs) else None
|
||||
|
||||
# Split outputs into chunks of at most
|
||||
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
|
||||
# event loop for too long.
|
||||
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
|
||||
slices = (outputs.outputs, )
|
||||
else:
|
||||
slices = np.array_split(
|
||||
outputs.outputs,
|
||||
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
|
||||
|
||||
for i, outputs_slice in enumerate(slices):
|
||||
# 2) Process EngineCoreOutputs.
|
||||
processed_outputs = output_processor.process_outputs(
|
||||
outputs_slice, outputs.timestamp, iteration_stats)
|
||||
# NOTE: RequestOutputs are pushed to their queues.
|
||||
assert not processed_outputs.request_outputs
|
||||
|
||||
# Allow other asyncio tasks to run between chunks
|
||||
if i + 1 < len(slices):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
await engine_core.abort_requests_async(
|
||||
processed_outputs.reqs_to_abort)
|
||||
|
||||
# 4) Logging.
|
||||
# TODO(rob): make into a coroutine and launch it in
|
||||
# background thread once Prometheus overhead is non-trivial.
|
||||
if stat_loggers:
|
||||
AsyncLLM._record_stats(
|
||||
stat_loggers[outputs.engine_index],
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("AsyncLLM output_handler failed.")
|
||||
output_processor.propagate_error(e)
|
||||
|
||||
self.output_handler = asyncio.create_task(output_handler())
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort RequestId in OutputProcessor and EngineCore."""
|
||||
|
||||
request_ids = self.output_processor.abort_requests((request_id, ))
|
||||
await self.engine_core.abort_requests_async(request_ids)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Aborted request %s.", request_id)
|
||||
|
||||
@staticmethod
|
||||
def _record_stats(
|
||||
stat_loggers: list[StatLoggerBase],
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
):
|
||||
"""static so that it can be used from the output_handler task
|
||||
without a circular ref to AsyncLLM."""
|
||||
for stat_logger in stat_loggers:
|
||||
stat_logger.record(scheduler_stats=scheduler_stats,
|
||||
iteration_stats=iteration_stats)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
):
|
||||
raise ValueError("Not Supported on V1 yet.")
|
||||
|
||||
async def get_vllm_config(self) -> VllmConfig:
|
||||
return self.vllm_config
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
async def get_decoding_config(self):
|
||||
raise ValueError("Not Supported on V1 yet.")
|
||||
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
return self.processor.input_preprocessor
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return False
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs=None,
|
||||
model_output=None,
|
||||
) -> None:
|
||||
for loggers in self.stat_loggers:
|
||||
for stat_logger in loggers:
|
||||
stat_logger.log()
|
||||
|
||||
async def check_health(self) -> None:
|
||||
logger.debug("Called check_health.")
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
await self.engine_core.profile_async(True)
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
await self.engine_core.profile_async(False)
|
||||
|
||||
async def reset_mm_cache(self) -> None:
|
||||
self.processor.mm_registry.reset_processor_cache()
|
||||
self.processor.mm_input_cache_client.reset()
|
||||
await self.engine_core.reset_mm_cache_async()
|
||||
|
||||
async def reset_prefix_cache(self,
|
||||
device: Optional[Device] = None) -> None:
|
||||
if device == Device.CPU:
|
||||
raise ValueError("Not supported on CPU.")
|
||||
await self.engine_core.reset_prefix_cache_async()
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
await self.engine_core.sleep_async(level)
|
||||
|
||||
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
await self.engine_core.wake_up_async(tags)
|
||||
|
||||
async def is_sleeping(self) -> bool:
|
||||
return await self.engine_core.is_sleeping_async()
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
return await self.engine_core.add_lora_async(lora_request)
|
||||
|
||||
async def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove an already loaded LoRA adapter."""
|
||||
return await self.engine_core.remove_lora_async(lora_id)
|
||||
|
||||
async def list_loras(self) -> set[int]:
|
||||
"""List all registered adapters."""
|
||||
return await self.engine_core.list_loras_async()
|
||||
|
||||
async def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return await self.engine_core.pin_lora_async(lora_id)
|
||||
|
||||
async def collective_rpc(self,
|
||||
method: str,
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None):
|
||||
"""
|
||||
Perform a collective RPC call to the given path.
|
||||
"""
|
||||
return await self.engine_core.collective_rpc_async(
|
||||
method, timeout, args, kwargs)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
# Is None before the loop is started.
|
||||
return self.output_handler is None or not self.output_handler.done()
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self.errored
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return self.engine_core.resources.engine_dead or not self.is_running
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
return EngineDeadError()
|
||||
253
vllm/v1/engine/coordinator.py
Normal file
253
vllm/v1/engine/coordinator.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
from typing import Optional
|
||||
|
||||
import msgspec.msgpack
|
||||
import zmq
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket
|
||||
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DPCoordinator:
|
||||
"""Coordinator process used for data-parallel deployments (DP>1).
|
||||
|
||||
Intermediates between multiple DP engine rank processes and one or more
|
||||
front-end API server processes.
|
||||
|
||||
* Collects stats from each DP engine (currently just waiting and running
|
||||
queue lengths), and publishes these to all front-ends for use in
|
||||
load-balancing decisions.
|
||||
|
||||
* Keeps track of the current DP "request wave" number and running state
|
||||
of the engines. This is received from the DP rank 0 engine and published
|
||||
to the front-end processes along with the current load stats.
|
||||
|
||||
The engines alternate between a global running/paused state. The global
|
||||
"request wave" number is a count of the number of times that the workers
|
||||
collectively move from a running state to a paused state. This transition
|
||||
is synchronized via the all-reduce operation performed in the
|
||||
DPEngineCoreProc._has_global_unfinished_reqs method.
|
||||
|
||||
* Broadcasts the START_DP_WAVE message to engines to move them from paused
|
||||
to running state when one engine receives a new request. This can happen
|
||||
in two cases:
|
||||
1) A front-end sending a new request while the engines are paused will
|
||||
concurrently notify the coordinator.
|
||||
2) An engine receiving a request for a stale request wave while in paused
|
||||
state will notify the coordinator.
|
||||
|
||||
Engines will move into running state when receiving a new request or
|
||||
START_DP_WAVE message.
|
||||
"""
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
|
||||
# Assume coordinator is colocated with front-end procs.
|
||||
front_publish_address = get_open_zmq_ipc_path()
|
||||
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
assert dp_size > 1, "Coordinator only used for data parallel"
|
||||
|
||||
local_only = dp_size == parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only, host)
|
||||
|
||||
context = get_mp_context()
|
||||
self.proc: multiprocessing.Process = context.Process(
|
||||
target=CoordinatorProc.run_coordinator,
|
||||
name="VLLM_DP_Coordinator",
|
||||
kwargs={
|
||||
"engine_count": parallel_config.data_parallel_size,
|
||||
"front_publish_address": front_publish_address,
|
||||
"back_output_address": back_output_address,
|
||||
"back_publish_address": back_publish_address,
|
||||
},
|
||||
daemon=True)
|
||||
self.proc.start()
|
||||
|
||||
self.stats_publish_address = front_publish_address
|
||||
self.coord_in_address = back_publish_address
|
||||
self.coord_out_address = back_output_address
|
||||
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
|
||||
|
||||
def get_stats_publish_address(self) -> str:
|
||||
return self.stats_publish_address
|
||||
|
||||
def get_engine_socket_addresses(self) -> tuple[str, str]:
|
||||
"""Returns tuple of ZMQ input address, output address."""
|
||||
return self.coord_in_address, self.coord_out_address
|
||||
|
||||
def close(self):
|
||||
self._finalizer()
|
||||
|
||||
|
||||
class EngineState:
|
||||
|
||||
def __init__(self):
|
||||
self.request_counts = [0, 0] # [waiting, running]
|
||||
|
||||
|
||||
class CoordinatorProc:
|
||||
|
||||
def __init__(self, engine_count: int):
|
||||
|
||||
self.ctx = zmq.Context()
|
||||
|
||||
self.engines = [EngineState() for _ in range(engine_count)]
|
||||
|
||||
self.current_wave = 0
|
||||
self.engines_running = False
|
||||
self.stats_changed = False
|
||||
|
||||
@staticmethod
|
||||
def run_coordinator(
|
||||
engine_count: int,
|
||||
front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
):
|
||||
coordinator = CoordinatorProc(engine_count=engine_count)
|
||||
try:
|
||||
coordinator.process_input_socket(
|
||||
front_publish_address,
|
||||
back_output_address,
|
||||
back_publish_address,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DP Coordinator process exiting")
|
||||
|
||||
def process_input_socket(self, front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str):
|
||||
|
||||
decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
with make_zmq_socket(
|
||||
path=front_publish_address, # IPC
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.XPUB,
|
||||
bind=True,
|
||||
) as publish_front, make_zmq_socket(
|
||||
path=back_output_address, # IPC or TCP
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.PULL,
|
||||
bind=True,
|
||||
) as output_back, make_zmq_socket(
|
||||
path=back_publish_address, # IPC or TCP
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.XPUB,
|
||||
bind=True,
|
||||
) as publish_back:
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(publish_front, zmq.POLLIN)
|
||||
poller.register(output_back, zmq.POLLIN)
|
||||
last_publish_time = 0
|
||||
while True:
|
||||
elapsed = int(time.time() * 1000) - last_publish_time
|
||||
# Send at 100 ms interval if the stats have changed,
|
||||
# or otherwise every 3 seconds.
|
||||
wait_for = 100 if self.stats_changed else 3000
|
||||
events = poller.poll(timeout=max(0, wait_for - elapsed))
|
||||
if not events:
|
||||
# Poller timeout - publish current stats to front-ends.
|
||||
engine_req_counts_list = self._get_engine_counts()
|
||||
to_publish = (engine_req_counts_list, self.current_wave,
|
||||
self.engines_running)
|
||||
publish_front.send(msgspec.msgpack.encode(to_publish))
|
||||
last_publish_time = int(time.time() * 1000)
|
||||
self.stats_changed = False
|
||||
continue
|
||||
|
||||
events = dict(events)
|
||||
|
||||
if publish_front in events:
|
||||
buffer = publish_front.recv()
|
||||
if buffer == b'\x01':
|
||||
# Ignore subscription messages.
|
||||
continue
|
||||
|
||||
# We received a message on the front-end XPUB socket,
|
||||
# from an API server sending a new request while the
|
||||
# engines are paused, so that we can wake the other
|
||||
# engines.
|
||||
engine_to_exclude, wave = msgspec.msgpack.decode(buffer)
|
||||
if wave < self.current_wave:
|
||||
# If the wave number is stale, ensure the message is
|
||||
# handled by all the engines.
|
||||
engine_to_exclude = None
|
||||
if not self.engines_running:
|
||||
self.engines_running = True
|
||||
self.stats_changed = True
|
||||
self._send_start_wave(publish_back, self.current_wave,
|
||||
engine_to_exclude)
|
||||
|
||||
if output_back in events:
|
||||
# We received a message from one of the engines.
|
||||
|
||||
buffer = output_back.recv()
|
||||
outputs: EngineCoreOutputs = decoder.decode(buffer)
|
||||
|
||||
assert not outputs.outputs
|
||||
assert outputs.utility_output is None
|
||||
|
||||
eng_index = outputs.engine_index
|
||||
if outputs.scheduler_stats:
|
||||
# 1. Updated request load stats - update our local
|
||||
# state with these.
|
||||
stats = self.engines[eng_index].request_counts
|
||||
stats[0] = outputs.scheduler_stats.num_waiting_reqs
|
||||
stats[1] = outputs.scheduler_stats.num_running_reqs
|
||||
self.stats_changed = True
|
||||
|
||||
if (wave := outputs.wave_complete) is not None:
|
||||
# 2. Notification from rank 0 engine that we've
|
||||
# moved into the global paused state
|
||||
# (engines_running==False)
|
||||
if self.current_wave <= wave:
|
||||
logger.debug("Moving DP wave from %d to %d.",
|
||||
self.current_wave, wave)
|
||||
self.current_wave = wave + 1
|
||||
self.engines_running = False
|
||||
self.stats_changed = True
|
||||
elif (wave := outputs.start_wave) is not None and (
|
||||
wave > self.current_wave or
|
||||
(wave == self.current_wave
|
||||
and not self.engines_running)):
|
||||
# 3. The engine received request for a non-current wave
|
||||
# so we must ensure that other engines progress to the
|
||||
# next wave (race condition handling).
|
||||
logger.debug(
|
||||
"Starting wave %d after notification of "
|
||||
"stale wave request from engine.", wave)
|
||||
self.current_wave = wave
|
||||
self.engines_running = True
|
||||
self.stats_changed = True
|
||||
self._send_start_wave(publish_back, wave, eng_index)
|
||||
|
||||
@staticmethod
|
||||
def _send_start_wave(socket: zmq.Socket, wave: int,
|
||||
exclude_engine_index: Optional[int]):
|
||||
"""Broadcast the START_DP_WAVE message to all the engines.
|
||||
It includes the current wave number and index of engine which
|
||||
has already received a request with this wave number and so doesn't
|
||||
require additional notification.
|
||||
"""
|
||||
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
|
||||
socket.send_multipart(
|
||||
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
|
||||
|
||||
def _get_engine_counts(self) -> list[list[int]]:
|
||||
"""Return list of [waiting, running] count lists for each engine."""
|
||||
return [e.request_counts for e in self.engines]
|
||||
962
vllm/v1/engine/core.py
Normal file
962
vllm/v1/engine/core.py
Normal file
@@ -0,0 +1,962 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import Future
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.executor.multiproc_worker_utils import _add_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils.dump_input import dump_engine_exception
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
|
||||
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_S = 2.5
|
||||
HANDSHAKE_TIMEOUT_MINS = 5
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
|
||||
class EngineCore:
|
||||
"""Inner loop of vLLM's Engine."""
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
executor_fail_callback: Optional[Callable] = None):
|
||||
assert vllm_config.model_config.runner_type != "pooling"
|
||||
|
||||
# plugins need to be loaded at the engine/scheduler level too
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
|
||||
VLLM_VERSION, vllm_config)
|
||||
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Setup Model.
|
||||
self.model_executor = executor_class(vllm_config)
|
||||
if executor_fail_callback is not None:
|
||||
self.model_executor.register_failure_callback(
|
||||
executor_fail_callback)
|
||||
|
||||
# Setup KV Caches and update CacheConfig after profiling.
|
||||
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
|
||||
self._initialize_kv_caches(vllm_config)
|
||||
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self.structured_output_manager = StructuredOutputManager(vllm_config)
|
||||
|
||||
# Setup scheduler.
|
||||
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
|
||||
Scheduler = resolve_obj_by_qualname(
|
||||
vllm_config.scheduler_config.scheduler_cls)
|
||||
else:
|
||||
Scheduler = vllm_config.scheduler_config.scheduler_cls
|
||||
|
||||
# This warning can be removed once the V1 Scheduler interface is
|
||||
# finalized and we can maintain support for scheduler classes that
|
||||
# implement it
|
||||
if Scheduler is not V1Scheduler:
|
||||
logger.warning(
|
||||
"Using configured V1 scheduler class %s. "
|
||||
"This scheduler interface is not public and "
|
||||
"compatibility may not be maintained.",
|
||||
vllm_config.scheduler_config.scheduler_cls)
|
||||
|
||||
self.scheduler: SchedulerInterface = Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
structured_output_manager=self.structured_output_manager,
|
||||
include_finished_set=vllm_config.parallel_config.data_parallel_size
|
||||
> 1,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
# Setup MM Input Mapper.
|
||||
self.mm_input_cache_server = MirroredProcessingCache(
|
||||
vllm_config.model_config)
|
||||
|
||||
# Setup batch queue for pipeline parallelism.
|
||||
# Batch queue for scheduled batches. This enables us to asynchronously
|
||||
# schedule and execute batches, and is required by pipeline parallelism
|
||||
# to eliminate pipeline bubbles.
|
||||
self.batch_queue_size = self.model_executor.max_concurrent_batches
|
||||
self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
|
||||
SchedulerOutput]]] = None
|
||||
if self.batch_queue_size > 1:
|
||||
logger.info("Batch queue is enabled with size %d",
|
||||
self.batch_queue_size)
|
||||
self.batch_queue = queue.Queue(self.batch_queue_size)
|
||||
|
||||
def _initialize_kv_caches(
|
||||
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
||||
start = time.time()
|
||||
|
||||
# Get all kv cache needed by the model
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
|
||||
# Profiles the peak memory usage of the model to determine how much
|
||||
# memory can be allocated for kv cache.
|
||||
available_gpu_memory = self.model_executor.determine_available_memory()
|
||||
|
||||
assert len(kv_cache_specs) == len(available_gpu_memory)
|
||||
# Get the kv cache tensor size
|
||||
kv_cache_configs = [
|
||||
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
|
||||
available_gpu_memory_one_worker)
|
||||
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
|
||||
zip(kv_cache_specs, available_gpu_memory)
|
||||
]
|
||||
|
||||
# Since we use a shared centralized controller, we need the
|
||||
# `kv_cache_config` to be consistent across all workers to make sure
|
||||
# all the memory operators can be applied to all workers.
|
||||
unify_kv_cache_configs(kv_cache_configs)
|
||||
|
||||
# All workers have the same kv_cache_config except layer names, so use
|
||||
# an arbitrary one to initialize the scheduler.
|
||||
assert all([
|
||||
cfg.num_blocks == kv_cache_configs[0].num_blocks
|
||||
for cfg in kv_cache_configs
|
||||
])
|
||||
num_gpu_blocks = kv_cache_configs[0].num_blocks
|
||||
num_cpu_blocks = 0
|
||||
scheduler_kv_cache_config = kv_cache_configs[0]
|
||||
|
||||
# Initialize kv cache and warmup the execution
|
||||
self.model_executor.initialize_from_config(kv_cache_configs)
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(("init engine (profile, create kv cache, "
|
||||
"warmup model) took %.2f seconds"), elapsed)
|
||||
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
"""Add request to the scheduler."""
|
||||
|
||||
if request.mm_hashes is not None:
|
||||
# Here, if hash exists for a multimodal input, then it will be
|
||||
# fetched from the cache, else it will be added to the cache.
|
||||
# Note that the cache here is mirrored with the client cache, so
|
||||
# anything that has a hash must have a HIT cache entry here
|
||||
# as well.
|
||||
assert request.mm_inputs is not None
|
||||
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
|
||||
request.mm_inputs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
if req.use_structured_output:
|
||||
# Start grammar compilation asynchronously
|
||||
self.structured_output_manager.grammar_init(req)
|
||||
|
||||
if req.kv_transfer_params is not None and (
|
||||
not self.scheduler.get_kv_connector()):
|
||||
logger.warning("Got kv_transfer_params, but no KVConnector found. "
|
||||
"Disabling KVTransfer for this request.")
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def abort_requests(self, request_ids: list[str]):
|
||||
"""Abort requests from the scheduler."""
|
||||
|
||||
# TODO: The scheduler doesn't really need to know the
|
||||
# specific finish reason, TBD whether we propagate that
|
||||
# (i.e. client-aborted vs stop criteria met).
|
||||
self.scheduler.finish_requests(request_ids,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
|
||||
def execute_model(self, scheduler_output: SchedulerOutput):
|
||||
try:
|
||||
return self.model_executor.execute_model(scheduler_output)
|
||||
except BaseException as err:
|
||||
# NOTE: This method is exception-free
|
||||
dump_engine_exception(self.vllm_config, scheduler_output,
|
||||
self.scheduler.make_stats())
|
||||
# Re-raise exception
|
||||
raise err
|
||||
|
||||
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
|
||||
"""Schedule, execute, and make output.
|
||||
|
||||
Returns tuple of outputs and a flag indicating whether the model
|
||||
was executed.
|
||||
"""
|
||||
|
||||
# Check for any requests remaining in the scheduler - unfinished,
|
||||
# or finished and not yet removed from the batch.
|
||||
if not self.scheduler.has_requests():
|
||||
return {}, False
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
model_output = self.execute_model(scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output) # type: ignore
|
||||
|
||||
return (engine_core_outputs,
|
||||
scheduler_output.total_num_scheduled_tokens > 0)
|
||||
|
||||
def step_with_batch_queue(
|
||||
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
|
||||
"""Schedule and execute batches with the batch queue.
|
||||
Note that if nothing to output in this step, None is returned.
|
||||
|
||||
The execution flow is as follows:
|
||||
1. Try to schedule a new batch if the batch queue is not full.
|
||||
If a new batch is scheduled, directly return an empty engine core
|
||||
output. In other words, fulfilling the batch queue has a higher priority
|
||||
than getting model outputs.
|
||||
2. If there is no new scheduled batch, meaning that the batch queue
|
||||
is full or no other requests can be scheduled, we block until the first
|
||||
batch in the job queue is finished.
|
||||
3. Update the scheduler from the output.
|
||||
"""
|
||||
assert self.batch_queue is not None
|
||||
|
||||
engine_core_outputs = None
|
||||
scheduler_output = None
|
||||
# Try to schedule a new batch if the batch queue is not full, but
|
||||
# the scheduler may return an empty batch if all requests are scheduled.
|
||||
# Note that this is not blocking.
|
||||
if not self.batch_queue.full():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
future = self.model_executor.execute_model(scheduler_output)
|
||||
self.batch_queue.put_nowait(
|
||||
(future, scheduler_output)) # type: ignore
|
||||
|
||||
scheduled_batch = (scheduler_output is not None
|
||||
and scheduler_output.total_num_scheduled_tokens > 0)
|
||||
|
||||
# If no more requests can be scheduled and the job queue is not empty,
|
||||
# block until the first batch in the job queue is finished.
|
||||
# TODO(comaniac): Ideally we should peek the first batch in the
|
||||
# job queue to check if it's finished before scheduling a new batch,
|
||||
# but peeking the first element in a queue is not thread-safe,
|
||||
# so we need more work.
|
||||
if not scheduled_batch and not self.batch_queue.empty():
|
||||
future, scheduler_output = self.batch_queue.get_nowait()
|
||||
# Blocking until the first result is available.
|
||||
model_output = future.result()
|
||||
self.batch_queue.task_done()
|
||||
engine_core_outputs = (self.scheduler.update_from_output(
|
||||
scheduler_output, model_output))
|
||||
|
||||
return engine_core_outputs, scheduled_batch
|
||||
|
||||
def shutdown(self):
|
||||
self.structured_output_manager.clear_backend()
|
||||
if self.model_executor:
|
||||
self.model_executor.shutdown()
|
||||
if self.scheduler:
|
||||
self.scheduler.shutdown()
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
self.model_executor.profile(is_start)
|
||||
|
||||
def reset_mm_cache(self):
|
||||
# NOTE: Since this is mainly for debugging, we don't attempt to
|
||||
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
|
||||
if self.scheduler.has_unfinished_requests():
|
||||
logger.warning("Resetting the multi-modal cache when requests are "
|
||||
"in progress may lead to desynced internal caches.")
|
||||
|
||||
self.mm_input_cache_server.reset()
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
self.scheduler.reset_prefix_cache()
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
self.model_executor.sleep(level)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None):
|
||||
self.model_executor.wake_up(tags)
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
return self.model_executor.is_sleeping
|
||||
|
||||
def execute_dummy_batch(self):
|
||||
self.model_executor.collective_rpc("execute_dummy_batch")
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_executor.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_executor.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
return self.model_executor.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_executor.pin_lora(lora_id)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model_executor.save_sharded_state(path=path,
|
||||
pattern=pattern,
|
||||
max_size=max_size)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.model_executor.collective_rpc(method, timeout, args,
|
||||
kwargs)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config,
|
||||
) -> None:
|
||||
self.model_executor.save_tensorized_model(
|
||||
tensorizer_config=tensorizer_config, )
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
|
||||
ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
engine_index: int = 0,
|
||||
):
|
||||
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
|
||||
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
|
||||
bytes]]()
|
||||
executor_fail_callback = lambda: self.input_queue.put_nowait(
|
||||
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
|
||||
|
||||
self.engine_index = engine_index
|
||||
identity = self.engine_index.to_bytes(length=2, byteorder="little")
|
||||
self.engines_running = False
|
||||
|
||||
with self._perform_handshake(handshake_address, identity, on_head_node,
|
||||
vllm_config) as addresses:
|
||||
self.client_count = len(addresses.outputs)
|
||||
|
||||
# Set up data parallel environment.
|
||||
self.has_coordinator = addresses.coordinator_output is not None
|
||||
self._init_data_parallel(vllm_config)
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
executor_fail_callback)
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
threading.Thread(target=self.process_input_sockets,
|
||||
args=(addresses.inputs, addresses.coordinator_input,
|
||||
identity),
|
||||
daemon=True).start()
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_sockets,
|
||||
args=(addresses.outputs, addresses.coordinator_output,
|
||||
self.engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshake(
|
||||
self, handshake_address: str, identity: bytes, on_head_node: bool,
|
||||
vllm_config: VllmConfig
|
||||
) -> Generator[EngineZmqAddresses, None, None]:
|
||||
input_ctx = zmq.Context()
|
||||
with make_zmq_socket(input_ctx,
|
||||
handshake_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
linger=5000,
|
||||
bind=False) as handshake_socket:
|
||||
# Register engine with front-end.
|
||||
addresses = self.startup_handshake(handshake_socket, on_head_node,
|
||||
vllm_config.parallel_config)
|
||||
|
||||
# Update config which may have changed from the handshake
|
||||
vllm_config.__post_init__()
|
||||
|
||||
yield addresses
|
||||
|
||||
# Send ready message.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "READY",
|
||||
"local": on_head_node,
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
}))
|
||||
|
||||
@staticmethod
|
||||
def startup_handshake(
|
||||
handshake_socket: zmq.Socket, on_head_node: bool,
|
||||
parallel_config: ParallelConfig) -> EngineZmqAddresses:
|
||||
|
||||
# Send registration message.
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "HELLO",
|
||||
"local": on_head_node,
|
||||
}))
|
||||
|
||||
# Receive initialization message.
|
||||
logger.info("Waiting for init message from front-end.")
|
||||
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
|
||||
raise RuntimeError("Did not receive response from front-end "
|
||||
f"process within {HANDSHAKE_TIMEOUT_MINS} "
|
||||
f"minutes")
|
||||
init_bytes = handshake_socket.recv()
|
||||
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
|
||||
init_bytes, type=EngineHandshakeMetadata)
|
||||
logger.debug("Received init message: %s", init_message)
|
||||
|
||||
received_parallel_config = init_message.parallel_config
|
||||
for key, value in received_parallel_config.items():
|
||||
setattr(parallel_config, key, value)
|
||||
|
||||
return init_message.addresses
|
||||
|
||||
@staticmethod
|
||||
def run_engine_core(*args,
|
||||
dp_rank: int = 0,
|
||||
local_dp_rank: int = 0,
|
||||
**kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
# processes to terminate without error
|
||||
shutdown_requested = False
|
||||
|
||||
# Ensure we can serialize transformer config after spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the engine_core
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
engine_core: Optional[EngineCoreProc] = None
|
||||
try:
|
||||
parallel_config: ParallelConfig = kwargs[
|
||||
"vllm_config"].parallel_config
|
||||
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
|
||||
# Set data parallel rank for this engine process.
|
||||
parallel_config.data_parallel_rank = dp_rank
|
||||
parallel_config.data_parallel_rank_local = local_dp_rank
|
||||
engine_core = DPEngineCoreProc(*args, **kwargs)
|
||||
else:
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
logger.debug("EngineCore exiting.")
|
||||
raise
|
||||
except Exception as e:
|
||||
if engine_core is None:
|
||||
logger.exception("EngineCore failed to start.")
|
||||
else:
|
||||
logger.exception("EngineCore encountered a fatal error.")
|
||||
engine_core._send_engine_dead()
|
||||
raise e
|
||||
finally:
|
||||
if engine_core is not None:
|
||||
engine_core.shutdown()
|
||||
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
pass
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
# 2) Step the engine core and return the outputs.
|
||||
self._process_engine_step()
|
||||
|
||||
def _process_input_queue(self):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
waited = False
|
||||
while not self.engines_running and not self.scheduler.has_requests():
|
||||
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
req = self.input_queue.get()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
if waited:
|
||||
logger.debug("EngineCore loop active.")
|
||||
|
||||
# Handle any more client requests.
|
||||
while not self.input_queue.empty():
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
def _process_engine_step(self) -> bool:
|
||||
"""Called only when there are unfinished local requests."""
|
||||
|
||||
# Step the engine core.
|
||||
outputs, model_executed = self.step_fn()
|
||||
# Put EngineCoreOutputs into the output queue.
|
||||
for output in (outputs.items() if outputs else ()):
|
||||
self.output_queue.put_nowait(output)
|
||||
|
||||
return model_executed
|
||||
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
"""Dispatch request from client."""
|
||||
|
||||
if request_type == EngineCoreRequestType.ADD:
|
||||
self.add_request(request)
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.UTILITY:
|
||||
client_idx, call_id, method_name, args = request
|
||||
output = UtilityOutput(call_id)
|
||||
try:
|
||||
method = getattr(self, method_name)
|
||||
output.result = method(
|
||||
*self._convert_msgspec_args(method, args))
|
||||
except BaseException as e:
|
||||
logger.exception("Invocation of %s method failed", method_name)
|
||||
output.failure_message = (f"Call to {method_name} method"
|
||||
f" failed: {str(e)}")
|
||||
self.output_queue.put_nowait(
|
||||
(client_idx, EngineCoreOutputs(utility_output=output)))
|
||||
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
|
||||
raise RuntimeError("Executor failed.")
|
||||
else:
|
||||
logger.error("Unrecognized input request type encountered: %s",
|
||||
request_type)
|
||||
|
||||
@staticmethod
|
||||
def _convert_msgspec_args(method, args):
|
||||
"""If a provided arg type doesn't match corresponding target method
|
||||
arg type, try converting to msgspec object."""
|
||||
if not args:
|
||||
return args
|
||||
arg_types = signature(method).parameters.values()
|
||||
assert len(args) <= len(arg_types)
|
||||
return tuple(
|
||||
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
|
||||
and issubclass(p.annotation, msgspec.Struct)
|
||||
and not isinstance(v, p.annotation) else v
|
||||
for v, p in zip(args, arg_types))
|
||||
|
||||
def _send_engine_dead(self):
|
||||
"""Send EngineDead status to the EngineCoreClient."""
|
||||
|
||||
# Put ENGINE_CORE_DEAD in the queue.
|
||||
self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)
|
||||
|
||||
# Wait until msg sent by the daemon before shutdown.
|
||||
self.output_thread.join(timeout=5.0)
|
||||
if self.output_thread.is_alive():
|
||||
logger.fatal("vLLM shutdown signal from EngineCore failed "
|
||||
"to send. Please report this issue.")
|
||||
|
||||
def process_input_sockets(self, input_addresses: list[str],
|
||||
coord_input_address: Optional[str],
|
||||
identity: bytes):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
|
||||
with ExitStack() as stack, zmq.Context() as ctx:
|
||||
input_sockets = [
|
||||
stack.enter_context(
|
||||
make_zmq_socket(ctx,
|
||||
input_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
bind=False))
|
||||
for input_address in input_addresses
|
||||
]
|
||||
if coord_input_address is None:
|
||||
coord_socket = None
|
||||
else:
|
||||
coord_socket = stack.enter_context(
|
||||
make_zmq_socket(ctx,
|
||||
coord_input_address,
|
||||
zmq.XSUB,
|
||||
identity=identity,
|
||||
bind=False))
|
||||
# Send subscription message to coordinator.
|
||||
coord_socket.send(b'\x01')
|
||||
|
||||
# Register sockets with poller.
|
||||
poller = zmq.Poller()
|
||||
for input_socket in input_sockets:
|
||||
# Send initial message to each input socket - this is required
|
||||
# before the front-end ROUTER socket can send input messages
|
||||
# back to us.
|
||||
input_socket.send(b'')
|
||||
poller.register(input_socket, zmq.POLLIN)
|
||||
if coord_socket is not None:
|
||||
poller.register(coord_socket, zmq.POLLIN)
|
||||
|
||||
while True:
|
||||
for input_socket, _ in poller.poll():
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = input_socket.recv_multipart(
|
||||
copy=False)
|
||||
request_type = EngineCoreRequestType(
|
||||
bytes(type_frame.buffer))
|
||||
|
||||
# Deserialize the request data.
|
||||
decoder = add_request_decoder if (
|
||||
request_type
|
||||
== EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frames)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
|
||||
def process_output_sockets(self, output_paths: list[str],
|
||||
coord_output_path: Optional[str],
|
||||
engine_index: int):
|
||||
"""Output socket IO thread."""
|
||||
|
||||
# Msgpack serialization encoding.
|
||||
encoder = MsgpackEncoder()
|
||||
# Send buffers to reuse.
|
||||
reuse_buffers: list[bytearray] = []
|
||||
# Keep references to outputs and buffers until zmq is finished
|
||||
# with them (outputs may contain tensors/np arrays whose
|
||||
# backing buffers were extracted for zero-copy send).
|
||||
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
|
||||
|
||||
# We must set linger to ensure the ENGINE_CORE_DEAD
|
||||
# message is sent prior to closing the socket.
|
||||
with ExitStack() as stack, zmq.Context() as ctx:
|
||||
sockets = [
|
||||
stack.enter_context(
|
||||
make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
|
||||
for output_path in output_paths
|
||||
]
|
||||
coord_socket = stack.enter_context(
|
||||
make_zmq_socket(
|
||||
ctx, coord_output_path, zmq.PUSH, bind=False,
|
||||
linger=4000)) if coord_output_path is not None else None
|
||||
max_reuse_bufs = len(sockets) + 1
|
||||
|
||||
while True:
|
||||
output = self.output_queue.get()
|
||||
if output == EngineCoreProc.ENGINE_CORE_DEAD:
|
||||
for socket in sockets:
|
||||
socket.send(output)
|
||||
break
|
||||
assert not isinstance(output, bytes)
|
||||
client_index, outputs = output
|
||||
outputs.engine_index = engine_index
|
||||
|
||||
if client_index == -1:
|
||||
# Don't reuse buffer for coordinator message
|
||||
# which will be very small.
|
||||
assert coord_socket is not None
|
||||
coord_socket.send_multipart(encoder.encode(outputs))
|
||||
continue
|
||||
|
||||
# Reclaim buffers that zmq is finished with.
|
||||
while pending and pending[-1][0].done:
|
||||
reuse_buffers.append(pending.pop()[2])
|
||||
|
||||
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
|
||||
buffers = encoder.encode_into(outputs, buffer)
|
||||
tracker = sockets[client_index].send_multipart(buffers,
|
||||
copy=False,
|
||||
track=True)
|
||||
if not tracker.done:
|
||||
ref = outputs if len(buffers) > 1 else None
|
||||
pending.appendleft((tracker, ref, buffer))
|
||||
elif len(reuse_buffers) < max_reuse_bufs:
|
||||
# Limit the number of buffers to reuse.
|
||||
reuse_buffers.append(buffer)
|
||||
|
||||
|
||||
class DPEngineCoreProc(EngineCoreProc):
|
||||
"""ZMQ-wrapper for running EngineCore in background process
|
||||
in a data parallel context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
|
||||
self._decorate_logs()
|
||||
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
self.current_wave = 0
|
||||
self.last_counts = (0, 0)
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
super().__init__(vllm_config, on_head_node, handshake_address,
|
||||
executor_class, log_stats, dp_rank)
|
||||
|
||||
def _decorate_logs(self):
|
||||
# Add process-specific prefix to stdout and stderr before
|
||||
# we initialize the engine.
|
||||
from multiprocessing import current_process
|
||||
process_name = current_process().name
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
|
||||
# Configure GPUs and stateless process group for data parallel.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
||||
|
||||
assert dp_size > 1
|
||||
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
||||
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
# modify the engine_id and append the local_dp_rank to it to ensure
|
||||
# that the kv_transfer_config is unique for each DP rank.
|
||||
vllm_config.kv_transfer_config.engine_id = (
|
||||
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
|
||||
)
|
||||
logger.debug("Setting kv_transfer_config.engine_id to %s",
|
||||
vllm_config.kv_transfer_config.engine_id)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
device_control_env_var = current_platform.device_control_env_var
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
os.environ[device_control_env_var] = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
|
||||
world_size))
|
||||
os.environ["MACA_VISIBLE_DEVICES"] = os.environ[device_control_env_var]
|
||||
|
||||
self.dp_rank = dp_rank
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
|
||||
def shutdown(self):
|
||||
super().shutdown()
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
if self.has_coordinator and request.current_wave != self.current_wave:
|
||||
if request.current_wave > self.current_wave:
|
||||
self.current_wave = request.current_wave
|
||||
elif not self.engines_running:
|
||||
# Request received for an already-completed wave, notify
|
||||
# front-end that we need to start the next one.
|
||||
self.output_queue.put_nowait(
|
||||
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
|
||||
|
||||
super().add_request(request)
|
||||
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
if request_type == EngineCoreRequestType.START_DP_WAVE:
|
||||
new_wave, exclude_eng_index = request
|
||||
if exclude_eng_index != self.engine_index and (
|
||||
new_wave >= self.current_wave):
|
||||
self.current_wave = new_wave
|
||||
if not self.engines_running:
|
||||
logger.debug("EngineCore starting idle loop for wave %d.",
|
||||
new_wave)
|
||||
self.engines_running = True
|
||||
else:
|
||||
super()._handle_client_request(request_type, request)
|
||||
|
||||
def _maybe_publish_request_counts(self):
|
||||
if not self.has_coordinator:
|
||||
return
|
||||
|
||||
# Publish our request counts (if they've changed).
|
||||
counts = self.scheduler.get_request_counts()
|
||||
if counts != self.last_counts:
|
||||
self.last_counts = counts
|
||||
stats = SchedulerStats(*counts)
|
||||
self.output_queue.put_nowait(
|
||||
(-1, EngineCoreOutputs(scheduler_stats=stats)))
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore for data parallel case."""
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
|
||||
# 2) Step the engine core.
|
||||
executed = self._process_engine_step()
|
||||
self._maybe_publish_request_counts()
|
||||
|
||||
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
|
||||
if not executed:
|
||||
if not local_unfinished_reqs and not self.engines_running:
|
||||
# All engines are idle.
|
||||
continue
|
||||
|
||||
# We are in a running state and so must execute a dummy pass
|
||||
# if the model didn't execute any ready requests.
|
||||
self.execute_dummy_batch()
|
||||
|
||||
# 3) All-reduce operation to determine global unfinished reqs.
|
||||
self.engines_running = self._has_global_unfinished_reqs(
|
||||
local_unfinished_reqs)
|
||||
|
||||
if not self.engines_running:
|
||||
if self.dp_rank == 0:
|
||||
# Notify client that we are pausing the loop.
|
||||
logger.debug("Wave %d finished, pausing engine loop.",
|
||||
self.current_wave)
|
||||
self.output_queue.put_nowait(
|
||||
(-1,
|
||||
EngineCoreOutputs(wave_complete=self.current_wave)))
|
||||
self.current_wave += 1
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
|
||||
# Optimization - only perform finish-sync all-reduce every 24 steps.
|
||||
self.counter += 1
|
||||
if self.counter != 24:
|
||||
return True
|
||||
self.counter = 0
|
||||
|
||||
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
||||
local_unfinished)
|
||||
|
||||
|
||||
class DPEngineCoreActor(DPEngineCoreProc):
|
||||
"""
|
||||
Ray actor for running EngineCore in a data parallel context
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
addresses: EngineZmqAddresses,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
dp_rank: int = 0,
|
||||
local_dp_rank: int = 0,
|
||||
):
|
||||
self.addresses = addresses
|
||||
vllm_config.parallel_config.data_parallel_rank = dp_rank
|
||||
vllm_config.parallel_config.data_parallel_rank_local = \
|
||||
local_dp_rank
|
||||
|
||||
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
|
||||
# we clean this up to be able to properly initialize
|
||||
# data parallel groups.
|
||||
# del os.environ['CUDA_VISIBLE_DEVICES']
|
||||
|
||||
super().__init__(vllm_config, on_head_node, "", executor_class,
|
||||
log_stats)
|
||||
|
||||
def _decorate_logs(self):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshake(self, handshake_address: str, identity: bytes,
|
||||
on_head_node: bool, vllm_config: VllmConfig):
|
||||
"""
|
||||
For Ray, we don't need to actually perform handshake.
|
||||
All addresses information is known before the actor creation.
|
||||
Therefore, we simply yield these addresses.
|
||||
"""
|
||||
yield self.addresses
|
||||
|
||||
def wait_for_init(self):
|
||||
"""
|
||||
Wait until the engine core is initialized.
|
||||
|
||||
This is just an empty method. When ray.get() on this method
|
||||
(or any other method of the actor) returns, it is guaranteed
|
||||
that actor creation (i.e., __init__) is complete.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the engine core busy loop.
|
||||
"""
|
||||
try:
|
||||
self.run_busy_loop()
|
||||
except SystemExit:
|
||||
logger.debug("EngineCore exiting.")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("EngineCore encountered a fatal error.")
|
||||
raise
|
||||
finally:
|
||||
self.shutdown()
|
||||
1129
vllm/v1/engine/core_client.py
Normal file
1129
vllm/v1/engine/core_client.py
Normal file
File diff suppressed because it is too large
Load Diff
286
vllm/v1/engine/detokenizer.py
Normal file
286
vllm/v1/engine/detokenizer.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import tokenizers
|
||||
from packaging import version
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.decoders import DecodeStream
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.detokenizer_utils import (
|
||||
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Only tokenizers >= 0.21.1 supports DecodeStream used for
|
||||
# FastIncrementalDetokenizer.
|
||||
USE_FAST_DETOKENIZER = version.parse(
|
||||
tokenizers.__version__) >= version.parse("0.21.1")
|
||||
|
||||
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
|
||||
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
|
||||
|
||||
class IncrementalDetokenizer:
|
||||
|
||||
def __init__(self):
|
||||
self.token_ids: list[int] = []
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> list[int]:
|
||||
return self.token_ids
|
||||
|
||||
def update(self, new_token_ids: list[int],
|
||||
stop_terminated: bool) -> Optional[str]:
|
||||
self.token_ids.extend(new_token_ids)
|
||||
return None
|
||||
|
||||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
request: EngineCoreRequest,
|
||||
) -> "IncrementalDetokenizer":
|
||||
|
||||
if tokenizer is None:
|
||||
# No tokenizer => skipping detokenization.
|
||||
return IncrementalDetokenizer()
|
||||
|
||||
if USE_FAST_DETOKENIZER and isinstance(tokenizer,
|
||||
PreTrainedTokenizerFast):
|
||||
# Fast tokenizer => use tokenizers library DecodeStream.
|
||||
return FastIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
# Fall back to slow python-based incremental detokenization.
|
||||
return SlowIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
|
||||
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
||||
|
||||
def __init__(self, request: EngineCoreRequest):
|
||||
super().__init__()
|
||||
|
||||
# Stop strings
|
||||
params = request.sampling_params
|
||||
self.stop = stop = params.stop
|
||||
self.include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
# Number of chars to hold back when stop strings are to be excluded
|
||||
# from streamed output.
|
||||
if stop and not self.include_stop_str_in_output:
|
||||
self.stop_buffer_length = max(len(s) for s in stop) - 1
|
||||
else:
|
||||
self.stop_buffer_length = 0
|
||||
self._last_output_text_offset: int = 0
|
||||
|
||||
# Generation data
|
||||
self.output_text = ""
|
||||
|
||||
def update(self, new_token_ids: list[int],
|
||||
stop_terminated: bool) -> Optional[str]:
|
||||
"""
|
||||
Update RequestState for the request_id by:
|
||||
1) Detokenize the new token ids incrementally.
|
||||
2) Evaluate stop criteria.
|
||||
|
||||
Return matched stop string or None.
|
||||
"""
|
||||
if not new_token_ids:
|
||||
# Skip detokenization if no new token ids.
|
||||
return None
|
||||
|
||||
if stop_terminated and not self.include_stop_str_in_output:
|
||||
# If stop-terminated, exclude last token from detokenization
|
||||
# based on include_stop_str_in_output parameter.
|
||||
skipped_stop_token_id = new_token_ids[-1]
|
||||
new_token_ids = new_token_ids[:-1]
|
||||
else:
|
||||
skipped_stop_token_id = None
|
||||
|
||||
# 1) Detokenize the new token ids incrementally.
|
||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||
# new_token_ids is more than 1. We need to optimize this.
|
||||
offset_before = len(self.output_text)
|
||||
for new_token_id in new_token_ids:
|
||||
self.token_ids.append(new_token_id)
|
||||
self.output_text += self.decode_next(new_token_id)
|
||||
|
||||
if stop_terminated:
|
||||
if skipped_stop_token_id is not None:
|
||||
# Cleanup after skipping detokenization.
|
||||
self.token_ids.append(skipped_stop_token_id)
|
||||
# Stop token triggered; skip stop string check.
|
||||
return None
|
||||
|
||||
# 2) Evaluate stop strings.
|
||||
stop_string = None
|
||||
if self.stop:
|
||||
stop = StopChecker.check_stop_strings(
|
||||
output_text=self.output_text,
|
||||
new_char_count=len(self.output_text) - offset_before,
|
||||
stop=self.stop,
|
||||
include_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
if stop is not None:
|
||||
stop_string, truncate_to = stop
|
||||
if truncate_to != -1:
|
||||
self.output_text = self.output_text[:truncate_to]
|
||||
|
||||
return stop_string
|
||||
|
||||
@abstractmethod
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
"""If delta is True, only new text since the last call to
|
||||
this method is returned"""
|
||||
|
||||
# We return the full output text if the sequence is finished.
|
||||
buffer_length = 0 if finished else self.stop_buffer_length
|
||||
if not delta:
|
||||
return self.output_text[:-buffer_length] if buffer_length else (
|
||||
self.output_text)
|
||||
length = len(self.output_text) - buffer_length
|
||||
last_offset = self._last_output_text_offset
|
||||
if last_offset < length:
|
||||
self._last_output_text_offset = length
|
||||
return self.output_text[last_offset:length]
|
||||
return ""
|
||||
|
||||
|
||||
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerFast,
|
||||
request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
|
||||
self.request_id = request.request_id
|
||||
self.skip_special_tokens = sampling_params.skip_special_tokens
|
||||
self.stream = DecodeStream(
|
||||
skip_special_tokens=self.skip_special_tokens)
|
||||
|
||||
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
||||
|
||||
# Find a safe place to start.
|
||||
prompt_suffix = request.prompt_token_ids
|
||||
prompt_len = len(prompt_suffix)
|
||||
if prompt_len > 4:
|
||||
for i in range(4, min(prompt_len + 1, 24)):
|
||||
suffix = request.prompt_token_ids[-i:]
|
||||
if '<EFBFBD>' not in self.tokenizer.decode(suffix):
|
||||
prompt_suffix = suffix
|
||||
break
|
||||
|
||||
# Prime the stream.
|
||||
for tid in prompt_suffix:
|
||||
self._protected_step(tid)
|
||||
|
||||
self.spaces_between_special_tokens = (
|
||||
sampling_params.skip_special_tokens
|
||||
or sampling_params.spaces_between_special_tokens)
|
||||
|
||||
if not self.spaces_between_special_tokens:
|
||||
# Store dict of added token ids so that we can suppress
|
||||
# the spaces between them.
|
||||
if (added_token_ids := getattr(self.tokenizer, "added_token_ids",
|
||||
None)) is None:
|
||||
self.tokenizer.added_token_ids = added_token_ids = {
|
||||
tid: tok.content
|
||||
for tid, tok in
|
||||
self.tokenizer.get_added_tokens_decoder().items()
|
||||
}
|
||||
|
||||
if added_token_ids:
|
||||
self.last_special = False
|
||||
self.added_token_ids = added_token_ids
|
||||
else:
|
||||
# No added tokens.
|
||||
self.spaces_between_special_tokens = True
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
token = self._protected_step(next_token_id)
|
||||
|
||||
if not self.spaces_between_special_tokens:
|
||||
special_token = self.added_token_ids.get(next_token_id)
|
||||
is_special = special_token is not None
|
||||
if is_special and self.last_special:
|
||||
# Return raw token string without any prefixed spaces.
|
||||
token = special_token
|
||||
self.last_special = is_special
|
||||
|
||||
return token or ""
|
||||
|
||||
def _protected_step(self, next_token_id: int) -> Optional[str]:
|
||||
try:
|
||||
token = self.stream.step(self.tokenizer, next_token_id)
|
||||
except Exception as e:
|
||||
if str(e) != INVALID_PREFIX_ERR_MSG:
|
||||
raise e
|
||||
# Recover from edge case where tokenizer can produce non-monotonic,
|
||||
# invalid UTF-8 output, which breaks the internal state of
|
||||
# tokenizers' DecodeStream.
|
||||
# See https://github.com/vllm-project/vllm/issues/17448.
|
||||
logger.warning(
|
||||
"Encountered invalid prefix detokenization error"
|
||||
" for request %s, resetting decode stream.", self.request_id)
|
||||
self.stream = DecodeStream(self.skip_special_tokens)
|
||||
token = self.stream.step(self.tokenizer, next_token_id)
|
||||
return token
|
||||
|
||||
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Metadata for incremental detokenization.
|
||||
self.tokens, self.prefix_offset, self.read_offset = (
|
||||
convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=request.prompt_token_ids,
|
||||
skip_special_tokens=request.sampling_params.
|
||||
skip_special_tokens,
|
||||
))
|
||||
|
||||
self.token_ids.extend(request.prompt_token_ids)
|
||||
self.prompt_len = len(request.prompt_token_ids)
|
||||
|
||||
params = request.sampling_params
|
||||
self.skip_special_tokens = params.skip_special_tokens
|
||||
self.spaces_between_special_tokens = (
|
||||
params.spaces_between_special_tokens)
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> list[int]:
|
||||
return self.token_ids if not self.prompt_len else (
|
||||
self.token_ids[self.prompt_len:])
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
new_tokens, decoded_text, prefix_offset, read_offset = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=self.token_ids,
|
||||
prev_tokens=self.tokens,
|
||||
prefix_offset=self.prefix_offset,
|
||||
read_offset=self.read_offset,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.
|
||||
spaces_between_special_tokens,
|
||||
))
|
||||
|
||||
self.tokens.extend(new_tokens)
|
||||
self.prefix_offset = prefix_offset
|
||||
self.read_offset = read_offset
|
||||
|
||||
return decoded_text
|
||||
17
vllm/v1/engine/exceptions.py
Normal file
17
vllm/v1/engine/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
class EngineGenerateError(Exception):
|
||||
"""Raised when a AsyncLLM.generate() fails. Recoverable."""
|
||||
pass
|
||||
|
||||
|
||||
class EngineDeadError(Exception):
|
||||
"""Raised when the EngineCore dies. Unrecoverable."""
|
||||
|
||||
def __init__(self, *args, suppress_context: bool = False, **kwargs):
|
||||
ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501
|
||||
|
||||
super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs)
|
||||
# Make stack trace clearer when using with LLMEngine by
|
||||
# silencing irrelevant ZMQError.
|
||||
self.__suppress_context__ = suppress_context
|
||||
317
vllm/v1/engine/llm_engine.py
Normal file
317
vllm/v1/engine/llm_engine.py
Normal file
@@ -0,0 +1,317 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
TokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
|
||||
StatLoggerFactory)
|
||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""Legacy LLMEngine for backwards compatibility."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
multiprocess_mode: bool = False,
|
||||
) -> None:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"LLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
if stat_loggers is not None:
|
||||
raise NotImplementedError(
|
||||
"Passing StatLoggers to LLMEngine in V1 is not yet supported. "
|
||||
"Set VLLM_USE_V1=0 and file and issue on Github.")
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
self.log_stats = log_stats
|
||||
self.stat_logger: Optional[StatLoggerBase] = None
|
||||
if self.log_stats:
|
||||
self.stat_logger = PrometheusStatLogger(vllm_config)
|
||||
|
||||
# important: init dp group before init the engine_core
|
||||
# In the decoupled engine case this is handled in EngineCoreProc.
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
|
||||
self.dp_group = parallel_config.stateless_init_dp_group()
|
||||
else:
|
||||
self.dp_group = None
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(vllm_config=vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
mm_registry=mm_registry)
|
||||
|
||||
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(self.tokenizer,
|
||||
log_stats=self.log_stats)
|
||||
|
||||
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
||||
self.engine_core = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocess_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
if not multiprocess_mode:
|
||||
# for v0 compatibility
|
||||
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
||||
|
||||
# Don't keep the dummy data in memory
|
||||
self.reset_mm_cache()
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
disable_log_stats: bool = False,
|
||||
) -> "LLMEngine":
|
||||
return cls(vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=(not disable_log_stats),
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
enable_multiprocessing: bool = False,
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||
logger.debug("Enabling multiprocessing for LLMEngine.")
|
||||
enable_multiprocessing = True
|
||||
|
||||
# Create the LLMEngine.
|
||||
return cls(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=enable_multiprocessing)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.output_processor.get_num_unfinished_requests()
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
has_unfinished = self.output_processor.has_unfinished_requests()
|
||||
if self.dp_group is None:
|
||||
return has_unfinished
|
||||
return self.has_unfinished_requests_dp(has_unfinished)
|
||||
|
||||
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
|
||||
aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
|
||||
self.dp_group, has_unfinished)
|
||||
if not has_unfinished and aggregated_has_unfinished:
|
||||
self.should_execute_dummy_batch = True
|
||||
return aggregated_has_unfinished
|
||||
|
||||
@classmethod
|
||||
def validate_outputs(cls, outputs, output_type):
|
||||
return outputs
|
||||
|
||||
def abort_request(self, request_ids: list[str]) -> None:
|
||||
"""Remove request_ids from EngineCore and Detokenizer."""
|
||||
|
||||
request_ids = self.output_processor.abort_requests(request_ids)
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
# Process raw inputs into the request.
|
||||
prompt_str, request = self.processor.process_inputs(
|
||||
request_id, prompt, params, arrival_time, lora_request,
|
||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
|
||||
if n == 1:
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request, prompt_str, None, 0)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(request)
|
||||
return
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_req = ParentRequest(request_id, params)
|
||||
for idx in range(n):
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(child_request, prompt_str,
|
||||
parent_req, idx)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(child_request)
|
||||
|
||||
def step(self) -> list[RequestOutput]:
|
||||
|
||||
if self.should_execute_dummy_batch:
|
||||
self.should_execute_dummy_batch = False
|
||||
self.engine_core.execute_dummy_batch()
|
||||
return []
|
||||
|
||||
# 1) Get EngineCoreOutput from the EngineCore.
|
||||
outputs = self.engine_core.get_output()
|
||||
|
||||
# 2) Process EngineCoreOutputs.
|
||||
iteration_stats = IterationStats() if self.log_stats else None
|
||||
processed_outputs = self.output_processor.process_outputs(
|
||||
outputs.outputs,
|
||||
engine_core_timestamp=outputs.timestamp,
|
||||
iteration_stats=iteration_stats)
|
||||
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
|
||||
# 4) Record stats
|
||||
if self.stat_logger is not None:
|
||||
assert outputs.scheduler_stats is not None
|
||||
self.stat_logger.record(scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats)
|
||||
|
||||
return processed_outputs.request_outputs
|
||||
|
||||
def get_vllm_config(self):
|
||||
return self.vllm_config
|
||||
|
||||
def get_model_config(self):
|
||||
return self.model_config
|
||||
|
||||
def start_profile(self):
|
||||
self.engine_core.profile(True)
|
||||
|
||||
def stop_profile(self):
|
||||
self.engine_core.profile(False)
|
||||
|
||||
def reset_mm_cache(self):
|
||||
self.processor.mm_registry.reset_processor_cache()
|
||||
self.processor.mm_input_cache_client.reset()
|
||||
self.engine_core.reset_mm_cache()
|
||||
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None):
|
||||
self.engine_core.reset_prefix_cache()
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
self.engine_core.sleep(level)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None):
|
||||
self.engine_core.wake_up(tags)
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
return self.engine_core.is_sleeping()
|
||||
|
||||
def get_metrics(self) -> list[Metric]:
|
||||
assert self.log_stats, "Stat logging disabled"
|
||||
return get_metrics_snapshot()
|
||||
|
||||
def get_tokenizer_group(self) -> TokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
return self.engine_core.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove an already loaded LoRA adapter."""
|
||||
return self.engine_core.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
"""List all registered adapters."""
|
||||
return self.engine_core.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
def __del__(self):
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
199
vllm/v1/engine/logprobs.py
Normal file
199
vllm/v1/engine/logprobs.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||
from vllm.transformers_utils.detokenizer_utils import (
|
||||
AnyTokenizer, convert_ids_list_to_tokens)
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NONES = itertools.repeat(None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogprobsProcessor:
|
||||
|
||||
# Tokenizer for this request,
|
||||
# None if detokenization is disabled.
|
||||
tokenizer: Optional[AnyTokenizer]
|
||||
|
||||
# Logprobs for this request
|
||||
logprobs: Optional[SampleLogprobs]
|
||||
prompt_logprobs: Optional[PromptLogprobs]
|
||||
cumulative_logprob: Optional[float]
|
||||
num_logprobs: Optional[int]
|
||||
num_prompt_logprobs: Optional[int]
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
request: EngineCoreRequest,
|
||||
) -> "LogprobsProcessor":
|
||||
num_logprobs = request.sampling_params.logprobs
|
||||
num_prompt_logprobs = request.sampling_params.prompt_logprobs
|
||||
return cls(
|
||||
tokenizer=tokenizer,
|
||||
cumulative_logprob=(None if num_logprobs is None else 0.),
|
||||
logprobs=(None if num_logprobs is None else []),
|
||||
# NOTE: logprob of first prompt token is None.
|
||||
prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
|
||||
num_prompt_logprobs=num_prompt_logprobs,
|
||||
num_logprobs=num_logprobs,
|
||||
)
|
||||
|
||||
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
|
||||
"""Update with sample logprobs from EngineCore.
|
||||
|
||||
Outer lists are only of len > 1 if EngineCore made
|
||||
>1 tokens in prior step (e.g. in spec decoding).
|
||||
|
||||
Args:
|
||||
logprobs_lists: the lists of logprob tokens, logprobs, and ranks.
|
||||
|
||||
"""
|
||||
|
||||
assert self.num_logprobs is not None
|
||||
assert self.logprobs is not None
|
||||
assert self.cumulative_logprob is not None
|
||||
|
||||
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
|
||||
|
||||
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst,
|
||||
token_ids_lst):
|
||||
|
||||
# Detokenize (non-incrementally).
|
||||
decoded_tokens = NONES if self.tokenizer is None else (
|
||||
convert_ids_list_to_tokens(self.tokenizer, token_ids))
|
||||
|
||||
# Sampler puts the sampled logprob in first.
|
||||
sampled_token_logprob = logprobs[0]
|
||||
self.cumulative_logprob += sampled_token_logprob
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
self.logprobs.append(
|
||||
self._make_logprob_dict(
|
||||
logprobs,
|
||||
token_ids,
|
||||
decoded_tokens,
|
||||
rank,
|
||||
self.num_logprobs,
|
||||
))
|
||||
|
||||
def _update_prompt_logprobs(
|
||||
self,
|
||||
prompt_logprobs_tensors: LogprobsTensors,
|
||||
) -> None:
|
||||
"""Update with prompt logprobs from EngineCore.
|
||||
|
||||
Args:
|
||||
prompt_logprobs_tensors: tuple containing the prompt logprobs
|
||||
tensors.
|
||||
|
||||
"""
|
||||
|
||||
# Prompt logprobs are enabled.
|
||||
assert self.num_prompt_logprobs is not None
|
||||
assert self.prompt_logprobs is not None
|
||||
|
||||
token_ids, logprobs, ranks = prompt_logprobs_tensors
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
decoded_tokens = None if self.tokenizer is None else (
|
||||
convert_ids_list_to_tokens(self.tokenizer,
|
||||
token_ids.flatten().tolist()))
|
||||
|
||||
# Recover shapes.
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
|
||||
# Pythonize the torch tensors.
|
||||
prompt_token_ranks = ranks.tolist()
|
||||
prompt_logprobs = logprobs.tolist()
|
||||
token_ids = token_ids.tolist()
|
||||
|
||||
# Make Logprob for each position.
|
||||
for pos in range(num_prompt_tokens):
|
||||
# Handle flattening.
|
||||
offset = pos * num_logprobs
|
||||
offset_end = offset + num_logprobs
|
||||
decoded_tokens_for_pos = NONES \
|
||||
if decoded_tokens is None else decoded_tokens[offset:offset_end]
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
self.prompt_logprobs.append(
|
||||
self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos],
|
||||
decoded_tokens_for_pos,
|
||||
prompt_token_ranks[pos],
|
||||
self.num_prompt_logprobs))
|
||||
|
||||
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
|
||||
"""Pop and return all request prompt logprobs
|
||||
|
||||
The logprobs processor aggregates prompt chunk logprobs
|
||||
over one or more prefill chunks. This method returns
|
||||
all prompt logprobs at once and then forgets them.
|
||||
Ensures correct RequestOutputKind.DELTA semantics
|
||||
wherein all prompt logprobs are returned at once at
|
||||
the end of prefill.
|
||||
|
||||
Returns:
|
||||
None if prompt logprobs are disabled for this request.
|
||||
List of all prompt logprobs, otherwise.
|
||||
"""
|
||||
plp = self.prompt_logprobs
|
||||
if plp:
|
||||
self.prompt_logprobs = []
|
||||
return plp
|
||||
|
||||
@staticmethod
|
||||
def _make_logprob_dict(
|
||||
logprobs: list[float],
|
||||
logprob_token_ids: list[int],
|
||||
decoded_tokens: Iterable[Optional[str]],
|
||||
rank: int,
|
||||
num_logprobs: int,
|
||||
) -> dict[int, Logprob]:
|
||||
"""Make a Logprob dictionary for a position.
|
||||
|
||||
Args:
|
||||
logprobs: list of log probabilities
|
||||
logprob_token_ids: list of top token ids
|
||||
decoded_tokens: list of decoded top tokens
|
||||
rank: rank of the sampled token
|
||||
num_logprobs: number of logprobs requested
|
||||
by the user (in addition to sampled logprob)
|
||||
|
||||
Returns:
|
||||
dict[token id, Logprob]
|
||||
"""
|
||||
|
||||
# We do not need a special case for the sampled token
|
||||
# being in the topk, since inserting duplicated data
|
||||
# into a dictionary twice is the same as doing it once.
|
||||
topk_ranks = range(1, num_logprobs + 1)
|
||||
ranks = itertools.chain((rank, ), topk_ranks)
|
||||
|
||||
return {
|
||||
token_id: Logprob(
|
||||
logprob=logprob,
|
||||
rank=rank,
|
||||
decoded_token=token,
|
||||
)
|
||||
for token_id, logprob, rank, token in zip(
|
||||
logprob_token_ids, logprobs, ranks, decoded_tokens)
|
||||
}
|
||||
|
||||
def update_from_output(self, output: EngineCoreOutput) -> None:
|
||||
if output.new_logprobs is not None:
|
||||
self._update_sample_logprobs(output.new_logprobs)
|
||||
if output.new_prompt_logprobs_tensors is not None:
|
||||
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)
|
||||
91
vllm/v1/engine/mm_input_cache.py
Normal file
91
vllm/v1/engine/mm_input_cache.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.processing import ProcessingCache
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
# The idea of multimodal preprocessing caching is based on having a client and
|
||||
# a server, where the client executes in the frontend process (=P0) and the
|
||||
# server in the core process (=P1).
|
||||
#
|
||||
# -- Client:
|
||||
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
|
||||
# with built-in caching functionality, with mm_hash as its identifier.
|
||||
# - MirroredProcessingCache to keep track of the cached entries and
|
||||
# determine whether to send the MultiModalKwargs to P1.
|
||||
#
|
||||
# -- Server:
|
||||
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
|
||||
#
|
||||
# The caching for both client and server is mirrored, and this allows us
|
||||
# to avoid the serialization of "mm_inputs" (like pixel values) between
|
||||
# client (=P0) and server (=P1) processes if the mm_hash is found in the client
|
||||
# cache.
|
||||
|
||||
# Both Client and Server must use the same cache size
|
||||
# (to perform mirrored caching). This cache size is set by the environment
|
||||
# variable VLLM_MM_INPUT_CACHE_GIB.
|
||||
|
||||
|
||||
class MirroredProcessingCache:
|
||||
|
||||
def __init__(self, model_config):
|
||||
mm_config = model_config.multimodal_config
|
||||
disable_mm_preprocessor_cache = (
|
||||
mm_config is not None and mm_config.disable_mm_preprocessor_cache)
|
||||
self.use_cache = not disable_mm_preprocessor_cache
|
||||
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
|
||||
MultiModalKwargs)
|
||||
|
||||
def get_and_update_p0(
|
||||
self,
|
||||
mm_inputs: Sequence[MultiModalKwargs],
|
||||
mm_hashes: list[str],
|
||||
) -> Sequence[Optional[MultiModalKwargs]]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
if not self.use_cache:
|
||||
assert is_list_of(mm_inputs, MultiModalKwargs)
|
||||
return mm_inputs
|
||||
|
||||
full_mm_inputs = list[Optional[MultiModalKwargs]]()
|
||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||
if self.mm_cache.get(mm_hash) is not None:
|
||||
mm_input = None
|
||||
else:
|
||||
self.mm_cache[mm_hash] = mm_input
|
||||
|
||||
full_mm_inputs.append(mm_input)
|
||||
|
||||
return full_mm_inputs
|
||||
|
||||
def get_and_update_p1(
|
||||
self,
|
||||
mm_inputs: Sequence[Optional[MultiModalKwargs]],
|
||||
mm_hashes: list[str],
|
||||
) -> Sequence[MultiModalKwargs]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
if not self.use_cache:
|
||||
assert is_list_of(mm_inputs, MultiModalKwargs)
|
||||
return mm_inputs
|
||||
|
||||
full_mm_inputs = list[MultiModalKwargs]()
|
||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||
if mm_input is None:
|
||||
mm_input = self.mm_cache[mm_hash]
|
||||
else:
|
||||
self.mm_cache[mm_hash] = mm_input
|
||||
|
||||
full_mm_inputs.append(mm_input)
|
||||
|
||||
return full_mm_inputs
|
||||
|
||||
def reset(self) -> bool:
|
||||
self.mm_cache.clear()
|
||||
|
||||
return True
|
||||
428
vllm/v1/engine/output_processor.py
Normal file
428
vllm/v1/engine/output_processor.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
|
||||
RequestStateStats)
|
||||
|
||||
|
||||
class RequestOutputCollector:
|
||||
"""
|
||||
Collects streamed RequestOutputs per individual request,
|
||||
for hand-off to the consuming asyncio generate task.
|
||||
|
||||
When streaming deltas, RequestOutputs are merged if the
|
||||
producer gets ahead of the consumer.
|
||||
"""
|
||||
|
||||
def __init__(self, output_kind: RequestOutputKind):
|
||||
self.aggregate = output_kind == RequestOutputKind.DELTA
|
||||
self.output: Optional[Union[RequestOutput, Exception]] = None
|
||||
self.ready = asyncio.Event()
|
||||
|
||||
def put(self, output: Union[RequestOutput, Exception]) -> None:
|
||||
"""Non-blocking put operation."""
|
||||
if self.output is None or isinstance(output, Exception):
|
||||
self.output = output
|
||||
self.ready.set()
|
||||
elif isinstance(self.output, RequestOutput):
|
||||
# This ensures that request outputs with different request indexes
|
||||
# (if n > 1) do not override each other.
|
||||
self.output.add(output, aggregate=self.aggregate)
|
||||
|
||||
async def get(self) -> RequestOutput:
|
||||
"""Get operation blocks on put event."""
|
||||
while (output := self.output) is None:
|
||||
await self.ready.wait()
|
||||
self.output = None
|
||||
self.ready.clear()
|
||||
if isinstance(output, Exception):
|
||||
raise output
|
||||
return output
|
||||
|
||||
def get_nowait(self) -> Optional[RequestOutput]:
|
||||
"""Non-blocking get operation."""
|
||||
output = self.output
|
||||
if output is not None:
|
||||
self.output = None
|
||||
self.ready.clear()
|
||||
if isinstance(output, Exception):
|
||||
raise output
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputProcessorOutput:
|
||||
|
||||
request_outputs: list[RequestOutput]
|
||||
reqs_to_abort: list[str]
|
||||
|
||||
|
||||
class RequestState:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
parent_req: Optional[ParentRequest],
|
||||
request_index: int,
|
||||
lora_name: Optional[str],
|
||||
output_kind: RequestOutputKind,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: list[int],
|
||||
logprobs_processor: LogprobsProcessor,
|
||||
detokenizer: IncrementalDetokenizer,
|
||||
max_tokens_param: Optional[int],
|
||||
arrival_time: float,
|
||||
queue: Optional[RequestOutputCollector],
|
||||
log_stats: bool,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.parent_req = parent_req
|
||||
self.request_index = request_index
|
||||
self.lora_name = lora_name
|
||||
self.output_kind = output_kind
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_len = len(prompt_token_ids)
|
||||
self.logprobs_processor = logprobs_processor
|
||||
self.detokenizer = detokenizer
|
||||
self.max_tokens_param = max_tokens_param
|
||||
self.is_prefilling = True
|
||||
self.queue = queue
|
||||
|
||||
self.stats = RequestStateStats(
|
||||
arrival_time=arrival_time) if log_stats else None
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: AnyTokenizer,
|
||||
request: EngineCoreRequest,
|
||||
prompt: Optional[str],
|
||||
parent_req: Optional[ParentRequest],
|
||||
request_index: int,
|
||||
queue: Optional[RequestOutputCollector],
|
||||
log_stats: bool,
|
||||
) -> "RequestState":
|
||||
if not request.sampling_params.detokenize:
|
||||
tokenizer = None
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
lora_name=(request.lora_request.name
|
||||
if request.lora_request is not None else None),
|
||||
output_kind=request.sampling_params.output_kind,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
logprobs_processor=LogprobsProcessor.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
),
|
||||
detokenizer=IncrementalDetokenizer.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
),
|
||||
max_tokens_param=(request.sampling_params.max_tokens if
|
||||
request.sampling_params is not None else None),
|
||||
arrival_time=request.arrival_time,
|
||||
queue=queue,
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
def make_request_output(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
finish_reason: Optional[FinishReason],
|
||||
stop_reason: Union[int, str, None],
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||
num_cached_tokens: int = 0,
|
||||
) -> Optional[RequestOutput]:
|
||||
|
||||
finished = finish_reason is not None
|
||||
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
|
||||
|
||||
if not finished and final_only:
|
||||
# Only the final output is required in FINAL_ONLY mode.
|
||||
return None
|
||||
|
||||
completion_output = self._new_completion_output(
|
||||
new_token_ids, finish_reason, stop_reason)
|
||||
|
||||
request_id = self.request_id
|
||||
if self.parent_req is None:
|
||||
outputs = [completion_output]
|
||||
else:
|
||||
request_id, outputs, finished = self.parent_req.get_outputs(
|
||||
request_id, completion_output)
|
||||
if not outputs:
|
||||
return None
|
||||
|
||||
return self._new_request_output(request_id, outputs, finished,
|
||||
kv_transfer_params, num_cached_tokens)
|
||||
|
||||
def _new_request_output(
|
||||
self,
|
||||
request_id: str,
|
||||
outputs: list[CompletionOutput],
|
||||
finished: bool,
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||
num_cached_tokens: int = 0,
|
||||
) -> RequestOutput:
|
||||
|
||||
if self.output_kind == RequestOutputKind.DELTA:
|
||||
# Side effect: logprobs processor forgets prompt logprobs
|
||||
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
|
||||
else:
|
||||
prompt_logprobs = self.logprobs_processor.prompt_logprobs
|
||||
|
||||
return RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=self.prompt,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
outputs=outputs,
|
||||
finished=finished,
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
num_cached_tokens=num_cached_tokens,
|
||||
)
|
||||
|
||||
def _new_completion_output(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
finish_reason: Optional[FinishReason],
|
||||
stop_reason: Union[int, str, None],
|
||||
) -> CompletionOutput:
|
||||
|
||||
finished = finish_reason is not None
|
||||
delta = self.output_kind == RequestOutputKind.DELTA
|
||||
|
||||
# Prepare text and token_ids, based on delta mode
|
||||
text = self.detokenizer.get_next_output_text(finished, delta)
|
||||
if not delta:
|
||||
token_ids = self.detokenizer.output_token_ids
|
||||
|
||||
# Prepare logprobs, based on delta mode
|
||||
logprobs = self.logprobs_processor.logprobs
|
||||
if delta and logprobs:
|
||||
logprobs = logprobs[-len(token_ids):]
|
||||
|
||||
return CompletionOutput(
|
||||
index=self.request_index,
|
||||
text=text,
|
||||
token_ids=token_ids,
|
||||
logprobs=logprobs,
|
||||
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
|
||||
finish_reason=str(finish_reason) if finished else None,
|
||||
stop_reason=stop_reason if finished else None)
|
||||
|
||||
|
||||
class OutputProcessor:
|
||||
"""Process EngineCoreOutputs into RequestOutputs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: TokenizerGroup,
|
||||
log_stats: bool,
|
||||
):
|
||||
self.log_stats = log_stats
|
||||
self.tokenizer = tokenizer
|
||||
self.request_states: dict[str, RequestState] = {}
|
||||
self.parent_requests: dict[str, ParentRequest] = {}
|
||||
self.lora_states = LoRARequestStates()
|
||||
|
||||
def get_num_unfinished_requests(self):
|
||||
return len(self.request_states)
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return len(self.request_states) > 0
|
||||
|
||||
def propagate_error(self, e: Exception):
|
||||
"""Propagate error to all generate() tasks."""
|
||||
|
||||
for _, state in self.request_states.items():
|
||||
assert state.queue is not None
|
||||
state.queue.put(e)
|
||||
|
||||
def abort_requests(
|
||||
self,
|
||||
request_ids: Iterable[str],
|
||||
) -> list[str]:
|
||||
request_ids_to_abort = []
|
||||
for request_id in request_ids:
|
||||
req_state = self.request_states.pop(request_id, None)
|
||||
if req_state is not None:
|
||||
self.lora_states.abort_request(req_state)
|
||||
request_ids_to_abort.append(request_id)
|
||||
else:
|
||||
parent = self.parent_requests.pop(request_id, None)
|
||||
if parent and parent.child_requests:
|
||||
self.abort_requests(parent.child_requests)
|
||||
request_ids_to_abort.extend(parent.child_requests)
|
||||
return request_ids_to_abort
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: EngineCoreRequest,
|
||||
prompt: Optional[str],
|
||||
parent_req: Optional[ParentRequest] = None,
|
||||
request_index: int = 0,
|
||||
queue: Optional[RequestOutputCollector] = None,
|
||||
) -> None:
|
||||
request_id = request.request_id
|
||||
if request_id in self.request_states:
|
||||
raise ValueError(f"Request id {request_id} already running.")
|
||||
|
||||
req_state = RequestState.from_new_request(
|
||||
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
|
||||
request=request,
|
||||
prompt=prompt,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
queue=queue,
|
||||
log_stats=self.log_stats)
|
||||
self.request_states[request_id] = req_state
|
||||
self.lora_states.add_request(req_state)
|
||||
if parent_req:
|
||||
self.parent_requests[parent_req.request_id] = parent_req
|
||||
|
||||
def process_outputs(
|
||||
self,
|
||||
engine_core_outputs: list[EngineCoreOutput],
|
||||
engine_core_timestamp: Optional[float] = None,
|
||||
iteration_stats: Optional[IterationStats] = None,
|
||||
) -> OutputProcessorOutput:
|
||||
"""
|
||||
Process the EngineCoreOutputs:
|
||||
1) Compute stats for logging
|
||||
2) Detokenize
|
||||
3) Create and handle RequestOutput objects:
|
||||
* If there is a queue (for usage with AsyncLLM),
|
||||
put the RequestOutput objects into the queue for
|
||||
handling by the per-request generate() tasks.
|
||||
|
||||
* If there is no queue (for usage with LLMEngine),
|
||||
return a list of RequestOutput objects.
|
||||
|
||||
NOTE FOR DEVELOPERS
|
||||
|
||||
vLLM V1 minimizes the number of python loops over the full
|
||||
batch to ensure system overheads are minimized. This is the
|
||||
only function that should loop over EngineCoreOutputs.
|
||||
|
||||
If you need to touch every element of the batch, do it from
|
||||
within the loop below.
|
||||
"""
|
||||
|
||||
request_outputs: list[RequestOutput] = []
|
||||
reqs_to_abort: list[str] = []
|
||||
for engine_core_output in engine_core_outputs:
|
||||
req_id = engine_core_output.request_id
|
||||
req_state = self.request_states.get(req_id)
|
||||
if req_state is None:
|
||||
# Ignore output for already-aborted request.
|
||||
continue
|
||||
|
||||
# 1) Compute stats for this iteration.
|
||||
self._update_stats_from_output(req_state, engine_core_output,
|
||||
engine_core_timestamp,
|
||||
iteration_stats)
|
||||
|
||||
new_token_ids = engine_core_output.new_token_ids
|
||||
finish_reason = engine_core_output.finish_reason
|
||||
stop_reason = engine_core_output.stop_reason
|
||||
kv_transfer_params = engine_core_output.kv_transfer_params
|
||||
num_cached_tokens = engine_core_output.num_cached_tokens
|
||||
req_state.is_prefilling = False
|
||||
|
||||
# 2) Detokenize the token ids into text and perform stop checks.
|
||||
stop_string = req_state.detokenizer.update(
|
||||
new_token_ids, finish_reason == FinishReason.STOP)
|
||||
if stop_string:
|
||||
finish_reason = FinishReason.STOP
|
||||
stop_reason = stop_string
|
||||
|
||||
# 3) Compute sample and prompt logprobs for request, if required.
|
||||
req_state.logprobs_processor.update_from_output(engine_core_output)
|
||||
|
||||
# 4) Create and handle RequestOutput objects.
|
||||
if request_output := req_state.make_request_output(
|
||||
new_token_ids, finish_reason, stop_reason,
|
||||
kv_transfer_params, num_cached_tokens):
|
||||
if req_state.queue is not None:
|
||||
# AsyncLLM: put into queue for handling by generate().
|
||||
req_state.queue.put(request_output)
|
||||
else:
|
||||
# LLMEngine: return list of RequestOutputs.
|
||||
request_outputs.append(request_output)
|
||||
|
||||
# Free completed requests.
|
||||
if finish_reason is not None:
|
||||
self.request_states.pop(req_id)
|
||||
# Remove parent request if applicable.
|
||||
parent_req = req_state.parent_req
|
||||
if parent_req and not parent_req.child_requests:
|
||||
self.parent_requests.pop(parent_req.request_id, None)
|
||||
if not engine_core_output.finished:
|
||||
# If req not finished in EngineCore, but Detokenizer
|
||||
# detected stop string, abort needed in EngineCore.
|
||||
reqs_to_abort.append(req_id)
|
||||
|
||||
# Track per-request stats
|
||||
self._update_stats_from_finished(req_state, finish_reason,
|
||||
iteration_stats)
|
||||
|
||||
self.lora_states.update_iteration_stats(iteration_stats)
|
||||
|
||||
return OutputProcessorOutput(
|
||||
request_outputs=request_outputs,
|
||||
reqs_to_abort=reqs_to_abort,
|
||||
)
|
||||
|
||||
def _update_stats_from_output(self, req_state: RequestState,
|
||||
engine_core_output: EngineCoreOutput,
|
||||
engine_core_timestamp: Optional[float],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
lora_stats = self.lora_states.get_stats(req_state)
|
||||
|
||||
assert engine_core_timestamp is not None
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_output(engine_core_output,
|
||||
engine_core_timestamp,
|
||||
req_state.is_prefilling,
|
||||
req_state.prompt_len,
|
||||
req_state.stats, lora_stats)
|
||||
|
||||
def _update_stats_from_finished(self, req_state: RequestState,
|
||||
finish_reason: Optional[FinishReason],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
assert finish_reason is not None
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_finished_request(
|
||||
finish_reason=finish_reason,
|
||||
num_prompt_tokens=len(req_state.prompt_token_ids),
|
||||
max_tokens_param=req_state.max_tokens_param,
|
||||
req_stats=req_state.stats)
|
||||
self.lora_states.finish_request(req_state)
|
||||
|
||||
ParentRequest.observe_finished_request(
|
||||
req_state.parent_req, iteration_stats,
|
||||
req_state.stats.num_generation_tokens)
|
||||
133
vllm/v1/engine/parallel_sampling.py
Normal file
133
vllm/v1/engine/parallel_sampling.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import copy
|
||||
from typing import Optional
|
||||
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
|
||||
class ParentRequest:
|
||||
"""Info, state & processing for parallel sampling request.
|
||||
|
||||
Store parent request ID and sampling params.
|
||||
Facilitate generating child request sampling params.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
sampling_params: SamplingParams
|
||||
|
||||
# To track the completion of child requests
|
||||
child_requests: set[str]
|
||||
|
||||
# To aggregate child completions when not streaming
|
||||
output_aggregator: list[CompletionOutput]
|
||||
|
||||
# To find the max number of generated tokens across all children
|
||||
max_num_generation_tokens: int
|
||||
|
||||
# To efficiently obtain child sampling params
|
||||
cached_child_sampling_params: Optional[SamplingParams]
|
||||
|
||||
def __init__(self, request_id: str,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
self.request_id = request_id
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.child_requests = set()
|
||||
self.output_aggregator = [None] * sampling_params.n if (
|
||||
sampling_params.output_kind
|
||||
== RequestOutputKind.FINAL_ONLY) else []
|
||||
self.max_num_generation_tokens = 0
|
||||
self.cached_child_sampling_params = None
|
||||
|
||||
def _get_child_sampling_params(
|
||||
self,
|
||||
index: int,
|
||||
) -> SamplingParams:
|
||||
"""Efficiently obtain child `sampling_params`
|
||||
|
||||
If `sampling_params.seed` is not `None` then
|
||||
each child request requires a unique clone of
|
||||
parent `sampling_params` with a unique seed.
|
||||
|
||||
Args:
|
||||
index: index within `n` child requests
|
||||
|
||||
Returns:
|
||||
Child `sampling_params` instance.
|
||||
"""
|
||||
seed = self.sampling_params.seed
|
||||
if self.cached_child_sampling_params:
|
||||
# Reuse child sampling_params data structure
|
||||
return self.cached_child_sampling_params
|
||||
# Build child sampling_params
|
||||
child_sampling_params = copy(self.sampling_params)
|
||||
child_sampling_params.n = 1
|
||||
if seed is None:
|
||||
# Cache child sampling_params for later reuse
|
||||
self.cached_child_sampling_params = child_sampling_params
|
||||
else:
|
||||
# Each child gets a clone with a unique seed
|
||||
child_sampling_params.seed = seed + index
|
||||
return child_sampling_params
|
||||
|
||||
def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
|
||||
"""Get child request ID and sampling params.
|
||||
|
||||
Args:
|
||||
index: index within `n` child requests.
|
||||
|
||||
Returns:
|
||||
(request ID, sampling_params) tuple
|
||||
"""
|
||||
child_req_id = f"{index}_{self.request_id}"
|
||||
self.child_requests.add(child_req_id)
|
||||
return child_req_id, self._get_child_sampling_params(index)
|
||||
|
||||
@property
|
||||
def n(self) -> int:
|
||||
return self.sampling_params.n
|
||||
|
||||
def get_outputs(
|
||||
self,
|
||||
child_request_id: str,
|
||||
completion_output: CompletionOutput,
|
||||
) -> tuple[str, list[CompletionOutput], bool]:
|
||||
if completion_output.finished():
|
||||
self.child_requests.remove(child_request_id)
|
||||
|
||||
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
|
||||
# If streaming, just return the current output.
|
||||
outputs = [completion_output]
|
||||
else:
|
||||
# If not streaming, aggregate the n final outputs.
|
||||
self.output_aggregator[completion_output.index] = completion_output
|
||||
outputs = [] if self.child_requests else self.output_aggregator
|
||||
|
||||
finished = not self.child_requests
|
||||
return self.request_id, outputs, finished
|
||||
|
||||
def observe_num_generation_tokens(self, num_generation_tokens: int):
|
||||
self.max_num_generation_tokens = max(num_generation_tokens,
|
||||
self.max_num_generation_tokens)
|
||||
return self.max_num_generation_tokens
|
||||
|
||||
@staticmethod
|
||||
def observe_finished_request(parent_req: Optional['ParentRequest'],
|
||||
iteration_stats: IterationStats,
|
||||
num_generation_tokens: int):
|
||||
|
||||
n_param = parent_req.n if parent_req is not None else 1
|
||||
|
||||
if parent_req is not None:
|
||||
num_generation_tokens = parent_req.observe_num_generation_tokens(
|
||||
num_generation_tokens)
|
||||
|
||||
# Child requests finished, we can now record to iteration stats
|
||||
if parent_req is None or not parent_req.child_requests:
|
||||
iteration_stats.max_num_generation_tokens_iter.append(
|
||||
num_generation_tokens)
|
||||
iteration_stats.n_params_iter.append(n_param)
|
||||
407
vllm/v1/engine/processor.py
Normal file
407
vllm/v1/engine/processor.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
MultiModalRegistry)
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
validate_xgrammar_grammar)
|
||||
|
||||
|
||||
class Processor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
tokenizer: TokenizerGroup,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.decoding_config = vllm_config.decoding_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.generation_config_fields = (
|
||||
self.model_config.try_get_generation_config())
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
|
||||
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
|
||||
|
||||
# Multi-modal hasher (for images)
|
||||
self.use_hash = self.mm_input_cache_client.use_cache or \
|
||||
self.cache_config.enable_prefix_caching
|
||||
|
||||
@property
|
||||
def mm_registry(self):
|
||||
return self.input_preprocessor.mm_registry
|
||||
|
||||
def _validate_logprobs(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
max_logprobs = self.model_config.max_logprobs
|
||||
# Validate sample logprobs.
|
||||
if params.logprobs and params.logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Requested sample logprobs of {params.logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}")
|
||||
|
||||
# Validate prompt logprobs.
|
||||
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Requested prompt logprobs of {params.prompt_logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}")
|
||||
|
||||
def _validate_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
self._validate_structured_output(params)
|
||||
self._validate_logit_bias(params)
|
||||
|
||||
if params.allowed_token_ids is None:
|
||||
return
|
||||
if not params.allowed_token_ids:
|
||||
raise ValueError("allowed_token_ids is not None and empty!")
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
vocab_size = len(tokenizer)
|
||||
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
|
||||
raise ValueError(
|
||||
"allowed_token_ids contains out-of-vocab token id!")
|
||||
|
||||
def _validate_logit_bias(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
"""Validate logit_bias token IDs are within vocabulary range."""
|
||||
if not params.logit_bias:
|
||||
return
|
||||
|
||||
vocab_size = self.model_config.get_vocab_size()
|
||||
invalid_token_ids = []
|
||||
|
||||
for token_id in params.logit_bias:
|
||||
if token_id < 0 or token_id >= vocab_size:
|
||||
invalid_token_ids.append(token_id)
|
||||
|
||||
if invalid_token_ids:
|
||||
raise ValueError(
|
||||
f"token_id(s) {invalid_token_ids} in logit_bias contain "
|
||||
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
|
||||
|
||||
def _validate_supported_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
# Best of not yet supported.
|
||||
if params.best_of is not None and params.best_of > 1:
|
||||
raise ValueError("vLLM V1 does not yet support best_of.")
|
||||
# Logits processors not supported.
|
||||
if params.logits_processors:
|
||||
raise ValueError("vLLM V1 does not support per request "
|
||||
"user provided logits processors.")
|
||||
|
||||
def _validate_params(
|
||||
self,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
lora_request: Optional[LoRARequest],
|
||||
):
|
||||
"""
|
||||
Validate supported SamplingParam.
|
||||
Should raise ValueError if unsupported for API Server.
|
||||
"""
|
||||
|
||||
if not isinstance(params, SamplingParams):
|
||||
raise ValueError("V1 does not yet support Pooling models.")
|
||||
|
||||
self._validate_logprobs(params)
|
||||
self._validate_sampling_params(params, lora_request)
|
||||
self._validate_supported_sampling_params(params)
|
||||
|
||||
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
|
||||
def _validate_structured_output(self, params: SamplingParams) -> None:
|
||||
if not params.guided_decoding or not self.decoding_config:
|
||||
return
|
||||
|
||||
engine_level_backend = self.decoding_config.backend
|
||||
if params.guided_decoding.backend:
|
||||
# Request-level backend selection is not supported in V1.
|
||||
# The values may differ if `params` is reused and was set
|
||||
# to a specific backend based on `auto` behavior in a previous
|
||||
# request. We remember that it was set as a result of `auto`
|
||||
# using the `_auto` option set on the backend in the params.
|
||||
if (params.guided_decoding.backend != engine_level_backend
|
||||
and not (engine_level_backend == "auto"
|
||||
and params.guided_decoding.backend_was_auto)):
|
||||
raise ValueError(
|
||||
"Request-level structured output backend selection is no "
|
||||
"longer supported. The request specified "
|
||||
f"'{params.guided_decoding.backend}', but vLLM was "
|
||||
f"initialised with '{engine_level_backend}'. This error "
|
||||
"can be resolved by removing backend selection from the "
|
||||
"request.")
|
||||
else:
|
||||
params.guided_decoding.backend = engine_level_backend
|
||||
|
||||
# Request content validation
|
||||
if engine_level_backend.startswith("xgrammar"):
|
||||
# xgrammar with no fallback
|
||||
validate_xgrammar_grammar(params)
|
||||
elif engine_level_backend.startswith("guidance"):
|
||||
# TODO: ideally we would have the LLTokenizer here as Lark syntax
|
||||
# allows <|special_token|> and similar, see
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||
# Without tokenizer these are disallowed in grammars.
|
||||
validate_guidance_grammar(params, tokenizer=None)
|
||||
else:
|
||||
# NOTE: engine_level_backend must be "auto" here, because we have
|
||||
# checked supported_backends above.
|
||||
# "auto" is an opt-in to opinionated behavior where we try to
|
||||
# choose a backend based on request contents. This is not the
|
||||
# default as it is less predictable and subject to change
|
||||
# between releases as feature support changes.
|
||||
try:
|
||||
validate_xgrammar_grammar(params)
|
||||
params.guided_decoding.backend = "xgrammar"
|
||||
except ValueError:
|
||||
# The request either failed validation
|
||||
# or includes some jsonschema feature(s) that
|
||||
# are not supported in xgrammar. Fall back to guidance.
|
||||
validate_guidance_grammar(params, tokenizer=None)
|
||||
params.guided_decoding.backend = "guidance"
|
||||
# Remember that this backend was set automatically
|
||||
params.guided_decoding.backend_was_auto = True
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> tuple[Optional[str], EngineCoreRequest]:
|
||||
|
||||
# TODO(woosuk): Support pooling models.
|
||||
# TODO(woosuk): Support encoder-decoder models.
|
||||
self._validate_lora(lora_request)
|
||||
self._validate_params(params, lora_request)
|
||||
if priority != 0:
|
||||
raise ValueError("V1 does not support priority yet.")
|
||||
if trace_headers is not None:
|
||||
raise ValueError("V1 does not support tracing yet.")
|
||||
if prompt_adapter_request is not None:
|
||||
raise ValueError("V1 does not support prompt_adapter_request.")
|
||||
|
||||
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
|
||||
data_parallel_size):
|
||||
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
|
||||
f"is out of range [0, {data_parallel_size}).")
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
# Process inputs, which includes:
|
||||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||
# 2. For multimodal models with a merged preprocessor, preprocess
|
||||
# multimodal data and expand prompt token ids accordingly.
|
||||
# 3. Apply prompt adapter to prompt token ids if one exists.
|
||||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=self.use_hash,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.validate_request(
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
processed_inputs=processed_inputs,
|
||||
)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
|
||||
# TODO: Impl encoder-decoder
|
||||
if encoder_inputs is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
assert isinstance(params, SamplingParams)
|
||||
# TODO: can we avoid cloning here in multiproc case?
|
||||
sampling_params = params.clone()
|
||||
# If unset max tokens, then generate up to the max_model_len.
|
||||
if sampling_params.max_tokens is None:
|
||||
sampling_params.max_tokens = (
|
||||
self.model_config.max_model_len -
|
||||
len(decoder_inputs["prompt_token_ids"]))
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
sampling_params.update_from_tokenizer(
|
||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||
|
||||
# Multimodal related.
|
||||
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
|
||||
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
||||
sorted_mm_hashes: Optional[list[str]] = None
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
|
||||
|
||||
# Merge and flatten multimodal placeholders, hashes and inputs
|
||||
# from dictionaries to lists, and sort them by each item's position
|
||||
# in the input sequence.
|
||||
(
|
||||
sorted_item_modalities,
|
||||
sorted_mm_positions,
|
||||
sorted_mm_hashes,
|
||||
) = merge_and_sort_multimodal_metadata(
|
||||
decoder_inputs["mm_placeholders"],
|
||||
decoder_inputs["mm_hashes"] if self.use_hash else None,
|
||||
)
|
||||
|
||||
# The output of merged multi-modal processor (`decoder_mm_inputs`)
|
||||
# is a single MultiModalKwargs for all items from all modalities.
|
||||
# This code flattens kwargs for individual items in a list and
|
||||
# sorts them by each item's position in the input sequence if there
|
||||
# are multiple modalities.
|
||||
unique_modalities = set(sorted_item_modalities)
|
||||
if len(unique_modalities) > 1:
|
||||
orig_sorted_mm_inputs = []
|
||||
used_indices = {modality: 0 for modality in unique_modalities}
|
||||
|
||||
for modality in sorted_item_modalities:
|
||||
items = decoder_mm_inputs.get_items(modality)
|
||||
item = items[used_indices[modality]]
|
||||
|
||||
orig_sorted_mm_inputs.append(
|
||||
MultiModalKwargs.from_items([item]))
|
||||
used_indices[modality] += 1
|
||||
else:
|
||||
orig_sorted_mm_inputs = [
|
||||
MultiModalKwargs.from_items([item]) for item in
|
||||
decoder_mm_inputs.get_items(sorted_item_modalities[0])
|
||||
]
|
||||
|
||||
if sorted_mm_hashes is not None:
|
||||
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
|
||||
orig_sorted_mm_inputs, sorted_mm_hashes)
|
||||
else:
|
||||
sorted_mm_inputs = orig_sorted_mm_inputs
|
||||
|
||||
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=decoder_inputs["prompt_token_ids"],
|
||||
mm_inputs=sorted_mm_inputs,
|
||||
mm_hashes=sorted_mm_hashes,
|
||||
mm_placeholders=sorted_mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
eos_token_id=eos_token_id,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
cache_salt=decoder_inputs.get("cache_salt"),
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
def _validate_model_inputs(self,
|
||||
inputs: ProcessorInputs,
|
||||
lora_request: Optional[LoRARequest] = None):
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
|
||||
|
||||
if encoder_inputs is not None:
|
||||
self._validate_model_input(encoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="encoder")
|
||||
|
||||
self._validate_model_input(decoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="decoder")
|
||||
|
||||
def _validate_model_input(
|
||||
self,
|
||||
prompt_inputs: SingletonInputs,
|
||||
lora_request: Optional[LoRARequest],
|
||||
*,
|
||||
prompt_type: Literal["encoder", "decoder"],
|
||||
):
|
||||
model_config = self.model_config
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
|
||||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||
if not prompt_ids:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
pass # Mllama may have empty encoder inputs for text-only data
|
||||
else:
|
||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||
|
||||
max_input_id = max(prompt_ids, default=0)
|
||||
if max_input_id > tokenizer.max_token_id:
|
||||
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
|
||||
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
if len(prompt_ids) > max_prompt_len:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
mm_registry = self.input_preprocessor.mm_registry
|
||||
mm_processor = mm_registry.create_processor(
|
||||
model_config,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
assert isinstance(mm_processor, EncDecMultiModalProcessor)
|
||||
|
||||
if mm_processor.pad_dummy_encoder_prompt:
|
||||
return # Skip encoder length check for Whisper
|
||||
|
||||
if model_config.is_multimodal_model:
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens plus multimodal tokens. For image "
|
||||
"inputs, the number of image tokens depends on the number "
|
||||
"of images, and possibly their aspect ratios as well.")
|
||||
else:
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens.")
|
||||
|
||||
raise ValueError(
|
||||
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
|
||||
f"longer than the maximum model length of {max_prompt_len}. "
|
||||
f"{suggestion}")
|
||||
|
||||
# TODO: Find out how many placeholder tokens are there so we can
|
||||
# check that chunked prefill does not truncate them
|
||||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||
0
vllm/v1/executor/__init__.py
Normal file
0
vllm/v1/executor/__init__.py
Normal file
113
vllm/v1/executor/abstract.py
Normal file
113
vllm/v1/executor/abstract.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
UniProcExecutor as UniProcExecutorV0)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
FailureCallback = Callable[[], None]
|
||||
|
||||
|
||||
class Executor(ExecutorBase):
|
||||
"""
|
||||
Abstract class for v1 executors, mainly define some methods for v1.
|
||||
For methods shared by v0 and v1, define them in ExecutorBase"""
|
||||
|
||||
@staticmethod
|
||||
def get_class(vllm_config: VllmConfig) -> type["Executor"]:
|
||||
executor_class: type[Executor]
|
||||
parallel_config = vllm_config.parallel_config
|
||||
distributed_executor_backend = (
|
||||
parallel_config.distributed_executor_backend)
|
||||
# distributed_executor_backend must be set in VllmConfig.__post_init__
|
||||
if isinstance(distributed_executor_backend, type):
|
||||
if not issubclass(distributed_executor_backend, ExecutorBase):
|
||||
raise TypeError(
|
||||
"distributed_executor_backend must be a subclass of "
|
||||
f"ExecutorBase. Got {distributed_executor_backend}.")
|
||||
executor_class = distributed_executor_backend
|
||||
elif distributed_executor_backend == "ray":
|
||||
from vllm.v1.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor)
|
||||
executor_class = RayDistributedExecutor
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
executor_class = MultiprocExecutor
|
||||
elif distributed_executor_backend == "uni":
|
||||
executor_class = UniProcExecutor
|
||||
elif distributed_executor_backend == "external_launcher":
|
||||
# TODO: make v1 scheduling deterministic
|
||||
# to support external launcher
|
||||
executor_class = ExecutorWithExternalLauncher
|
||||
else:
|
||||
raise ValueError("Unknown distributed executor backend: "
|
||||
f"{distributed_executor_backend}")
|
||||
return executor_class
|
||||
|
||||
def initialize_from_config(self,
|
||||
kv_cache_configs: list[KVCacheConfig]) -> None:
|
||||
"""
|
||||
Initialize the KV caches and begin the model execution loop of the
|
||||
underlying workers.
|
||||
"""
|
||||
self.collective_rpc("initialize_from_config",
|
||||
args=(kv_cache_configs, ))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
|
||||
def register_failure_callback(self, callback: FailureCallback):
|
||||
"""
|
||||
Register a function to be called if the executor enters a permanent
|
||||
failed state.
|
||||
"""
|
||||
pass
|
||||
|
||||
def determine_available_memory(self) -> list[int]: # in bytes
|
||||
output = self.collective_rpc("determine_available_memory")
|
||||
return output
|
||||
|
||||
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
|
||||
output = self.collective_rpc("get_kv_cache_spec")
|
||||
return output
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
output = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ))
|
||||
return output[0]
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return 1
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
self.collective_rpc("profile", args=(is_start, ))
|
||||
|
||||
|
||||
class UniProcExecutor(UniProcExecutorV0, Executor):
|
||||
pass
|
||||
|
||||
|
||||
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
|
||||
|
||||
def determine_available_memory(self) -> list[int]: # in bytes
|
||||
# same as determine_num_available_blocks in v0,
|
||||
# we need to get the min across all ranks.
|
||||
memory = super().determine_available_memory()
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
cpu_group = get_world_group().cpu_group
|
||||
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
|
||||
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
||||
return [memory_tensor.item()]
|
||||
537
vllm/v1/executor/multiproc_executor.py
Normal file
537
vllm/v1/executor/multiproc_executor.py
Normal file
@@ -0,0 +1,537 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from multiprocessing.connection import Connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
import cloudpickle
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||
MessageQueue)
|
||||
from vllm.executor.multiproc_worker_utils import (
|
||||
_add_prefix, set_multiprocessing_worker_envs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_distributed_init_method, get_mp_context,
|
||||
get_open_port)
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_MS = 5000
|
||||
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||
|
||||
EXECUTE_MODEL_TIMEOUT_S = 300
|
||||
|
||||
|
||||
class MultiprocExecutor(Executor):
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
# Call self.shutdown at exit to clean up
|
||||
# and ensure workers will be terminated.
|
||||
self._finalizer = weakref.finalize(self, self.shutdown)
|
||||
self.is_failed = False
|
||||
self.shutdown_event = threading.Event()
|
||||
self.failure_callback: Optional[FailureCallback] = None
|
||||
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
||||
f"world_size ({self.world_size}) must be equal to the "
|
||||
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
||||
f"_parallel_size ({pp_parallel_size}). ")
|
||||
|
||||
# Set multiprocessing envs that are common to V0 and V1
|
||||
set_multiprocessing_worker_envs(self.parallel_config)
|
||||
|
||||
# Multiprocessing-based executor does not support multi-node setting.
|
||||
# Since it only works for single node, we can use the loopback address
|
||||
# 127.0.0.1 for communication.
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
"127.0.0.1", get_open_port())
|
||||
|
||||
# Initialize worker and set up message queues for SchedulerOutputs
|
||||
# and ModelRunnerOutputs
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
self.rpc_broadcast_mq = MessageQueue(self.world_size,
|
||||
self.world_size,
|
||||
max_chunk_bytes=max_chunk_bytes)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
|
||||
# Create workers
|
||||
unready_workers: list[UnreadyWorkerProcHandle] = []
|
||||
success = False
|
||||
try:
|
||||
for rank in range(self.world_size):
|
||||
unready_workers.append(
|
||||
WorkerProc.make_worker_process(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
))
|
||||
|
||||
# Workers must be created before wait_for_ready to avoid
|
||||
# deadlock, since worker.init_device() does a device sync.
|
||||
self.workers = WorkerProc.wait_for_ready(unready_workers)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||
# Must be kept consistent with the WorkerProc.
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
|
||||
self.start_worker_monitor()
|
||||
success = True
|
||||
finally:
|
||||
if not success:
|
||||
# Clean up the worker procs if there was a failure.
|
||||
self._ensure_worker_termination(
|
||||
[w.proc for w in unready_workers])
|
||||
|
||||
# For pipeline parallel, we use a thread pool for asynchronous
|
||||
# execute_model.
|
||||
if self.max_concurrent_batches > 1:
|
||||
# Note: must use only 1 IO thread to keep dequeue sequence
|
||||
# from the response queue
|
||||
self.io_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="mp_exec_io")
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
|
||||
def start_worker_monitor(self):
|
||||
workers = self.workers
|
||||
self_ref = weakref.ref(self)
|
||||
|
||||
# Monitors worker process liveness. If any die unexpectedly,
|
||||
# logs an error, shuts down the executor and invokes the failure
|
||||
# callback to inform the engine.
|
||||
def monitor_workers():
|
||||
sentinels = [h.proc.sentinel for h in workers]
|
||||
died = multiprocessing.connection.wait(sentinels)
|
||||
_self = self_ref()
|
||||
if not _self or getattr(_self, 'shutting_down', False):
|
||||
return
|
||||
_self.is_failed = True
|
||||
proc_name = next(h.proc.name for h in workers
|
||||
if h.proc.sentinel == died[0])
|
||||
logger.error(
|
||||
"Worker proc %s died unexpectedly, "
|
||||
"shutting down executor.", proc_name)
|
||||
_self.shutdown()
|
||||
callback = _self.failure_callback
|
||||
if callback is not None:
|
||||
_self.failure_callback = None
|
||||
callback()
|
||||
|
||||
Thread(target=monitor_workers,
|
||||
daemon=True,
|
||||
name="MultiprocWorkerMonitor").start()
|
||||
|
||||
def register_failure_callback(self, callback: FailureCallback):
|
||||
if self.is_failed:
|
||||
callback()
|
||||
else:
|
||||
self.failure_callback = callback
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
(output, ) = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ),
|
||||
unique_reply_rank=self.output_rank,
|
||||
non_block=self.max_concurrent_batches
|
||||
> 1,
|
||||
timeout=EXECUTE_MODEL_TIMEOUT_S)
|
||||
return output
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None,
|
||||
non_block: bool = False,
|
||||
unique_reply_rank: Optional[int] = None) -> list[Any]:
|
||||
if self.is_failed:
|
||||
raise RuntimeError("Executor failed.")
|
||||
|
||||
deadline = None if timeout is None else time.monotonic() + timeout
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# NOTE: If the args are heterogeneous, then we pack them into a list,
|
||||
# and unpack them in the method of every worker, because every worker
|
||||
# knows their own rank.
|
||||
try:
|
||||
if isinstance(method, str):
|
||||
send_method = method
|
||||
else:
|
||||
send_method = cloudpickle.dumps(
|
||||
method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.rpc_broadcast_mq.enqueue(
|
||||
(send_method, args, kwargs, unique_reply_rank))
|
||||
|
||||
workers = (self.workers[unique_reply_rank],
|
||||
) if unique_reply_rank is not None else self.workers
|
||||
responses = []
|
||||
|
||||
def get_response(w: WorkerProcHandle,
|
||||
dequeue_timeout: Optional[float] = None,
|
||||
cancel_event: Optional[threading.Event] = None):
|
||||
status, result = w.worker_response_mq.dequeue(
|
||||
timeout=dequeue_timeout, cancel=cancel_event)
|
||||
|
||||
if status != WorkerProc.ResponseStatus.SUCCESS:
|
||||
raise RuntimeError(
|
||||
f"Worker failed with error '{result}', please check the"
|
||||
" stack trace above for the root cause")
|
||||
return result
|
||||
|
||||
for w in workers:
|
||||
dequeue_timeout = None if deadline is None else (
|
||||
deadline - time.monotonic())
|
||||
|
||||
if non_block:
|
||||
result = self.io_thread_pool.submit( # type: ignore
|
||||
get_response, w, dequeue_timeout, self.shutdown_event)
|
||||
else:
|
||||
result = get_response(w, dequeue_timeout)
|
||||
|
||||
responses.append(result)
|
||||
|
||||
return responses
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||
|
||||
@staticmethod
|
||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||
"""Ensure that all worker processes are terminated. Assumes workers have
|
||||
received termination requests. Waits for processing, then sends
|
||||
termination and kill signals if needed."""
|
||||
|
||||
def wait_for_termination(procs, timeout):
|
||||
if not time:
|
||||
# If we are in late stage shutdown, the interpreter may replace
|
||||
# `time` with `None`.
|
||||
return all(not proc.is_alive() for proc in procs)
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if all(not proc.is_alive() for proc in procs):
|
||||
return True
|
||||
time.sleep(0.1)
|
||||
return False
|
||||
|
||||
# Send SIGTERM if still running
|
||||
active_procs = [proc for proc in worker_procs if proc.is_alive()]
|
||||
for p in active_procs:
|
||||
p.terminate()
|
||||
if not wait_for_termination(active_procs, 4):
|
||||
# Send SIGKILL if still running
|
||||
active_procs = [p for p in active_procs if p.is_alive()]
|
||||
for p in active_procs:
|
||||
p.kill()
|
||||
|
||||
def shutdown(self):
|
||||
"""Properly shut down the executor and its workers"""
|
||||
if not getattr(self, 'shutting_down', False):
|
||||
self.shutting_down = True
|
||||
self.shutdown_event.set()
|
||||
|
||||
if self.io_thread_pool is not None:
|
||||
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
|
||||
self.io_thread_pool = None
|
||||
|
||||
if workers := getattr(self, 'workers', None):
|
||||
for w in workers:
|
||||
w.worker_response_mq = None
|
||||
self._ensure_worker_termination([w.proc for w in workers])
|
||||
|
||||
self.rpc_broadcast_mq = None
|
||||
|
||||
def check_health(self) -> None:
|
||||
self.collective_rpc("check_health", timeout=10)
|
||||
return
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return self.parallel_config.pipeline_parallel_size
|
||||
|
||||
def _get_output_rank(self) -> int:
|
||||
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
|
||||
# (the first TP worker of the last PP stage).
|
||||
# Example:
|
||||
# Assuming TP=8, PP=4, then the world_size=32
|
||||
# 0-7, PP rank 0
|
||||
# 8-15, PP rank 1
|
||||
# 16-23, PP rank 2
|
||||
# 24-31, PP rank 3
|
||||
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
|
||||
return self.world_size - self.parallel_config.tensor_parallel_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnreadyWorkerProcHandle:
|
||||
"""WorkerProcess handle before READY."""
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
ready_pipe: Connection
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerProcHandle:
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
worker_response_mq: MessageQueue # The worker process writes to this MQ
|
||||
|
||||
@classmethod
|
||||
def from_unready_handle(
|
||||
cls, unready_handle: UnreadyWorkerProcHandle,
|
||||
worker_response_mq: MessageQueue) -> "WorkerProcHandle":
|
||||
return cls(
|
||||
proc=unready_handle.proc,
|
||||
rank=unready_handle.rank,
|
||||
worker_response_mq=worker_response_mq,
|
||||
)
|
||||
|
||||
|
||||
class WorkerProc:
|
||||
"""Wrapper that runs one Worker in a separate process."""
|
||||
|
||||
READY_STR = "READY"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle: Handle,
|
||||
):
|
||||
self.rank = rank
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
|
||||
# TODO: move `init_worker` to executor level as a collective rpc call
|
||||
all_kwargs: list[dict] = [
|
||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||
]
|
||||
is_driver_worker = (
|
||||
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
|
||||
all_kwargs[rank] = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
"distributed_init_method": distributed_init_method,
|
||||
"is_driver_worker": is_driver_worker,
|
||||
}
|
||||
wrapper.init_worker(all_kwargs)
|
||||
self.worker = wrapper
|
||||
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
|
||||
_add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid)
|
||||
|
||||
# Initialize MessageQueue for receiving SchedulerOutput
|
||||
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||
input_shm_handle, self.worker.rank)
|
||||
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq = MessageQueue(1, 1)
|
||||
|
||||
# Initialize device and loads weights
|
||||
self.worker.init_device()
|
||||
self.worker.load_model()
|
||||
|
||||
@staticmethod
|
||||
def make_worker_process(
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle, # Receive SchedulerOutput
|
||||
) -> UnreadyWorkerProcHandle:
|
||||
context = get_mp_context()
|
||||
# (reader, writer)
|
||||
reader, writer = context.Pipe(duplex=False)
|
||||
|
||||
process_kwargs = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
"distributed_init_method": distributed_init_method,
|
||||
"input_shm_handle": input_shm_handle,
|
||||
"ready_pipe": (reader, writer),
|
||||
}
|
||||
# Run EngineCore busy loop in background process.
|
||||
proc = context.Process(target=WorkerProc.worker_main,
|
||||
kwargs=process_kwargs,
|
||||
name=f"VllmWorker-{rank}",
|
||||
daemon=True)
|
||||
|
||||
proc.start()
|
||||
writer.close()
|
||||
return UnreadyWorkerProcHandle(proc, rank, reader)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_ready(
|
||||
unready_proc_handles: list[UnreadyWorkerProcHandle]
|
||||
) -> list[WorkerProcHandle]:
|
||||
|
||||
e = Exception("WorkerProc initialization failed due to "
|
||||
"an exception in a background process. "
|
||||
"See stack trace for root cause.")
|
||||
|
||||
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
|
||||
ready_proc_handles: list[Optional[WorkerProcHandle]] = (
|
||||
[None] * len(unready_proc_handles))
|
||||
while pipes:
|
||||
ready = multiprocessing.connection.wait(pipes.keys())
|
||||
for pipe in ready:
|
||||
assert isinstance(pipe, Connection)
|
||||
try:
|
||||
# Wait until the WorkerProc is ready.
|
||||
unready_proc_handle = pipes.pop(pipe)
|
||||
response: dict[str, Any] = pipe.recv()
|
||||
if response["status"] != "READY":
|
||||
raise e
|
||||
|
||||
# Extract the message queue handle.
|
||||
worker_response_mq = MessageQueue.create_from_handle(
|
||||
response["handle"], 0)
|
||||
ready_proc_handles[unready_proc_handle.rank] = (
|
||||
WorkerProcHandle.from_unready_handle(
|
||||
unready_proc_handle, worker_response_mq))
|
||||
|
||||
except EOFError:
|
||||
e.__suppress_context__ = True
|
||||
raise e from None
|
||||
|
||||
finally:
|
||||
# Close connection.
|
||||
pipe.close()
|
||||
|
||||
return cast(list[WorkerProcHandle], ready_proc_handles)
|
||||
|
||||
def shutdown(self):
|
||||
self.rpc_broadcast_mq = None
|
||||
self.worker_response_mq = None
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
||||
@staticmethod
|
||||
def worker_main(*args, **kwargs):
|
||||
""" Worker initialization and execution loops.
|
||||
This runs a background process """
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
# processes to terminate without error
|
||||
shutdown_requested = False
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the worker
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
worker = None
|
||||
# tuple[Connection, Connection]
|
||||
reader, ready_writer = kwargs.pop("ready_pipe")
|
||||
try:
|
||||
reader.close()
|
||||
worker = WorkerProc(*args, **kwargs)
|
||||
|
||||
# Send READY once we know everything is loaded
|
||||
ready_writer.send({
|
||||
"status":
|
||||
WorkerProc.READY_STR,
|
||||
"handle":
|
||||
worker.worker_response_mq.export_handle(),
|
||||
})
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered.
|
||||
# Must be kept consistent with the Executor
|
||||
worker.rpc_broadcast_mq.wait_until_ready()
|
||||
worker.worker_response_mq.wait_until_ready()
|
||||
ready_writer.close()
|
||||
ready_writer = None
|
||||
|
||||
worker.worker_busy_loop()
|
||||
|
||||
except Exception:
|
||||
# NOTE: if an Exception arises in busy_loop, we send
|
||||
# a FAILURE message over the MQ RPC to notify the Executor,
|
||||
# which triggers system shutdown.
|
||||
# TODO(rob): handle case where the MQ itself breaks.
|
||||
|
||||
if ready_writer is not None:
|
||||
logger.exception("WorkerProc failed to start.")
|
||||
else:
|
||||
logger.exception("WorkerProc failed.")
|
||||
|
||||
# The parent sends a SIGTERM to all worker processes if
|
||||
# any worker dies. Set this value so we don't re-throw
|
||||
# SystemExit() to avoid zmq exceptions in __del__.
|
||||
shutdown_requested = True
|
||||
|
||||
finally:
|
||||
if ready_writer is not None:
|
||||
ready_writer.close()
|
||||
# Clean up once worker exits busy loop
|
||||
if worker is not None:
|
||||
worker.shutdown()
|
||||
|
||||
class ResponseStatus(Enum):
|
||||
SUCCESS = auto()
|
||||
FAILURE = auto()
|
||||
|
||||
def worker_busy_loop(self):
|
||||
"""Main busy loop for Multiprocessing Workers"""
|
||||
while True:
|
||||
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
|
||||
|
||||
try:
|
||||
if isinstance(method, str):
|
||||
func = getattr(self.worker, method)
|
||||
elif isinstance(method, bytes):
|
||||
func = partial(cloudpickle.loads(method), self.worker)
|
||||
output = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Notes have been introduced in python 3.11
|
||||
if hasattr(e, "add_note"):
|
||||
e.add_note(traceback.format_exc())
|
||||
logger.exception("WorkerProc hit an exception.")
|
||||
# exception might not be serializable, so we convert it to
|
||||
# string, only for logging purpose.
|
||||
if output_rank is None or self.rank == output_rank:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.FAILURE, str(e)))
|
||||
continue
|
||||
|
||||
if output_rank is None or self.rank == output_rank:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.SUCCESS, output))
|
||||
62
vllm/v1/executor/ray_distributed_executor.py
Normal file
62
vllm/v1/executor/ray_distributed_executor.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import Union
|
||||
|
||||
from vllm.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor as RayDistributedExecutorV0)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
|
||||
class FutureWrapper(Future):
|
||||
"""A wrapper around a Ray output reference to meet the interface
|
||||
of .execute_model().
|
||||
"""
|
||||
|
||||
def __init__(self, ref):
|
||||
super().__init__()
|
||||
self.ref = ref
|
||||
|
||||
def result(self, timeout=None):
|
||||
if timeout is not None:
|
||||
raise NotImplementedError("timeout is not supported")
|
||||
return self.ref.get()
|
||||
|
||||
|
||||
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||
"""Ray distributed executor using Ray Compiled Graphs."""
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
"""Ray distributed executor supports pipeline parallelism,
|
||||
meaning that it allows PP size batches to be executed concurrently.
|
||||
"""
|
||||
return self.parallel_config.pipeline_parallel_size
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
"""Execute the model on the Ray workers.
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output to execute.
|
||||
|
||||
Returns:
|
||||
The model runner output.
|
||||
"""
|
||||
# Build the compiled DAG for the first time.
|
||||
if self.forward_dag is None: # type: ignore
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
refs = self.forward_dag.execute(scheduler_output) # type: ignore
|
||||
|
||||
# When PP is not used, we block here until the result is available.
|
||||
if self.max_concurrent_batches == 1:
|
||||
return refs[0].get()
|
||||
|
||||
# When PP is used, we return a FutureWrapper immediately so that
|
||||
# the scheduler can yield to the next batch.
|
||||
return FutureWrapper(refs[0])
|
||||
194
vllm/v1/kv_cache_interface.py
Normal file
194
vllm/v1/kv_cache_interface.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, get_dtype_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheSpec:
|
||||
"""
|
||||
A base class for specifying the KV cache format of one layer.
|
||||
"""
|
||||
|
||||
# number of tokens in a block
|
||||
block_size: int
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
"""
|
||||
The type identifier of this KV cache.
|
||||
Return different strings for layers with different KV cache type (e.g.,
|
||||
different number of tokens like full attention vs sliding window
|
||||
attention, different KV cache size per token like layers with different
|
||||
number of heads)
|
||||
|
||||
Returns:
|
||||
The type identifier of this KV cache.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
"""
|
||||
The size of a page with `block_size` tokens in bytes.
|
||||
|
||||
Returns:
|
||||
The page size
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
"""
|
||||
The maximum possible memory usage of this KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The KV cache size in bytes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
|
||||
"""
|
||||
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
|
||||
"All layers in the same KV cache group must share the same "
|
||||
"type_id.")
|
||||
return copy.deepcopy(specs[0])
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionSpec(KVCacheSpec):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
use_mla: bool
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
# For MLA we only store a single latent vector
|
||||
coef = 1 if self.use_mla else 2
|
||||
return coef * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullAttentionSpec(AttentionSpec):
|
||||
sliding_window: Optional[int] = None
|
||||
"""
|
||||
When hybrid allocator is disabled and the model contains both full
|
||||
attention layers and sliding window attention layers, sliding
|
||||
window attention are regarded as full attention in KV cache manager
|
||||
(blocks are allocated for all tokens), while computed as sliding window
|
||||
attention in model runner.
|
||||
In this case, we use FullAttentionSpec and record the sliding window size.
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of FullAttentionSpec objects into a single
|
||||
FullAttentionSpec object.
|
||||
"""
|
||||
merged_spec = super().merge(specs)
|
||||
sliding_window = set(spec.sliding_window for spec in specs
|
||||
if spec.sliding_window is not None)
|
||||
if len(sliding_window) == 0:
|
||||
merged_spec.sliding_window = None
|
||||
elif len(sliding_window) == 1:
|
||||
merged_spec.sliding_window = sliding_window.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
"All sliding window layers in the same KV cache group "
|
||||
"must have the same window size.")
|
||||
return merged_spec
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlidingWindowSpec(AttentionSpec):
|
||||
sliding_window: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.use_mla, "MLA is not supported for sliding window"
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_num_batched_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# During chunked prefill, we allocate KV cache for the last
|
||||
# `self.sliding_window-1` computed tokens plus the newly scheduled
|
||||
# tokens. And we won't allocate KV cache for more than `max_model_len`
|
||||
# tokens.
|
||||
num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
|
||||
max_model_len)
|
||||
|
||||
# +1 here because the sliding window may not start from the beginning
|
||||
# of the block. For example, if the block size is 4 and num_token
|
||||
# is 4, we need two blocks [XXCD] [EF] to store the sliding
|
||||
# window [CDEF] of 6 tokens.
|
||||
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheTensor:
|
||||
"""
|
||||
A class for specifying how the workers should initialize the KV cache.
|
||||
"""
|
||||
size: int # size of the KV cache tensor in bytes
|
||||
shared_by: list[str] # layer names that share the same KV cache tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheGroupSpec:
|
||||
"""
|
||||
Represents a group of model layers that share the same KV cache block table.
|
||||
These layers are regarded as one layer in the KV cache manager.
|
||||
"""
|
||||
# The names of model layers in this group
|
||||
layer_names: list[str]
|
||||
# The KV cache spec of this manager layer
|
||||
kv_cache_spec: KVCacheSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheConfig:
|
||||
"""
|
||||
The KV cache configuration of a model.
|
||||
"""
|
||||
"""The number of KV cache blocks"""
|
||||
num_blocks: int
|
||||
"""How should model runner initialize the KV cache tensors for each layer"""
|
||||
kv_cache_tensors: list[KVCacheTensor]
|
||||
"""
|
||||
The kv cache groups of the model.
|
||||
For models with only one type of attention, there is only one group that
|
||||
contains all layers.
|
||||
For models with multiple types of attention, there will be multiple groups,
|
||||
see `_get_kv_cache_config_uniform_page_size` for more details.
|
||||
"""
|
||||
kv_cache_groups: list[KVCacheGroupSpec]
|
||||
0
vllm/v1/metrics/__init__.py
Normal file
0
vllm/v1/metrics/__init__.py
Normal file
523
vllm/v1/metrics/loggers.py
Normal file
523
vllm/v1/metrics/loggers.py
Normal file
@@ -0,0 +1,523 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import prometheus_client
|
||||
|
||||
from vllm.config import SupportsMetricsInfo, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
|
||||
from vllm.v1.engine import FinishReason
|
||||
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||
|
||||
|
||||
class StatLoggerBase(ABC):
|
||||
"""Interface for logging metrics.
|
||||
|
||||
API users may define custom loggers that implement this interface.
|
||||
However, note that the `SchedulerStats` and `IterationStats` classes
|
||||
are not considered stable interfaces and may change in future versions.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def log_engine_initialized(self):
|
||||
...
|
||||
|
||||
def log(self): # noqa
|
||||
pass
|
||||
|
||||
|
||||
class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
self.engine_index = engine_index
|
||||
self.vllm_config = vllm_config
|
||||
self._reset(time.monotonic())
|
||||
self.last_scheduler_stats = SchedulerStats()
|
||||
# Prefix cache metrics. This cannot be reset.
|
||||
# TODO: Make the interval configurable.
|
||||
self.prefix_caching_metrics = PrefixCachingMetrics()
|
||||
self.spec_decoding_logging = SpecDecodingLogging()
|
||||
self.last_prompt_throughput: float = 0.0
|
||||
self.last_generation_throughput: float = 0.0
|
||||
|
||||
def _reset(self, now):
|
||||
self.last_log_time = now
|
||||
|
||||
# Tracked stats over current local logging interval.
|
||||
self.num_prompt_tokens: list[int] = []
|
||||
self.num_generation_tokens: list[int] = []
|
||||
|
||||
def _track_iteration_stats(self, iteration_stats: IterationStats):
|
||||
# Save tracked stats for token counters.
|
||||
self.num_prompt_tokens.append(iteration_stats.num_prompt_tokens)
|
||||
self.num_generation_tokens.append(
|
||||
iteration_stats.num_generation_tokens)
|
||||
|
||||
def _get_throughput(self, tracked_stats: list[int], now: float) -> float:
|
||||
# Compute summary metrics for tracked stats
|
||||
return float(np.sum(tracked_stats) / (now - self.last_log_time))
|
||||
|
||||
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
"""Log Stats to standard output."""
|
||||
|
||||
if iteration_stats:
|
||||
self._track_iteration_stats(iteration_stats)
|
||||
|
||||
if scheduler_stats is not None:
|
||||
self.prefix_caching_metrics.observe(
|
||||
scheduler_stats.prefix_cache_stats)
|
||||
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.spec_decoding_logging.observe(
|
||||
scheduler_stats.spec_decoding_stats)
|
||||
|
||||
self.last_scheduler_stats = scheduler_stats
|
||||
|
||||
def log(self):
|
||||
now = time.monotonic()
|
||||
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
|
||||
generation_throughput = self._get_throughput(
|
||||
self.num_generation_tokens, now)
|
||||
|
||||
self._reset(now)
|
||||
|
||||
scheduler_stats = self.last_scheduler_stats
|
||||
|
||||
log_fn = logger.info
|
||||
if not any(
|
||||
(prompt_throughput, generation_throughput,
|
||||
self.last_prompt_throughput, self.last_generation_throughput)):
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
self.last_generation_throughput = generation_throughput
|
||||
self.last_prompt_throughput = prompt_throughput
|
||||
|
||||
# Format and print output.
|
||||
log_fn(
|
||||
"Engine %03d: "
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Waiting: %d reqs, "
|
||||
"GPU KV cache usage: %.1f%%, "
|
||||
"Prefix cache hit rate: %.1f%%",
|
||||
self.engine_index,
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
scheduler_stats.num_running_reqs,
|
||||
scheduler_stats.num_waiting_reqs,
|
||||
scheduler_stats.gpu_cache_usage * 100,
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
)
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
|
||||
def log_engine_initialized(self):
|
||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||
logger.info(
|
||||
"Engine %03d: vllm cache_config_info with initialization "
|
||||
"after num_gpu_blocks is: %d", self.engine_index,
|
||||
self.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
|
||||
class PrometheusStatLogger(StatLoggerBase):
|
||||
_gauge_cls = prometheus_client.Gauge
|
||||
_counter_cls = prometheus_client.Counter
|
||||
_histogram_cls = prometheus_client.Histogram
|
||||
_spec_decoding_cls = SpecDecodingProm
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
|
||||
unregister_vllm_metrics()
|
||||
self.vllm_config = vllm_config
|
||||
self.engine_index = engine_index
|
||||
# Use this flag to hide metrics that were deprecated in
|
||||
# a previous release and which will be removed future
|
||||
self.show_hidden_metrics = \
|
||||
vllm_config.observability_config.show_hidden_metrics
|
||||
|
||||
labelnames = ["model_name", "engine"]
|
||||
labelvalues = [
|
||||
vllm_config.model_config.served_model_name,
|
||||
str(engine_index)
|
||||
]
|
||||
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
self.spec_decoding_prom = self._spec_decoding_cls(
|
||||
vllm_config.speculative_config, labelnames, labelvalues)
|
||||
|
||||
#
|
||||
# Scheduler state
|
||||
#
|
||||
self.gauge_scheduler_running = self._gauge_cls(
|
||||
name="vllm:num_requests_running",
|
||||
documentation="Number of requests in model execution batches.",
|
||||
multiprocess_mode="mostrecent",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.gauge_scheduler_waiting = self._gauge_cls(
|
||||
name="vllm:num_requests_waiting",
|
||||
documentation="Number of requests waiting to be processed.",
|
||||
multiprocess_mode="mostrecent",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
#
|
||||
# GPU cache
|
||||
#
|
||||
self.gauge_gpu_cache_usage = self._gauge_cls(
|
||||
name="vllm:gpu_cache_usage_perc",
|
||||
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
||||
multiprocess_mode="mostrecent",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_gpu_prefix_cache_queries = self._counter_cls(
|
||||
name="vllm:gpu_prefix_cache_queries",
|
||||
documentation=
|
||||
"GPU prefix cache queries, in terms of number of queried tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_gpu_prefix_cache_hits = self._counter_cls(
|
||||
name="vllm:gpu_prefix_cache_hits",
|
||||
documentation=
|
||||
"GPU prefix cache hits, in terms of number of cached tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
#
|
||||
# Counters
|
||||
#
|
||||
self.counter_num_preempted_reqs = self._counter_cls(
|
||||
name="vllm:num_preemptions",
|
||||
documentation="Cumulative number of preemption from the engine.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_prompt_tokens = self._counter_cls(
|
||||
name="vllm:prompt_tokens",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_generation_tokens = self._counter_cls(
|
||||
name="vllm:generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_request_success: dict[FinishReason,
|
||||
prometheus_client.Counter] = {}
|
||||
counter_request_success_base = self._counter_cls(
|
||||
name="vllm:request_success",
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=labelnames + ["finished_reason"])
|
||||
for reason in FinishReason:
|
||||
self.counter_request_success[
|
||||
reason] = counter_request_success_base.labels(*(labelvalues +
|
||||
[str(reason)]))
|
||||
|
||||
#
|
||||
# Histograms of counts
|
||||
#
|
||||
self.histogram_num_prompt_tokens_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_prompt_tokens",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.histogram_num_generation_tokens_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
# TODO: This metric might be incorrect in case of using multiple
|
||||
# api_server counts which uses prometheus mp.
|
||||
# See: https://github.com/vllm-project/vllm/pull/18053
|
||||
self.histogram_iteration_tokens = \
|
||||
self._histogram_cls(
|
||||
name="vllm:iteration_tokens_total",
|
||||
documentation="Histogram of number of tokens per engine_step.",
|
||||
buckets=[
|
||||
1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192,
|
||||
16384
|
||||
],
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.histogram_max_num_generation_tokens_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_max_num_generation_tokens",
|
||||
documentation=
|
||||
"Histogram of maximum number of requested generation tokens.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.histogram_n_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_params_n",
|
||||
documentation="Histogram of the n request parameter.",
|
||||
buckets=[1, 2, 5, 10, 20],
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.histogram_max_tokens_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_params_max_tokens",
|
||||
documentation="Histogram of the max_tokens request parameter.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
#
|
||||
# Histogram of timing intervals
|
||||
#
|
||||
self.histogram_time_to_first_token = \
|
||||
self._histogram_cls(
|
||||
name="vllm:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
buckets=[
|
||||
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
|
||||
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0,
|
||||
640.0, 2560.0
|
||||
],
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.histogram_time_per_output_token = \
|
||||
self._histogram_cls(
|
||||
name="vllm:time_per_output_token_seconds",
|
||||
documentation="Histogram of time per output token in seconds.",
|
||||
buckets=[
|
||||
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5,
|
||||
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
|
||||
],
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
request_latency_buckets = [
|
||||
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
|
||||
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
|
||||
]
|
||||
self.histogram_e2e_time_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:e2e_request_latency_seconds",
|
||||
documentation="Histogram of e2e request latency in seconds.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_queue_time_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_queue_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in WAITING phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_inference_time_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_inference_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in RUNNING phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_prefill_time_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_prefill_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in PREFILL phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_decode_time_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_decode_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in DECODE phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
#
|
||||
# LoRA metrics
|
||||
#
|
||||
|
||||
# TODO: This metric might be incorrect in case of using multiple
|
||||
# api_server counts which uses prometheus mp.
|
||||
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
|
||||
if vllm_config.lora_config is not None:
|
||||
self.labelname_max_lora = "max_lora"
|
||||
self.labelname_waiting_lora_adapters = "waiting_lora_adapters"
|
||||
self.labelname_running_lora_adapters = "running_lora_adapters"
|
||||
self.max_lora = vllm_config.lora_config.max_loras
|
||||
self.gauge_lora_info = \
|
||||
self._gauge_cls(
|
||||
name="vllm:lora_requests_info",
|
||||
documentation="Running stats on lora requests.",
|
||||
multiprocess_mode="sum",
|
||||
labelnames=[
|
||||
self.labelname_max_lora,
|
||||
self.labelname_waiting_lora_adapters,
|
||||
self.labelname_running_lora_adapters,
|
||||
],
|
||||
)
|
||||
|
||||
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
|
||||
|
||||
metrics_info = config_obj.metrics_info()
|
||||
metrics_info["engine"] = self.engine_index
|
||||
|
||||
name, documentation = None, None
|
||||
if type == "cache_config":
|
||||
name = "vllm:cache_config_info"
|
||||
documentation = "Information of the LLMEngine CacheConfig"
|
||||
assert name is not None, f"Unknown metrics info type {type}"
|
||||
|
||||
# Info type metrics are syntactic sugar for a gauge permanently set to 1
|
||||
# Since prometheus multiprocessing mode does not support Info, emulate
|
||||
# info here with a gauge.
|
||||
info_gauge = self._gauge_cls(
|
||||
name=name,
|
||||
documentation=documentation,
|
||||
multiprocess_mode="mostrecent",
|
||||
labelnames=metrics_info.keys(),
|
||||
).labels(**metrics_info)
|
||||
info_gauge.set(1)
|
||||
|
||||
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
"""Log to prometheus."""
|
||||
if scheduler_stats is not None:
|
||||
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
|
||||
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
|
||||
|
||||
self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage)
|
||||
|
||||
self.counter_gpu_prefix_cache_queries.inc(
|
||||
scheduler_stats.prefix_cache_stats.queries)
|
||||
self.counter_gpu_prefix_cache_hits.inc(
|
||||
scheduler_stats.prefix_cache_stats.hits)
|
||||
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.spec_decoding_prom.observe(
|
||||
scheduler_stats.spec_decoding_stats)
|
||||
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs)
|
||||
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
|
||||
self.counter_generation_tokens.inc(
|
||||
iteration_stats.num_generation_tokens)
|
||||
self.histogram_iteration_tokens.observe(
|
||||
iteration_stats.num_prompt_tokens + \
|
||||
iteration_stats.num_generation_tokens)
|
||||
|
||||
for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter:
|
||||
self.histogram_max_num_generation_tokens_request.observe(
|
||||
max_gen_tokens)
|
||||
for n_param in iteration_stats.n_params_iter:
|
||||
self.histogram_n_request.observe(n_param)
|
||||
for ttft in iteration_stats.time_to_first_tokens_iter:
|
||||
self.histogram_time_to_first_token.observe(ttft)
|
||||
for tpot in iteration_stats.time_per_output_tokens_iter:
|
||||
self.histogram_time_per_output_token.observe(tpot)
|
||||
|
||||
for finished_request in iteration_stats.finished_requests:
|
||||
self.counter_request_success[finished_request.finish_reason].inc()
|
||||
self.histogram_e2e_time_request.observe(
|
||||
finished_request.e2e_latency)
|
||||
self.histogram_queue_time_request.observe(
|
||||
finished_request.queued_time)
|
||||
self.histogram_prefill_time_request.observe(
|
||||
finished_request.prefill_time)
|
||||
self.histogram_inference_time_request.observe(
|
||||
finished_request.inference_time)
|
||||
self.histogram_decode_time_request.observe(
|
||||
finished_request.decode_time)
|
||||
self.histogram_num_prompt_tokens_request.observe(
|
||||
finished_request.num_prompt_tokens)
|
||||
self.histogram_num_generation_tokens_request.observe(
|
||||
finished_request.num_generation_tokens)
|
||||
self.histogram_max_tokens_request.observe(
|
||||
finished_request.max_tokens_param)
|
||||
|
||||
if self.gauge_lora_info is not None:
|
||||
running_lora_adapters = \
|
||||
",".join(iteration_stats.running_lora_adapters.keys())
|
||||
waiting_lora_adapters = \
|
||||
",".join(iteration_stats.waiting_lora_adapters.keys())
|
||||
lora_info_labels = {
|
||||
self.labelname_running_lora_adapters: running_lora_adapters,
|
||||
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
|
||||
self.labelname_max_lora: self.max_lora,
|
||||
}
|
||||
self.gauge_lora_info.labels(**lora_info_labels)\
|
||||
.set_to_current_time()
|
||||
|
||||
def log_engine_initialized(self):
|
||||
self.log_metrics_info("cache_config", self.vllm_config.cache_config)
|
||||
|
||||
|
||||
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
|
||||
"""
|
||||
Builds a list of buckets with increasing powers of 10 multiplied by
|
||||
mantissa values until the value exceeds the specified maximum.
|
||||
|
||||
"""
|
||||
exponent = 0
|
||||
buckets: list[int] = []
|
||||
while True:
|
||||
for m in mantissa_lst:
|
||||
value = m * 10**exponent
|
||||
if value <= max_value:
|
||||
buckets.append(value)
|
||||
else:
|
||||
return buckets
|
||||
exponent += 1
|
||||
|
||||
|
||||
def build_1_2_5_buckets(max_value: int) -> list[int]:
|
||||
"""
|
||||
Example:
|
||||
>>> build_1_2_5_buckets(100)
|
||||
[1, 2, 5, 10, 20, 50, 100]
|
||||
"""
|
||||
return build_buckets([1, 2, 5], max_value)
|
||||
|
||||
|
||||
def setup_default_loggers(
|
||||
vllm_config: VllmConfig,
|
||||
log_stats: bool,
|
||||
engine_num: int,
|
||||
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
) -> list[list[StatLoggerBase]]:
|
||||
"""Setup logging and prometheus metrics."""
|
||||
if not log_stats:
|
||||
return []
|
||||
|
||||
factories: list[StatLoggerFactory]
|
||||
if custom_stat_loggers is not None:
|
||||
factories = custom_stat_loggers
|
||||
else:
|
||||
factories = [PrometheusStatLogger]
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
factories.append(LoggingStatLogger)
|
||||
|
||||
stat_loggers: list[list[StatLoggerBase]] = []
|
||||
for i in range(engine_num):
|
||||
per_engine_stat_loggers: list[StatLoggerBase] = []
|
||||
for logger_factory in factories:
|
||||
per_engine_stat_loggers.append(logger_factory(vllm_config, i))
|
||||
stat_loggers.append(per_engine_stat_loggers)
|
||||
|
||||
return stat_loggers
|
||||
82
vllm/v1/metrics/prometheus.py
Normal file
82
vllm/v1/metrics/prometheus.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Global temporary directory for prometheus multiprocessing
|
||||
_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None
|
||||
|
||||
|
||||
def setup_multiprocess_prometheus():
|
||||
"""Set up prometheus multiprocessing directory if not already configured.
|
||||
|
||||
"""
|
||||
global _prometheus_multiproc_dir
|
||||
|
||||
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||
# Make TemporaryDirectory for prometheus multiprocessing
|
||||
# Note: global TemporaryDirectory will be automatically
|
||||
# cleaned up upon exit.
|
||||
_prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name
|
||||
logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s",
|
||||
_prometheus_multiproc_dir.name)
|
||||
else:
|
||||
logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
||||
"This directory must be wiped between vLLM runs or "
|
||||
"you will find inaccurate metrics. Unset the variable "
|
||||
"and vLLM will properly handle cleanup.")
|
||||
|
||||
|
||||
def get_prometheus_registry():
|
||||
"""Get the appropriate prometheus registry based on multiprocessing
|
||||
configuration.
|
||||
|
||||
Returns:
|
||||
Registry: A prometheus registry
|
||||
"""
|
||||
if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None:
|
||||
logger.debug("Using multiprocess registry for prometheus metrics")
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
return registry
|
||||
|
||||
return REGISTRY
|
||||
|
||||
|
||||
def unregister_vllm_metrics():
|
||||
"""Unregister any existing vLLM collectors from the prometheus registry.
|
||||
|
||||
This is useful for testing and CI/CD where metrics may be registered
|
||||
multiple times across test runs.
|
||||
|
||||
Also, in case of multiprocess, we need to unregister the metrics from the
|
||||
global registry.
|
||||
"""
|
||||
registry = REGISTRY
|
||||
# Unregister any existing vLLM collectors
|
||||
for collector in list(registry._collector_to_names):
|
||||
if hasattr(collector, "_name") and "vllm" in collector._name:
|
||||
registry.unregister(collector)
|
||||
|
||||
|
||||
def shutdown_prometheus():
|
||||
"""Shutdown prometheus metrics."""
|
||||
|
||||
path = _prometheus_multiproc_dir
|
||||
if path is None:
|
||||
return
|
||||
try:
|
||||
pid = os.getpid()
|
||||
multiprocess.mark_process_dead(pid, path)
|
||||
logger.debug("Marked Prometheus metrics for process %d as dead", pid)
|
||||
except Exception as e:
|
||||
logger.error("Error during metrics cleanup: %s", str(e))
|
||||
131
vllm/v1/metrics/ray_wrappers.py
Normal file
131
vllm/v1/metrics/ray_wrappers.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.metrics.loggers import PrometheusStatLogger
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingProm
|
||||
|
||||
try:
|
||||
from ray.util import metrics as ray_metrics
|
||||
from ray.util.metrics import Metric
|
||||
except ImportError:
|
||||
ray_metrics = None
|
||||
|
||||
|
||||
class RayPrometheusMetric:
|
||||
|
||||
def __init__(self):
|
||||
if ray_metrics is None:
|
||||
raise ImportError(
|
||||
"RayPrometheusMetric requires Ray to be installed.")
|
||||
|
||||
self.metric: Metric = None
|
||||
|
||||
def labels(self, *labels, **labelskwargs):
|
||||
if labelskwargs:
|
||||
for k, v in labelskwargs.items():
|
||||
if not isinstance(v, str):
|
||||
labelskwargs[k] = str(v)
|
||||
|
||||
self.metric.set_default_tags(labelskwargs)
|
||||
|
||||
if labels:
|
||||
if len(labels) != len(self.metric._tag_keys):
|
||||
raise ValueError(
|
||||
"Number of labels must match the number of tag keys. "
|
||||
f"Expected {len(self.metric._tag_keys)}, got {len(labels)}"
|
||||
)
|
||||
|
||||
self.metric.set_default_tags(
|
||||
dict(zip(self.metric._tag_keys, labels)))
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class RayGaugeWrapper(RayPrometheusMetric):
|
||||
"""Wraps around ray.util.metrics.Gauge to provide same API as
|
||||
prometheus_client.Gauge"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: Optional[str] = "",
|
||||
labelnames: Optional[list[str]] = None):
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
self.metric = ray_metrics.Gauge(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple)
|
||||
|
||||
def set(self, value: Union[int, float]):
|
||||
return self.metric.set(value)
|
||||
|
||||
def set_to_current_time(self):
|
||||
# ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html
|
||||
return self.metric.set(time.time())
|
||||
|
||||
|
||||
class RayCounterWrapper(RayPrometheusMetric):
|
||||
"""Wraps around ray.util.metrics.Counter to provide same API as
|
||||
prometheus_client.Counter"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: Optional[str] = "",
|
||||
labelnames: Optional[list[str]] = None):
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
self.metric = ray_metrics.Counter(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple)
|
||||
|
||||
def inc(self, value: Union[int, float] = 1.0):
|
||||
if value == 0:
|
||||
return
|
||||
return self.metric.inc(value)
|
||||
|
||||
|
||||
class RayHistogramWrapper(RayPrometheusMetric):
|
||||
"""Wraps around ray.util.metrics.Histogram to provide same API as
|
||||
prometheus_client.Histogram"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: Optional[str] = "",
|
||||
labelnames: Optional[list[str]] = None,
|
||||
buckets: Optional[list[float]] = None):
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
boundaries = buckets if buckets else []
|
||||
self.metric = ray_metrics.Histogram(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple,
|
||||
boundaries=boundaries)
|
||||
|
||||
def observe(self, value: Union[int, float]):
|
||||
return self.metric.observe(value)
|
||||
|
||||
|
||||
class RaySpecDecodingProm(SpecDecodingProm):
|
||||
"""
|
||||
RaySpecDecodingProm is used by RayMetrics to log to Ray metrics.
|
||||
Provides the same metrics as SpecDecodingProm but uses Ray's
|
||||
util.metrics library.
|
||||
"""
|
||||
|
||||
_counter_cls = RayCounterWrapper
|
||||
|
||||
|
||||
class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||
"""RayPrometheusStatLogger uses Ray metrics instead."""
|
||||
|
||||
_gauge_cls = RayGaugeWrapper
|
||||
_counter_cls = RayCounterWrapper
|
||||
_histogram_cls = RayHistogramWrapper
|
||||
_spec_decoding_cls = RaySpecDecodingProm
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
super().__init__(vllm_config, engine_index)
|
||||
|
||||
@staticmethod
|
||||
def _unregister_vllm_metrics():
|
||||
# No-op on purpose
|
||||
pass
|
||||
246
vllm/v1/metrics/reader.py
Normal file
246
vllm/v1/metrics/reader.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from prometheus_client import REGISTRY
|
||||
from prometheus_client import Metric as PromMetric
|
||||
from prometheus_client.samples import Sample
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metric:
|
||||
"""A base class for prometheus metrics.
|
||||
|
||||
Each metric may be associated with key=value labels, and
|
||||
in some cases a single vLLM instance may have multiple
|
||||
metrics with the same name but different sets of labels.
|
||||
"""
|
||||
name: str
|
||||
labels: dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Counter(Metric):
|
||||
"""A monotonically increasing integer counter."""
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Vector(Metric):
|
||||
"""An ordered array of integer counters.
|
||||
|
||||
This type - which doesn't exist in Prometheus - models one very
|
||||
specific metric, vllm:spec_decode_num_accepted_tokens_per_pos.
|
||||
"""
|
||||
values: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gauge(Metric):
|
||||
"""A numerical value that can go up or down."""
|
||||
value: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class Histogram(Metric):
|
||||
"""Observations recorded in configurable buckets.
|
||||
|
||||
Buckets are represented by a dictionary. The key is
|
||||
the upper limit of the bucket, and the value is the
|
||||
observed count in that bucket. A '+Inf' key always
|
||||
exists.
|
||||
|
||||
The count property is the total count across all
|
||||
buckets, identical to the count of the '+Inf' bucket.
|
||||
|
||||
The sum property is the total sum of all observed
|
||||
values.
|
||||
"""
|
||||
count: int
|
||||
sum: float
|
||||
buckets: dict[str, int]
|
||||
|
||||
|
||||
def get_metrics_snapshot() -> list[Metric]:
|
||||
"""An API for accessing in-memory Prometheus metrics.
|
||||
|
||||
Example:
|
||||
>>> for metric in llm.get_metrics():
|
||||
... if isinstance(metric, Counter):
|
||||
... print(f"{metric} = {metric.value}")
|
||||
... elif isinstance(metric, Gauge):
|
||||
... print(f"{metric} = {metric.value}")
|
||||
... elif isinstance(metric, Histogram):
|
||||
... print(f"{metric}")
|
||||
... print(f" sum = {metric.sum}")
|
||||
... print(f" count = {metric.count}")
|
||||
... for bucket_le, value in metrics.buckets.items():
|
||||
... print(f" {bucket_le} = {value}")
|
||||
"""
|
||||
collected: list[Metric] = []
|
||||
for metric in REGISTRY.collect():
|
||||
if not metric.name.startswith("vllm:"):
|
||||
continue
|
||||
if metric.type == "gauge":
|
||||
samples = _get_samples(metric)
|
||||
for s in samples:
|
||||
collected.append(
|
||||
Gauge(name=metric.name, labels=s.labels, value=s.value))
|
||||
elif metric.type == "counter":
|
||||
samples = _get_samples(metric, "_total")
|
||||
if metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
|
||||
#
|
||||
# Ugly vllm:num_accepted_tokens_per_pos special case.
|
||||
#
|
||||
# This metric is a vector of counters - for each spec
|
||||
# decoding token position, we observe the number of
|
||||
# accepted tokens using a Counter labeled with 'position'.
|
||||
# We convert these into a vector of integer values.
|
||||
#
|
||||
for labels, values in _digest_num_accepted_by_pos_samples(
|
||||
samples):
|
||||
collected.append(
|
||||
Vector(name=metric.name, labels=labels, values=values))
|
||||
else:
|
||||
for s in samples:
|
||||
collected.append(
|
||||
Counter(name=metric.name,
|
||||
labels=s.labels,
|
||||
value=int(s.value)))
|
||||
|
||||
elif metric.type == "histogram":
|
||||
#
|
||||
# A histogram has a number of '_bucket' samples where
|
||||
# the 'le' label represents the upper limit of the bucket.
|
||||
# We convert these bucketized values into a dict of values
|
||||
# indexed by the value of the 'le' label. The 'le=+Inf'
|
||||
# label is a special case, catching all values observed.
|
||||
#
|
||||
bucket_samples = _get_samples(metric, "_bucket")
|
||||
count_samples = _get_samples(metric, "_count")
|
||||
sum_samples = _get_samples(metric, "_sum")
|
||||
for labels, buckets, count_value, sum_value in _digest_histogram(
|
||||
bucket_samples, count_samples, sum_samples):
|
||||
collected.append(
|
||||
Histogram(name=metric.name,
|
||||
labels=labels,
|
||||
buckets=buckets,
|
||||
count=count_value,
|
||||
sum=sum_value))
|
||||
else:
|
||||
raise AssertionError(f"Unknown metric type {metric.type}")
|
||||
|
||||
return collected
|
||||
|
||||
|
||||
def _get_samples(metric: PromMetric,
|
||||
suffix: Optional[str] = None) -> list[Sample]:
|
||||
name = (metric.name + suffix) if suffix is not None else metric.name
|
||||
return [s for s in metric.samples if s.name == name]
|
||||
|
||||
|
||||
def _strip_label(labels: dict[str, str], key_to_remove: str) -> dict[str, str]:
|
||||
labels_copy = labels.copy()
|
||||
labels_copy.pop(key_to_remove)
|
||||
return labels_copy
|
||||
|
||||
|
||||
def _digest_histogram(
|
||||
bucket_samples: list[Sample], count_samples: list[Sample],
|
||||
sum_samples: list[Sample]
|
||||
) -> list[tuple[dict[str, str], dict[str, int], int, float]]:
|
||||
#
|
||||
# In the case of DP, we have an indigestable
|
||||
# per-bucket-per-engine count as a list of labelled
|
||||
# samples, along with total and sum samples
|
||||
#
|
||||
# bucket_samples (in):
|
||||
# labels = {bucket: 100, idx: 0}, value = 2
|
||||
# labels = {bucket: 200, idx: 0}, value = 4
|
||||
# labels = {bucket: Inf, idx: 0}, value = 10
|
||||
# labels = {bucket: 100, idx: 1}, value = 1
|
||||
# labels = {bucket: 200, idx: 2}, value = 5
|
||||
# labels = {bucket: Inf, idx: 3}, value = 7
|
||||
# count_samples (in):
|
||||
# labels = {idx: 0}, value = 10
|
||||
# labels = {idx: 1}, value = 7
|
||||
# sum_samples (in):
|
||||
# labels = {idx: 0}, value = 2000
|
||||
# labels = {idx: 1}, value = 1200
|
||||
#
|
||||
# output: [
|
||||
# {idx: 0}, {"100": 2, "200": 4, "Inf": 10}, 10, 2000
|
||||
# {idx: 1}, {"100": 1, "200": 5, "Inf": 7}, 7, 1200
|
||||
# ]
|
||||
buckets_by_labels: dict[frozenset[tuple[str, str]], dict[str, int]] = {}
|
||||
for s in bucket_samples:
|
||||
bucket = s.labels["le"]
|
||||
labels_key = frozenset(_strip_label(s.labels, "le").items())
|
||||
if labels_key not in buckets_by_labels:
|
||||
buckets_by_labels[labels_key] = {}
|
||||
buckets_by_labels[labels_key][bucket] = int(s.value)
|
||||
|
||||
counts_by_labels: dict[frozenset[tuple[str, str]], int] = {}
|
||||
for s in count_samples:
|
||||
labels_key = frozenset(s.labels.items())
|
||||
counts_by_labels[labels_key] = int(s.value)
|
||||
|
||||
sums_by_labels: dict[frozenset[tuple[str, str]], float] = {}
|
||||
for s in sum_samples:
|
||||
labels_key = frozenset(s.labels.items())
|
||||
sums_by_labels[labels_key] = s.value
|
||||
|
||||
assert set(buckets_by_labels.keys()) == set(
|
||||
counts_by_labels.keys()) == set(sums_by_labels.keys())
|
||||
|
||||
output = []
|
||||
label_keys = list(buckets_by_labels.keys())
|
||||
for k in label_keys:
|
||||
labels = dict(k)
|
||||
output.append((labels, buckets_by_labels[k], counts_by_labels[k],
|
||||
sums_by_labels[k]))
|
||||
return output
|
||||
|
||||
|
||||
def _digest_num_accepted_by_pos_samples(
|
||||
samples: list[Sample]) -> list[tuple[dict[str, str], list[int]]]:
|
||||
#
|
||||
# In the case of DP, we have an indigestable
|
||||
# per-position-per-engine count as a list of
|
||||
# labelled samples
|
||||
#
|
||||
# samples (in):
|
||||
# labels = {pos: 0, idx: 0}, value = 10
|
||||
# labels = {pos: 1, idx: 0}, value = 7
|
||||
# labels = {pos: 2, idx: 0}, value = 2
|
||||
# labels = {pos: 0, idx: 1}, value = 5
|
||||
# labels = {pos: 1, idx: 1}, value = 3
|
||||
# labels = {pos: 2, idx: 1}, value = 1
|
||||
#
|
||||
# output: [
|
||||
# {idx: 0}, [10, 7, 2]
|
||||
# {idx: 1}, [5, 3, 1]
|
||||
# ]
|
||||
#
|
||||
max_pos = 0
|
||||
values_by_labels: dict[frozenset[tuple[str, str]], dict[int, int]] = {}
|
||||
|
||||
for s in samples:
|
||||
position = int(s.labels["position"])
|
||||
max_pos = max(max_pos, position)
|
||||
|
||||
labels_key = frozenset(_strip_label(s.labels, "position").items())
|
||||
if labels_key not in values_by_labels:
|
||||
values_by_labels[labels_key] = {}
|
||||
values_by_labels[labels_key][position] = int(s.value)
|
||||
|
||||
output = []
|
||||
for labels_key, values_by_position in values_by_labels.items():
|
||||
labels = dict(labels_key)
|
||||
values = [0] * (max_pos + 1)
|
||||
for pos, val in values_by_position.items():
|
||||
values[pos] = val
|
||||
output.append((labels, values))
|
||||
return output
|
||||
239
vllm/v1/metrics/stats.py
Normal file
239
vllm/v1/metrics/stats.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
|
||||
from vllm.v1.engine.output_processor import RequestState
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrefixCacheStats:
|
||||
"""Stores prefix cache hit statistics."""
|
||||
# Whether reset_prefix_cache was invoked.
|
||||
reset: bool = False
|
||||
# The number of requests in this update.
|
||||
requests: int = 0
|
||||
# The number of queries in these requests. Note that "queries" here
|
||||
# means the number of tokens that were queried from the cache.
|
||||
queries: int = 0
|
||||
# The number of hits in these requests.
|
||||
hits: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerStats:
|
||||
"""Stats associated with the scheduler."""
|
||||
|
||||
num_running_reqs: int = 0
|
||||
num_waiting_reqs: int = 0
|
||||
|
||||
gpu_cache_usage: float = 0.0
|
||||
|
||||
prefix_cache_stats: PrefixCacheStats = field(
|
||||
default_factory=PrefixCacheStats)
|
||||
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAStats:
|
||||
waiting_requests: set[str] = field(default_factory=set)
|
||||
running_requests: set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestStateStats:
|
||||
"""Stats that need to be tracked across delta updates."""
|
||||
|
||||
num_generation_tokens: int = 0
|
||||
|
||||
# This is a engine frontend timestamp (wall-clock)
|
||||
arrival_time: float = 0.0
|
||||
|
||||
# These are engine core timestamps (monotonic)
|
||||
queued_ts: float = 0.0
|
||||
scheduled_ts: float = 0.0
|
||||
first_token_ts: float = 0.0
|
||||
last_token_ts: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinishedRequestStats:
|
||||
"""Stats associated with a finished request."""
|
||||
|
||||
finish_reason: "FinishReason"
|
||||
e2e_latency: float = 0.0
|
||||
num_prompt_tokens: int = 0
|
||||
num_generation_tokens: int = 0
|
||||
max_tokens_param: Optional[int] = None
|
||||
queued_time: float = 0.0
|
||||
prefill_time: float = 0.0
|
||||
inference_time: float = 0.0
|
||||
decode_time: float = 0.0
|
||||
|
||||
|
||||
class IterationStats:
|
||||
"""Stats associated with a single set of EngineCoreOutputs."""
|
||||
|
||||
def __init__(self):
|
||||
self.iteration_timestamp = time.time()
|
||||
self.num_generation_tokens = 0
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_preempted_reqs = 0
|
||||
self.finished_requests: list[FinishedRequestStats] = []
|
||||
self.max_num_generation_tokens_iter: list[int] = []
|
||||
self.n_params_iter: list[int] = []
|
||||
self.time_to_first_tokens_iter: list[float] = []
|
||||
self.time_per_output_tokens_iter: list[float] = []
|
||||
self.waiting_lora_adapters: dict[str, int] = {}
|
||||
self.running_lora_adapters: dict[str, int] = {}
|
||||
|
||||
def _time_since(self, start: float) -> float:
|
||||
"""Calculate an interval relative to this iteration's timestamp."""
|
||||
return self.iteration_timestamp - start
|
||||
|
||||
def update_from_output(self, output: "EngineCoreOutput",
|
||||
engine_core_timestamp: float, is_prefilling: bool,
|
||||
prompt_len: int, req_stats: RequestStateStats,
|
||||
lora_stats: Optional[LoRAStats]):
|
||||
num_new_generation_tokens = len(output.new_token_ids)
|
||||
|
||||
self.num_generation_tokens += num_new_generation_tokens
|
||||
if is_prefilling:
|
||||
assert num_new_generation_tokens > 0
|
||||
self.num_prompt_tokens += prompt_len
|
||||
|
||||
first_token_latency = self._time_since(req_stats.arrival_time)
|
||||
self.time_to_first_tokens_iter.append(first_token_latency)
|
||||
|
||||
req_stats.num_generation_tokens += num_new_generation_tokens
|
||||
|
||||
# Process request-level engine core events
|
||||
if output.events is not None:
|
||||
self.update_from_events(output.request_id, output.events,
|
||||
is_prefilling, req_stats, lora_stats)
|
||||
|
||||
# Process the batch-level "new tokens" engine core event
|
||||
if is_prefilling:
|
||||
req_stats.first_token_ts = engine_core_timestamp
|
||||
else:
|
||||
tpot = engine_core_timestamp - req_stats.last_token_ts
|
||||
self.time_per_output_tokens_iter.append(tpot)
|
||||
|
||||
req_stats.last_token_ts = engine_core_timestamp
|
||||
|
||||
def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
|
||||
is_prefilling: bool, req_stats: RequestStateStats,
|
||||
lora_stats: Optional[LoRAStats]):
|
||||
# Avoid circular dependency
|
||||
from vllm.v1.engine import EngineCoreEventType
|
||||
for event in events:
|
||||
if event.type == EngineCoreEventType.QUEUED:
|
||||
req_stats.queued_ts = event.timestamp
|
||||
if lora_stats is not None:
|
||||
lora_stats.waiting_requests.add(req_id)
|
||||
elif event.type == EngineCoreEventType.SCHEDULED:
|
||||
if req_stats.scheduled_ts == 0.0: # ignore preemptions
|
||||
req_stats.scheduled_ts = event.timestamp
|
||||
LoRARequestStates.scheduled_request(lora_stats, req_id)
|
||||
elif event.type == EngineCoreEventType.PREEMPTED:
|
||||
self.num_preempted_reqs += 1
|
||||
LoRARequestStates.preempted_request(lora_stats, req_id)
|
||||
|
||||
def update_from_finished_request(self, finish_reason: "FinishReason",
|
||||
num_prompt_tokens: int,
|
||||
max_tokens_param: Optional[int],
|
||||
req_stats: RequestStateStats):
|
||||
e2e_latency = self._time_since(req_stats.arrival_time)
|
||||
|
||||
# Queued interval is from first QUEUED event to first SCHEDULED
|
||||
queued_time = req_stats.scheduled_ts - req_stats.queued_ts
|
||||
|
||||
# Prefill interval is from first SCHEDULED to first NEW_TOKEN
|
||||
# Any preemptions during prefill is included in the interval
|
||||
prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts
|
||||
|
||||
# Decode interval is from first NEW_TOKEN to last NEW_TOKEN
|
||||
# Any preemptions during decode are included
|
||||
decode_time = req_stats.last_token_ts - req_stats.first_token_ts
|
||||
|
||||
# Inference interval is from first SCHEDULED to last NEW_TOKEN
|
||||
# Any preemptions during prefill or decode are included
|
||||
inference_time = req_stats.last_token_ts - req_stats.scheduled_ts
|
||||
|
||||
finished_req = \
|
||||
FinishedRequestStats(finish_reason=finish_reason,
|
||||
e2e_latency=e2e_latency,
|
||||
num_prompt_tokens=num_prompt_tokens,
|
||||
num_generation_tokens=req_stats.num_generation_tokens,
|
||||
max_tokens_param=max_tokens_param,
|
||||
queued_time=queued_time,
|
||||
prefill_time=prefill_time,
|
||||
inference_time=inference_time,
|
||||
decode_time=decode_time)
|
||||
self.finished_requests.append(finished_req)
|
||||
|
||||
|
||||
class LoRARequestStates:
|
||||
"""Per-LoRA request state stats."""
|
||||
|
||||
def __init__(self):
|
||||
self.lora_name_to_stats: dict[str, LoRAStats] = {}
|
||||
|
||||
def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]:
|
||||
if req_state.lora_name is None:
|
||||
return None
|
||||
if req_state.lora_name not in self.lora_name_to_stats:
|
||||
self.lora_name_to_stats[req_state.lora_name] = LoRAStats()
|
||||
return self.lora_name_to_stats[req_state.lora_name]
|
||||
|
||||
def add_request(self, req_state: 'RequestState'):
|
||||
if (lora_stats := self.get_stats(req_state)) is not None:
|
||||
lora_stats.waiting_requests.add(req_state.request_id)
|
||||
|
||||
def finish_request(self, req_state: 'RequestState'):
|
||||
if req_state.lora_name is None:
|
||||
return
|
||||
lora_stats = self.lora_name_to_stats[req_state.lora_name]
|
||||
lora_stats.running_requests.remove(req_state.request_id)
|
||||
|
||||
def abort_request(self, req_state: 'RequestState'):
|
||||
if req_state.lora_name is None:
|
||||
return
|
||||
lora_stats = self.lora_name_to_stats[req_state.lora_name]
|
||||
lora_stats.waiting_requests.discard(req_state.request_id)
|
||||
lora_stats.running_requests.discard(req_state.request_id)
|
||||
|
||||
# Break the pattern for this lifecycle methods so we can
|
||||
# call this from IterationStats.update_from_events()
|
||||
@staticmethod
|
||||
def scheduled_request(lora_stats: Optional[LoRAStats], request_id: str):
|
||||
if lora_stats is None:
|
||||
return
|
||||
lora_stats.waiting_requests.remove(request_id)
|
||||
lora_stats.running_requests.add(request_id)
|
||||
|
||||
@staticmethod
|
||||
def preempted_request(lora_stats: Optional[LoRAStats], request_id: str):
|
||||
if lora_stats is None:
|
||||
return
|
||||
lora_stats.running_requests.remove(request_id)
|
||||
lora_stats.waiting_requests.add(request_id)
|
||||
|
||||
def update_iteration_stats(self,
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
if iteration_stats is None:
|
||||
return
|
||||
for lora_name, stats in self.lora_name_to_stats.items():
|
||||
if stats.waiting_requests:
|
||||
iteration_stats.waiting_lora_adapters[lora_name] = \
|
||||
len(stats.waiting_requests)
|
||||
if stats.running_requests:
|
||||
iteration_stats.running_lora_adapters[lora_name] = \
|
||||
len(stats.running_requests)
|
||||
116
vllm/v1/outputs.py
Normal file
116
vllm/v1/outputs.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LogprobsLists(NamedTuple):
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: list[list[int]]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: list[list[float]]
|
||||
# [num_reqs]
|
||||
sampled_token_ranks: list[int]
|
||||
|
||||
def slice(self, start: int, end: int):
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids[start:end],
|
||||
self.logprobs[start:end],
|
||||
self.sampled_token_ranks[start:end],
|
||||
)
|
||||
|
||||
|
||||
class LogprobsTensors(NamedTuple):
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: torch.Tensor
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: torch.Tensor
|
||||
# [num_reqs]
|
||||
selected_token_ranks: torch.Tensor
|
||||
|
||||
def tolists(self):
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids.tolist(),
|
||||
self.logprobs.tolist(),
|
||||
self.selected_token_ranks.tolist(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def empty_cpu(num_positions: int,
|
||||
num_tokens_per_position: int) -> "LogprobsTensors":
|
||||
"""Create empty LogprobsTensors on CPU."""
|
||||
|
||||
logprob_token_ids = torch.empty(
|
||||
(num_positions, num_tokens_per_position),
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
|
||||
selected_token_ranks = torch.empty(num_positions,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=selected_token_ranks,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
|
||||
# [num_reqs, max_num_generated_tokens]
|
||||
# Different requests can have different number of generated tokens.
|
||||
# All requests are padded to max_num_generated_tokens.
|
||||
# PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
|
||||
sampled_token_ids: torch.Tensor
|
||||
logprobs_tensors: Optional[LogprobsTensors]
|
||||
|
||||
|
||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||
# This is expensive for torch.Tensor so prefer to use list instead.
|
||||
@dataclass
|
||||
class ModelRunnerOutput:
|
||||
|
||||
# [num_reqs]
|
||||
req_ids: list[str]
|
||||
# req_id -> index
|
||||
req_id_to_index: dict[str, int]
|
||||
|
||||
# num_reqs x num_generated_tokens
|
||||
# num_generated_tokens is the number of tokens
|
||||
# generated in the current step. It can be different for
|
||||
# each request due to speculative/jump decoding.
|
||||
sampled_token_ids: list[list[int]]
|
||||
|
||||
# num_reqs x num_spec_tokens
|
||||
spec_token_ids: Optional[list[list[int]]]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs]
|
||||
logprobs: Optional[LogprobsLists]
|
||||
|
||||
# req_id -> (token_ids, logprobs, ranks)
|
||||
# [prompt_len, num_prompt_logprobs]
|
||||
# [prompt_len, num_prompt_logprobs]
|
||||
# [prompt_len]
|
||||
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
|
||||
|
||||
# [req_ids]
|
||||
finished_sending: Optional[set[str]] = None
|
||||
finished_recving: Optional[set[str]] = None
|
||||
|
||||
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||
req_id_to_index={},
|
||||
sampled_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
finished_sending=None,
|
||||
finished_recving=None)
|
||||
193
vllm/v1/request.py
Normal file
193
vllm/v1/request.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
EngineCoreRequest, FinishReason)
|
||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
class Request:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
multi_modal_inputs: Optional[list[MultiModalKwargs]],
|
||||
multi_modal_hashes: Optional[list[str]],
|
||||
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
||||
sampling_params: SamplingParams,
|
||||
eos_token_id: Optional[int],
|
||||
client_index: int = 0,
|
||||
lora_request: Optional["LoRARequest"] = None,
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.client_index = client_index
|
||||
self.sampling_params = sampling_params
|
||||
# Because of LoRA, the eos token id can be different for each request.
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.structured_output_request = structured_output_request
|
||||
|
||||
self.status = (RequestStatus.WAITING_FOR_FSM
|
||||
if sampling_params.guided_decoding is not None else
|
||||
RequestStatus.WAITING)
|
||||
self.events: list[EngineCoreEvent] = []
|
||||
self.stop_reason: Union[int, str, None] = None
|
||||
assert sampling_params.max_tokens is not None
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
self._output_token_ids: list[int] = []
|
||||
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
|
||||
self.spec_token_ids: list[int] = []
|
||||
self.num_computed_tokens = 0
|
||||
self.cache_salt: Optional[str] = cache_salt
|
||||
|
||||
# Multi-modal related
|
||||
self.mm_positions = multi_modal_placeholders or []
|
||||
self.mm_inputs = multi_modal_inputs or []
|
||||
self.mm_hashes: list[str] = multi_modal_hashes or []
|
||||
self.num_encoder_inputs = len(self.mm_inputs)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
|
||||
# P/D: Connector-specific KV transfer parameters.
|
||||
kv_params = (None if sampling_params.extra_args is None else
|
||||
sampling_params.extra_args.get("kv_transfer_params"))
|
||||
self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
|
||||
|
||||
# Sanity check
|
||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||
if self.mm_hashes:
|
||||
assert len(self.mm_inputs) == len(self.mm_hashes)
|
||||
|
||||
# Read-only views
|
||||
# Prevent directly appending to these lists since
|
||||
# they should also be updated simultaneously.
|
||||
self.output_token_ids = ConstantList(self._output_token_ids)
|
||||
self.all_token_ids = ConstantList(self._all_token_ids)
|
||||
|
||||
# State
|
||||
# The number of tokens with prefix cache hits.
|
||||
self.num_cached_tokens = -1
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
if request.mm_inputs is not None:
|
||||
assert isinstance(request.mm_inputs, list)
|
||||
assert is_list_of(request.mm_inputs, MultiModalKwargs), (
|
||||
"mm_inputs was not updated in EngineCore.add_request")
|
||||
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
client_index=request.client_index,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
multi_modal_inputs=request.mm_inputs,
|
||||
multi_modal_hashes=request.mm_hashes,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
sampling_params=request.sampling_params,
|
||||
eos_token_id=request.eos_token_id,
|
||||
lora_request=request.lora_request,
|
||||
structured_output_request=StructuredOutputRequest(
|
||||
sampling_params=request.sampling_params),
|
||||
cache_salt=request.cache_salt,
|
||||
)
|
||||
|
||||
def append_output_token_ids(
|
||||
self,
|
||||
token_ids: Union[int, list[int]],
|
||||
) -> None:
|
||||
if isinstance(token_ids, int):
|
||||
self._output_token_ids.append(token_ids)
|
||||
self._all_token_ids.append(token_ids)
|
||||
else:
|
||||
self._output_token_ids.extend(token_ids)
|
||||
self._all_token_ids.extend(token_ids)
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return len(self._all_token_ids)
|
||||
|
||||
@property
|
||||
def num_tokens_with_spec(self) -> int:
|
||||
return len(self._all_token_ids) + len(self.spec_token_ids)
|
||||
|
||||
@property
|
||||
def num_output_tokens(self) -> int:
|
||||
return len(self._output_token_ids)
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return RequestStatus.is_finished(self.status)
|
||||
|
||||
def get_finished_reason(self) -> Union[FinishReason, None]:
|
||||
return RequestStatus.get_finished_reason(self.status)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
assert input_id < len(self.mm_positions)
|
||||
num_tokens = self.mm_positions[input_id].length
|
||||
return num_tokens
|
||||
|
||||
@property
|
||||
def use_structured_output(self) -> bool:
|
||||
return self.sampling_params.guided_decoding is not None
|
||||
|
||||
def record_event(
|
||||
self,
|
||||
event_type: EngineCoreEventType,
|
||||
timestamp: Optional[float] = None,
|
||||
) -> None:
|
||||
self.events.append(EngineCoreEvent.new_event(event_type, timestamp))
|
||||
|
||||
def take_events(self) -> Optional[list[EngineCoreEvent]]:
|
||||
if not self.events:
|
||||
return None
|
||||
events, self.events = self.events, []
|
||||
return events
|
||||
|
||||
|
||||
class RequestStatus(enum.IntEnum):
|
||||
"""Status of a request."""
|
||||
WAITING = enum.auto()
|
||||
WAITING_FOR_FSM = enum.auto()
|
||||
WAITING_FOR_REMOTE_KVS = enum.auto()
|
||||
RUNNING = enum.auto()
|
||||
PREEMPTED = enum.auto()
|
||||
# Note: anything after PREEMPTED will be considered
|
||||
# as a finished status.
|
||||
FINISHED_STOPPED = enum.auto()
|
||||
FINISHED_LENGTH_CAPPED = enum.auto()
|
||||
FINISHED_ABORTED = enum.auto()
|
||||
FINISHED_IGNORED = enum.auto()
|
||||
|
||||
@staticmethod
|
||||
def is_finished(status: "RequestStatus") -> bool:
|
||||
return status > RequestStatus.PREEMPTED
|
||||
|
||||
@staticmethod
|
||||
def get_finished_reason(
|
||||
status: "RequestStatus") -> Union[FinishReason, None]:
|
||||
return _FINISHED_REASON_MAP.get(status)
|
||||
|
||||
|
||||
# Mapping of finished statuses to their finish reasons.
|
||||
# NOTE: The ignored requests are the requests whose prompt lengths
|
||||
# are longer than the model's length cap. Therefore, the stop
|
||||
# reason should also be "length" as in OpenAI API.
|
||||
_FINISHED_REASON_MAP = {
|
||||
RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
|
||||
RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
|
||||
RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
|
||||
RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
|
||||
}
|
||||
0
vllm/v1/sample/__init__.py
Normal file
0
vllm/v1/sample/__init__.py
Normal file
44
vllm/v1/sample/metadata.py
Normal file
44
vllm/v1/sample/metadata.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
|
||||
temperature: Optional[torch.Tensor]
|
||||
all_greedy: bool
|
||||
all_random: bool
|
||||
|
||||
top_p: Optional[torch.Tensor]
|
||||
top_k: Optional[torch.Tensor]
|
||||
min_p: Optional[torch.Tensor]
|
||||
|
||||
generators: dict[int, torch.Generator]
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: Optional[int]
|
||||
|
||||
no_penalties: bool
|
||||
prompt_token_ids: Optional[torch.Tensor]
|
||||
frequency_penalties: torch.Tensor
|
||||
presence_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor
|
||||
|
||||
output_token_ids: list[list[int]]
|
||||
|
||||
# req_index -> (min_tokens, stop_token_ids)
|
||||
min_tokens: dict[int, tuple[int, set[int]]]
|
||||
|
||||
logit_bias: list[Optional[dict[int, float]]]
|
||||
|
||||
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||
# vocab size).
|
||||
allowed_token_ids_mask: Optional[torch.Tensor]
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
bad_words_token_ids: dict[int, list[list[int]]]
|
||||
0
vllm/v1/sample/ops/__init__.py
Normal file
0
vllm/v1/sample/ops/__init__.py
Normal file
39
vllm/v1/sample/ops/bad_words.py
Normal file
39
vllm/v1/sample/ops/bad_words.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
_SMALLEST_LOGIT = float("-inf")
|
||||
|
||||
|
||||
def _apply_bad_words_single_batch(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: list[list[int]],
|
||||
past_tokens_ids: list[int],
|
||||
) -> None:
|
||||
for bad_word_ids in bad_words_token_ids:
|
||||
if len(bad_word_ids) > len(past_tokens_ids) + 1:
|
||||
continue
|
||||
|
||||
prefix_length = len(bad_word_ids) - 1
|
||||
last_token_id = bad_word_ids[-1]
|
||||
if prefix_length > 0:
|
||||
actual_prefix = past_tokens_ids[-prefix_length:]
|
||||
else:
|
||||
actual_prefix = []
|
||||
expected_prefix = bad_word_ids[:prefix_length]
|
||||
|
||||
assert len(actual_prefix) == len(expected_prefix)
|
||||
|
||||
if actual_prefix == expected_prefix:
|
||||
logits[last_token_id] = _SMALLEST_LOGIT
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: dict[int, list[list[int]]],
|
||||
past_tokens_ids: list[list[int]],
|
||||
) -> None:
|
||||
for i, bad_words_ids in bad_words_token_ids.items():
|
||||
_apply_bad_words_single_batch(logits[i], bad_words_ids,
|
||||
past_tokens_ids[i])
|
||||
59
vllm/v1/sample/ops/penalties.py
Normal file
59
vllm/v1/sample/ops/penalties.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
|
||||
|
||||
def apply_min_token_penalties(
|
||||
logits: torch.Tensor, output_token_ids: list[list[int]],
|
||||
min_tokens: dict[int, tuple[int, set[int]]]) -> None:
|
||||
"""
|
||||
Applies minimum token penalty by setting the logits of the stop tokens
|
||||
to -inf.
|
||||
"""
|
||||
min_tokens_logits_to_penalize: list[tuple[int, int]] = []
|
||||
for index, (min_token, stop_token_ids) in min_tokens.items():
|
||||
if len(output_token_ids[index]) < min_token:
|
||||
for stop_token_id in stop_token_ids:
|
||||
min_tokens_logits_to_penalize.append((index, stop_token_id))
|
||||
if min_tokens_logits_to_penalize:
|
||||
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
|
||||
|
||||
|
||||
def apply_all_penalties(
|
||||
logits: torch.Tensor,
|
||||
prompt_token_ids: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor,
|
||||
output_token_ids: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applies presence, frequency and repetition penalties to the logits.
|
||||
"""
|
||||
_, vocab_size = logits.shape
|
||||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
|
||||
logits.device)
|
||||
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
|
||||
presence_penalties, frequency_penalties,
|
||||
repetition_penalties)
|
||||
|
||||
|
||||
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Convert the different list data structures to tensors.
|
||||
"""
|
||||
output_tokens_tensor = make_tensor_with_pad(
|
||||
output_token_ids,
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
pad=vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
return output_tokens_tensor.to(device, non_blocking=True)
|
||||
293
vllm/v1/sample/ops/topk_topp_sampler.py
Normal file
293
vllm/v1/sample/ops/topk_topp_sampler.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import flashinfer.sampling
|
||||
is_flashinfer_available = True
|
||||
except ImportError:
|
||||
is_flashinfer_available = False
|
||||
|
||||
|
||||
class TopKTopPSampler(nn.Module):
|
||||
"""
|
||||
Module that performs optional top-k and top-p filtering followed by
|
||||
weighted random sampling of logits.
|
||||
|
||||
Implementations may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
if is_flashinfer_available:
|
||||
flashinfer_version = flashinfer.__version__
|
||||
if flashinfer_version < "0.2.3":
|
||||
logger.warning(
|
||||
"FlashInfer version >= 0.2.3 required. "
|
||||
"Falling back to default sampling implementation.")
|
||||
self.forward = self.forward_native
|
||||
elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
|
||||
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
|
||||
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
|
||||
# default it is unused). For backward compatibility, we set
|
||||
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
|
||||
# interpret it differently in V0 and V1 samplers: In V0,
|
||||
# None means False, while in V1, None means True. This is
|
||||
# why we use the condition
|
||||
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
|
||||
logger.info("Using FlashInfer for top-p & top-k sampling.")
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is available, but it is not enabled. "
|
||||
"Falling back to the PyTorch-native implementation of "
|
||||
"top-p & top-k sampling. For the best performance, "
|
||||
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is not available. Falling back to the PyTorch-"
|
||||
"native implementation of top-p & top-k sampling. For the "
|
||||
"best performance, please install FlashInfer.")
|
||||
self.forward = self.forward_native
|
||||
elif current_platform.is_tpu():
|
||||
self.forward = self.forward_tpu
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
PyTorch-native implementation of top-k and top-p sampling.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = apply_top_k_top_p(logits, k, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""More optimized implementation for top-k and top-p sampling."""
|
||||
if k is None and p is None:
|
||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
if generators:
|
||||
logger.warning("FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation.")
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
return flashinfer_sample(logits, k, p, generators)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
logits = apply_top_k_top_p_tpu(logits, k, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
|
||||
|
||||
def apply_top_k_top_p_tpu(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k and top-p optimized for TPU.
|
||||
|
||||
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
||||
This is achieved by finding a "cut-off" element in the original logit, and
|
||||
after thresholding the logit using this cut-off, the remaining elements
|
||||
shall constitute the top-p set.
|
||||
|
||||
Note: in the case of tie (i.e. multipple cut-off elements present in the
|
||||
logit), all tie elements are included in the top-p set. In other words,
|
||||
this function does not break ties. Instead, these tie tokens have equal
|
||||
chance of being chosen during final sampling, so we can consider the tie
|
||||
being broken then.
|
||||
"""
|
||||
probs = logits.softmax(dim=-1)
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||
|
||||
# Make sure the no top-k rows are no-op.
|
||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||
|
||||
elements_to_discard = probs < top_k_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||
top_p_mask[:, -1] = False # at least one
|
||||
|
||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||
elements_to_discard = probs < top_p_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k and top-p masks to the logits.
|
||||
|
||||
If a top-p is used, this function will sort the logits tensor,
|
||||
which can be slow for large batches.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
if p is None:
|
||||
if k is None:
|
||||
return logits
|
||||
|
||||
# Avoid sorting vocab for top-k only case.
|
||||
return apply_top_k_only(logits, k)
|
||||
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def apply_top_k_only(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k mask to the logits.
|
||||
|
||||
This implementation doesn't involve sorting the entire vocab.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
no_top_k_mask = k == logits.shape[1]
|
||||
# Set non-top-k rows to 1 so that we can gather.
|
||||
k = k.masked_fill(no_top_k_mask, 1)
|
||||
max_top_k = k.max()
|
||||
# topk.values tensor has shape [batch_size, max_top_k].
|
||||
# Convert top k to 0-based index in range [0, max_top_k).
|
||||
k_index = k.sub_(1).unsqueeze(1)
|
||||
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
|
||||
# Handle non-topk rows.
|
||||
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
||||
logits.masked_fill_(logits < top_k_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
|
||||
def random_sample(
|
||||
probs: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Randomly sample from the probabilities.
|
||||
|
||||
We use this function instead of torch.multinomial because torch.multinomial
|
||||
causes CPU-GPU synchronization.
|
||||
"""
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
if len(generators) != probs.shape[0]:
|
||||
q.exponential_()
|
||||
if generators:
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def flashinfer_sample(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Sample from the logits using FlashInfer.
|
||||
|
||||
Statistically, this function is equivalent to the `random_sample` function.
|
||||
However, this function is faster because it avoids sorting the logits tensor
|
||||
via rejection sampling.
|
||||
|
||||
NOTE: The outputs of this function do not necessarily match the outputs of
|
||||
the `random_sample` function. It only guarantees that the outputs are
|
||||
statistically equivalent.
|
||||
|
||||
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
|
||||
does not. Call this function at the end of the forward pass to minimize
|
||||
the synchronization overhead.
|
||||
"""
|
||||
assert not (k is None and p is None)
|
||||
if k is None:
|
||||
# Top-p only.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
|
||||
probs, p, deterministic=True)
|
||||
elif p is None:
|
||||
# Top-k only.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
|
||||
probs, k, deterministic=True)
|
||||
else:
|
||||
# Both top-k and top-p.
|
||||
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
|
||||
logits, k, p, deterministic=True)
|
||||
|
||||
return next_token_ids.view(-1)
|
||||
631
vllm/v1/sample/rejection_sampler.py
Normal file
631
vllm/v1/sample/rejection_sampler.py
Normal file
@@ -0,0 +1,631 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
||||
GREEDY_TEMPERATURE: tl.constexpr = -1
|
||||
# Maximum number of speculative draft tokens allowed per request in a single
|
||||
# step. This value is chosen to be large enough to handle typical use cases.
|
||||
MAX_SPEC_LEN = 32
|
||||
|
||||
|
||||
class RejectionSampler(nn.Module):
|
||||
"""
|
||||
The implementation strictly follows the algorithm described in
|
||||
https://arxiv.org/abs/2211.17192.
|
||||
However, we want to clarify the terminology used in the implementation:
|
||||
accepted tokens: tokens that are accepted based on the relationship
|
||||
between the "raw" draft and target probabilities.
|
||||
recovered tokens: tokens that are sampled based on the adjusted probability
|
||||
distribution, which is derived from both the draft and target
|
||||
probabilities.
|
||||
bonus tokens:
|
||||
If all proposed tokens are accepted, the bonus token is added to the
|
||||
end of the sequence. The bonus token is only sampled from the target
|
||||
probabilities. We pass in the bonus tokens instead of sampling them
|
||||
in the rejection sampler to allow for more flexibility in the
|
||||
sampling process. For example, we can use top_p, top_k sampling for
|
||||
bonus tokens, while spec decode does not support these sampling
|
||||
strategies.
|
||||
output tokens:
|
||||
Tokens are finally generated with the rejection sampler.
|
||||
output tokens = accepted tokens + recovered tokens + bonus tokens
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
metadata: SpecDecodeMetadata,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_logits: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Args:
|
||||
metadata:
|
||||
Metadata for spec decoding.
|
||||
draft_probs (Optional[torch.Tensor]):
|
||||
Probability distribution for the draft tokens. Shape is
|
||||
[num_tokens, vocab_size]. Can be None if probabilities are
|
||||
not provided, which is the case for ngram spec decode.
|
||||
target_logits (torch.Tensor):
|
||||
Target model's logits probability distribution.
|
||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
||||
different requests are flattened into a single tensor because
|
||||
this is the shape of the output logits.
|
||||
NOTE: `target_logits` can be updated in place to save memory.
|
||||
bonus_token_ids_tensor (torch.Tensor):
|
||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
||||
Bonus tokens are added to the end of the sequence if all
|
||||
proposed tokens are accepted. We generate the bonus tokens
|
||||
outside of the rejection sampler with the default sampling
|
||||
strategy. It allows for more flexibility in the sampling
|
||||
process such as top_p, top_k sampling.
|
||||
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
|
||||
Additional metadata needed for sampling, such as temperature,
|
||||
top-k/top-p parameters, or other relevant information.
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
'''
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
# `compute_probs` function.
|
||||
target_probs = compute_probs(
|
||||
target_logits,
|
||||
metadata.cu_num_draft_tokens,
|
||||
sampling_metadata,
|
||||
)
|
||||
|
||||
output_token_ids = rejection_sample(
|
||||
metadata.draft_token_ids,
|
||||
metadata.num_draft_tokens,
|
||||
metadata.max_spec_len,
|
||||
metadata.cu_num_draft_tokens,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
@staticmethod
|
||||
def parse_output(
|
||||
output_token_ids: torch.Tensor,
|
||||
vocab_size: int,
|
||||
) -> list[list[int]]:
|
||||
"""Parse the output of the rejection sampler.
|
||||
|
||||
Args:
|
||||
output_token_ids: The sampled token IDs in shape
|
||||
[batch_size, max_spec_len + 1]. The rejected tokens are
|
||||
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
|
||||
and will be filtered out in this function.
|
||||
vocab_size: The size of the vocabulary.
|
||||
|
||||
Returns:
|
||||
A list of lists of token IDs.
|
||||
"""
|
||||
output_token_ids_np = output_token_ids.cpu().numpy()
|
||||
# Create mask for valid tokens.
|
||||
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
|
||||
(output_token_ids_np < vocab_size))
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist()
|
||||
for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs
|
||||
|
||||
|
||||
def rejection_sample(
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_draft_tokens: list[int],
|
||||
max_spec_len: int,
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert draft_token_ids.ndim == 1
|
||||
assert draft_probs is None or draft_probs.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
assert target_probs.ndim == 2
|
||||
|
||||
batch_size = len(num_draft_tokens)
|
||||
num_tokens = draft_token_ids.shape[0]
|
||||
vocab_size = target_probs.shape[-1]
|
||||
device = target_probs.device
|
||||
assert draft_token_ids.is_contiguous()
|
||||
assert draft_probs is None or draft_probs.is_contiguous()
|
||||
assert target_probs.is_contiguous()
|
||||
assert bonus_token_ids.is_contiguous()
|
||||
assert target_probs.shape == (num_tokens, vocab_size)
|
||||
|
||||
# Create output buffer.
|
||||
output_token_ids = torch.empty(
|
||||
(batch_size, max_spec_len + 1),
|
||||
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
|
||||
device=device,
|
||||
)
|
||||
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
if sampling_metadata.all_greedy:
|
||||
is_greedy = None
|
||||
else:
|
||||
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
rejection_greedy_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
num_warps=1,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
uniform_probs = generate_uniform_probs(
|
||||
num_tokens,
|
||||
num_draft_tokens,
|
||||
sampling_metadata.generators,
|
||||
device,
|
||||
)
|
||||
|
||||
# Sample recovered tokens for each position.
|
||||
# [num_tokens]
|
||||
recovered_token_ids = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
num_draft_tokens,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device,
|
||||
)
|
||||
|
||||
# Rejection sampling for random sampling requests.
|
||||
rejection_random_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
num_warps=1,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
|
||||
def compute_probs(
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Compute probability distribution from logits based on sampling metadata.
|
||||
|
||||
This function applies temperature scaling to the logits and converts
|
||||
them to probabilities using softmax. For greedy decoding, it returns
|
||||
the original logits.
|
||||
|
||||
Args:
|
||||
logits: Input logits tensor to be converted to probabilities.
|
||||
cu_num_draft_tokens: Cumulative number of draft tokens.
|
||||
sampling_metadata: Metadata containing sampling parameters such as
|
||||
temperature and whether greedy sampling is used.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Probability distribution (softmax of scaled logits)
|
||||
if non-greedy sampling is used, otherwise returns the
|
||||
original logits.
|
||||
"""
|
||||
assert logits.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
if sampling_metadata.all_greedy:
|
||||
return logits
|
||||
|
||||
num_tokens = logits.shape[0]
|
||||
temperature = expand_batch_to_tokens(
|
||||
sampling_metadata.temperature,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
replace_from=GREEDY_TEMPERATURE,
|
||||
replace_to=1,
|
||||
)
|
||||
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
|
||||
logits.div_(temperature.unsqueeze(-1))
|
||||
|
||||
# Get expanded top_k and top_p tensors.
|
||||
top_k = None
|
||||
if sampling_metadata.top_k is not None:
|
||||
top_k = expand_batch_to_tokens(
|
||||
sampling_metadata.top_k,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
)
|
||||
top_p = None
|
||||
if sampling_metadata.top_p is not None:
|
||||
top_p = expand_batch_to_tokens(
|
||||
sampling_metadata.top_p,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||
# which is slow for large vocab sizes. This may cause performance issues.
|
||||
logits = apply_top_k_top_p(logits, top_k, top_p)
|
||||
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return output_prob
|
||||
|
||||
|
||||
def expand_batch_to_tokens(
|
||||
x: torch.Tensor, # [batch_size]
|
||||
cu_num_tokens: torch.Tensor, # [batch_size]
|
||||
num_tokens: int,
|
||||
replace_from: int = 0,
|
||||
replace_to: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
|
||||
tokens per batch in cu_num_tokens.
|
||||
|
||||
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
|
||||
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
|
||||
|
||||
Args:
|
||||
x: [batch_size] tensor to expand.
|
||||
cu_num_tokens: [batch_size] tensor containing the cumulative number of
|
||||
tokens per batch. Each element represents the total number of
|
||||
tokens up to and including that batch.
|
||||
num_tokens: Total number of tokens.
|
||||
replace_from: int = 0
|
||||
Value to be replaced if it is found in x.
|
||||
replace_to: int = 0
|
||||
Value to replace with when replace_from is found.
|
||||
Returns:
|
||||
expanded_x: [num_tokens] tensor.
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
assert cu_num_tokens.shape[0] == batch_size
|
||||
expanded_x = x.new_empty(num_tokens)
|
||||
expand_kernel[(batch_size, )](
|
||||
expanded_x,
|
||||
x,
|
||||
cu_num_tokens,
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
|
||||
num_warps=1,
|
||||
)
|
||||
return expanded_x
|
||||
|
||||
|
||||
def generate_uniform_probs(
|
||||
num_tokens: int,
|
||||
num_draft_tokens: list[int],
|
||||
generators: dict[int, torch.Generator],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generates a batch of uniform random samples, with optional seeding
|
||||
if available.
|
||||
|
||||
This method creates a tensor of shape `(num_tokens, )` filled
|
||||
with uniform random values in the range [0, 1). If `generators` is provided,
|
||||
the requests with their own seeds will use the provided `torch.Generator`
|
||||
for reproducibility. The samples for the other requests will be generated
|
||||
without a seed.
|
||||
|
||||
Args:
|
||||
num_tokens : int
|
||||
Total number of tokens.
|
||||
num_draft_tokens : List[List[int]]
|
||||
Number of draft tokens per request.
|
||||
generators : Optional[Dict[int, torch.Generator]]
|
||||
A dictionary mapping indices in the batch to
|
||||
`torch.Generator` objects.
|
||||
device : torch.device
|
||||
The device on which to allocate the tensor.
|
||||
Returns:
|
||||
uniform_rand : torch.Tensor
|
||||
A tensor of shape `(num_tokens, )` containing uniform
|
||||
random values in the range [0, 1).
|
||||
"""
|
||||
uniform_probs = torch.rand(
|
||||
(num_tokens, ),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
start_idx = 0
|
||||
for req_idx, n in enumerate(num_draft_tokens):
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if n == 0:
|
||||
continue
|
||||
end_idx = start_idx + n
|
||||
generator = generators.get(req_idx)
|
||||
if generator is not None:
|
||||
uniform_probs[start_idx:end_idx].uniform_(generator=generator)
|
||||
start_idx = end_idx
|
||||
return uniform_probs
|
||||
|
||||
|
||||
def sample_recovered_tokens(
|
||||
max_spec_len: int,
|
||||
num_draft_tokens: list[int],
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): Create only one distribution for each request.
|
||||
batch_size = len(num_draft_tokens)
|
||||
vocab_size = target_probs.shape[-1]
|
||||
q = torch.empty(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
q.exponential_()
|
||||
for i, generator in sampling_metadata.generators.items():
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if num_draft_tokens[i] > 0:
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
recovered_token_ids = torch.empty_like(draft_token_ids)
|
||||
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
|
||||
recovered_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
q,
|
||||
vocab_size,
|
||||
triton.next_power_of_2(vocab_size),
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
)
|
||||
return recovered_token_ids
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_greedy_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
target_argmax_ptr, # [num_tokens]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
is_greedy_ptr, # [batch_size] or None
|
||||
max_spec_len,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
|
||||
# re-compilation may happen during runtime when is_greedy_ptr is None.
|
||||
if is_greedy_ptr is None:
|
||||
is_greedy = True
|
||||
else:
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
if is_greedy is None:
|
||||
# Early exit for non-greedy sampling requests.
|
||||
return
|
||||
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
target_argmax_id)
|
||||
if draft_token_id != target_argmax_id:
|
||||
# Reject.
|
||||
rejected = True
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
num_draft_tokens, bonus_token_id)
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_random_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
recovered_token_ids_ptr, # [num_tokens]
|
||||
uniform_probs_ptr, # [num_tokens]
|
||||
is_greedy_ptr, # [batch_size]
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
if is_greedy is not None:
|
||||
# Early exit for greedy sampling requests.
|
||||
return
|
||||
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||
# NOTE(woosuk): While the draft probability should never be 0,
|
||||
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||
# Accept.
|
||||
token_id = draft_token_id
|
||||
else:
|
||||
# Reject. Use recovered token.
|
||||
rejected = True
|
||||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
token_id)
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
num_draft_tokens, bonus_token_id)
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
|
||||
def expand_kernel(
|
||||
output_ptr, # [num_tokens]
|
||||
input_ptr, # [batch_size]
|
||||
cu_num_tokens_ptr, # [batch_size]
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
if req_idx == 0: # noqa: SIM108
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
src_val = tl.load(input_ptr + req_idx)
|
||||
src_val = tl.where(src_val == replace_from, replace_to, src_val)
|
||||
offset = tl.arange(0, MAX_NUM_TOKENS)
|
||||
tl.store(output_ptr + start_idx + offset,
|
||||
src_val,
|
||||
mask=offset < num_tokens)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sample_recovered_tokens_kernel(
|
||||
output_token_ids_ptr, # [num_tokens]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
q_ptr, # [batch_size, vocab_size]
|
||||
vocab_size,
|
||||
PADDED_VOCAB_SIZE: tl.constexpr,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
# Early exit for out-of-range positions.
|
||||
pos = tl.program_id(1)
|
||||
if pos >= num_draft_tokens:
|
||||
return
|
||||
|
||||
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
# Temporarily zero out the probability of the draft token.
|
||||
# This is essentially the same as target_prob - draft_prob, except that
|
||||
# n-gram does not have draft_prob. We regard it as 1.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
0)
|
||||
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
prob = tl.maximum(target_prob - draft_prob, 0)
|
||||
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
|
||||
# `tl.argmax` will select the maximum value.
|
||||
|
||||
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=float("-inf"))
|
||||
recovered_id = tl.argmax(prob / q, axis=-1)
|
||||
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
||||
|
||||
if NO_DRAFT_PROBS:
|
||||
# Restore the original probability.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
orig_prob)
|
||||
286
vllm/v1/sample/sampler.py
Normal file
286
vllm/v1/sample/sampler.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.utils import async_tensor_h2d, is_pin_memory_available
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
||||
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
|
||||
apply_min_token_penalties)
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
# This is different from the V0 sampler, which uses the logits that
|
||||
# is used for sampling (after penalties and temperature scaling).
|
||||
# TODO(rob): provide option for logprobs post sampling.
|
||||
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Apply allowed token ids.
|
||||
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
||||
# Apply bad words exclusion.
|
||||
logits = self.apply_bad_words(logits, sampling_metadata)
|
||||
# Apply logits bias.
|
||||
logits = self.apply_logits_bias(logits, sampling_metadata)
|
||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||
logits = self.apply_penalties(logits, sampling_metadata)
|
||||
# Sample the next token.
|
||||
sampled = self.sample(logits, sampling_metadata)
|
||||
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
||||
# with subsequent operations that may use these values as indices.
|
||||
# This conversion is necessary because FlashInfer sampling operations
|
||||
# return int32 (while PyTorch argmax and topk return int64).
|
||||
sampled = sampled.long()
|
||||
|
||||
# Gather the logprobs of the topk and sampled token (if requested).
|
||||
# Get logprobs and rank tensors (if requested)
|
||||
logprobs_tensors = None if num_logprobs is None else \
|
||||
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Sample logits based on sampling metadata.
|
||||
|
||||
The various logits processing functions called in this method
|
||||
may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
assert not (sampling_metadata.all_greedy
|
||||
and sampling_metadata.all_random)
|
||||
if sampling_metadata.all_random:
|
||||
greedy_sampled = None
|
||||
else:
|
||||
greedy_sampled = self.greedy_sample(logits)
|
||||
if sampling_metadata.all_greedy:
|
||||
return greedy_sampled
|
||||
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply min_p.
|
||||
if sampling_metadata.min_p is not None:
|
||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled = self.topk_topp_sampler(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
if greedy_sampled is None:
|
||||
return random_sampled
|
||||
|
||||
sampled = torch.where(
|
||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled,
|
||||
random_sampled,
|
||||
out=greedy_sampled, # Reuse tensor
|
||||
)
|
||||
return sampled
|
||||
|
||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
Must be int64.
|
||||
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == torch.int64
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_ids = token_ids.unsqueeze(-1)
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
# Concatenate together with the topk.
|
||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
indices = indices.to(torch.int32)
|
||||
|
||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.min_tokens:
|
||||
apply_min_token_penalties(logits,
|
||||
sampling_metadata.output_token_ids,
|
||||
sampling_metadata.min_tokens)
|
||||
if not sampling_metadata.no_penalties:
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
logits = apply_all_penalties(
|
||||
logits,
|
||||
sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.output_token_ids,
|
||||
)
|
||||
return logits
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
min_p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Filters logits using adaptive probability thresholding.
|
||||
"""
|
||||
# Convert logits to probability distribution
|
||||
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||
# Calculate maximum probabilities per sequence
|
||||
max_probabilities = torch.amax(probability_values,
|
||||
dim=-1,
|
||||
keepdim=True)
|
||||
# Reshape min_p for broadcasting
|
||||
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
|
||||
# Identify valid tokens using threshold comparison
|
||||
valid_token_mask = probability_values >= adjusted_min_p
|
||||
# Apply mask using boolean indexing
|
||||
logits[~valid_token_mask] = -float('inf')
|
||||
return logits
|
||||
|
||||
def apply_logits_bias(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
# TODO(houseroad): this implementation is extremely inefficient.
|
||||
# One idea is implement this as a PyTorch C++ op, and we may
|
||||
# even optimize the logit_bias layout.
|
||||
|
||||
rows: list[int] = []
|
||||
cols: list[int] = []
|
||||
vals: list[float] = []
|
||||
|
||||
# Get vocabulary size from logits
|
||||
vocab_size = logits.shape[-1]
|
||||
|
||||
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
|
||||
if logit_bias:
|
||||
for token_id, bias in logit_bias.items():
|
||||
# Check token_id bounds to ensure within vocabulary
|
||||
if token_id < 0 or token_id >= vocab_size:
|
||||
raise ValueError(
|
||||
f"token_id {token_id} in logit_bias contains "
|
||||
f"out-of-vocab token id. Vocabulary size: "
|
||||
f"{vocab_size}")
|
||||
rows.append(i)
|
||||
cols.append(token_id)
|
||||
vals.append(bias)
|
||||
|
||||
if rows:
|
||||
indices = async_tensor_h2d([rows, cols], torch.int64,
|
||||
logits.device, self.pin_memory)
|
||||
values = async_tensor_h2d(vals, torch.float, logits.device,
|
||||
self.pin_memory)
|
||||
logits.index_put_(tuple(indices), values=values, accumulate=True)
|
||||
return logits
|
||||
|
||||
def apply_allowed_token_ids(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.allowed_token_ids_mask is not None:
|
||||
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
||||
float("-inf"))
|
||||
return logits
|
||||
|
||||
def apply_bad_words(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.bad_words_token_ids:
|
||||
apply_bad_words(
|
||||
logits,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.output_token_ids,
|
||||
)
|
||||
return logits
|
||||
0
vllm/v1/sample/tpu/__init__.py
Normal file
0
vllm/v1/sample/tpu/__init__.py
Normal file
124
vllm/v1/sample/tpu/metadata.py
Normal file
124
vllm/v1/sample/tpu/metadata.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
DEFAULT_SAMPLING_PARAMS = dict(
|
||||
temperature=-1.0,
|
||||
min_p=0.0,
|
||||
# strictly disabled for now
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
# frequency_penalties=0.0,
|
||||
# presence_penalties=0.0,
|
||||
# repetition_penalties=0.0,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TPUSupportedSamplingMetadata:
|
||||
# This class exposes a more xla-friendly interface than SamplingMetadata
|
||||
# on TPU, in particular all arguments should be traceable and no optionals
|
||||
# are allowed, to avoid graph recompilation on Nones.
|
||||
temperature: torch.Tensor = None
|
||||
|
||||
min_p: torch.Tensor = None
|
||||
top_k: torch.Tensor = None
|
||||
top_p: torch.Tensor = None
|
||||
|
||||
all_greedy: bool = True
|
||||
|
||||
# Whether logprobs are to be gathered in this batch of request. To balance
|
||||
# out compile time and runtime, a fixed `max_number_logprobs` value is used
|
||||
# when gathering logprobs, regardless of the values specified in the batch.
|
||||
logprobs: bool = False
|
||||
|
||||
# TODO No penalties for now
|
||||
no_penalties: bool = True
|
||||
prompt_token_ids = None
|
||||
frequency_penalties = None
|
||||
presence_penalties = None
|
||||
repetition_penalties = None
|
||||
# should use tensor
|
||||
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
|
||||
|
||||
min_tokens = None # impl is not vectorized
|
||||
|
||||
logit_bias: list[Optional[dict[int, float]]] = field(
|
||||
default_factory=lambda: list())
|
||||
|
||||
allowed_token_ids_mask = None
|
||||
bad_words_token_ids = None
|
||||
|
||||
# Generator not supported by xla
|
||||
_generators: dict[int,
|
||||
torch.Generator] = field(default_factory=lambda: dict())
|
||||
|
||||
@property
|
||||
def generators(self) -> dict[int, torch.Generator]:
|
||||
# Generator not supported by torch/xla. This field must be immutable.
|
||||
return self._generators
|
||||
|
||||
@classmethod
|
||||
def from_input_batch(
|
||||
cls,
|
||||
input_batch: InputBatch,
|
||||
padded_num_reqs: int,
|
||||
xla_device: torch.device,
|
||||
generate_params_if_all_greedy: bool = False
|
||||
) -> "TPUSupportedSamplingMetadata":
|
||||
"""
|
||||
Copy sampling tensors slices from `input_batch` to on device tensors.
|
||||
|
||||
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
|
||||
slices dynamic shapes on device tensors. This impl moves the dynamic
|
||||
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
|
||||
|
||||
Args:
|
||||
input_batch: The input batch containing sampling parameters.
|
||||
padded_num_reqs: The padded number of requests.
|
||||
xla_device: The XLA device.
|
||||
generate_params_if_all_greedy: If True, generate sampling parameters
|
||||
even if all requests are greedy. this is useful for cases where
|
||||
we want to pre-compile a graph with sampling parameters, even if
|
||||
they are not strictly needed for greedy decoding.
|
||||
"""
|
||||
needs_logprobs = input_batch.max_num_logprobs>0 if \
|
||||
input_batch.max_num_logprobs else False
|
||||
# Early return to avoid unnecessary cpu to tpu copy
|
||||
if (input_batch.all_greedy is True
|
||||
and generate_params_if_all_greedy is False):
|
||||
return cls(all_greedy=True, logprobs=needs_logprobs)
|
||||
|
||||
num_reqs = input_batch.num_reqs
|
||||
|
||||
def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
|
||||
# Pad value is the default one.
|
||||
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
|
||||
|
||||
fill_slice(input_batch.temperature_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
fill_slice(input_batch.min_p_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||
fill_slice(input_batch.top_k_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||
fill_slice(input_batch.top_p_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||
|
||||
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
||||
return cls(
|
||||
temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
|
||||
to(xla_device),
|
||||
all_greedy=input_batch.all_greedy,
|
||||
# TODO enable more and avoid returning None values
|
||||
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
logprobs=needs_logprobs)
|
||||
145
vllm/v1/sample/tpu/sampler.py
Normal file
145
vllm/v1/sample/tpu/sampler.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Sampler layer implementing TPU supported operations."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Sample the next token.
|
||||
sampled = self.sample(logits, sampling_metadata)
|
||||
|
||||
# These are TPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=None)
|
||||
return sampler_output
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
greedy_sampled = self.greedy_sample(logits)
|
||||
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply min_p.
|
||||
if sampling_metadata.min_p is not None:
|
||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled = self.topk_topp_sampler(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled, random_sampled)
|
||||
return sampled
|
||||
|
||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
|
||||
Args:
|
||||
logits: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_ids = token_ids.unsqueeze(-1)
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
# Concatenate together with the topk.
|
||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
indices = indices.to(torch.int32)
|
||||
|
||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
min_p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Filters logits using adaptive probability thresholding.
|
||||
"""
|
||||
# Convert logits to probability distribution
|
||||
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||
# Calculate maximum probabilities per sequence
|
||||
max_probabilities = torch.amax(probability_values,
|
||||
dim=-1,
|
||||
keepdim=True)
|
||||
# Reshape min_p for broadcasting
|
||||
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
|
||||
# Identify valid tokens using threshold comparison
|
||||
valid_token_mask = probability_values >= adjusted_min_p
|
||||
# Apply mask using boolean indexing (xla friendly)
|
||||
logits.masked_fill_(~valid_token_mask, -float("inf"))
|
||||
return logits
|
||||
315
vllm/v1/serial_utils.py
Normal file
315
vllm/v1/serial_utils.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import pickle
|
||||
from collections.abc import Sequence
|
||||
from inspect import isclass
|
||||
from types import FunctionType
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
from msgspec import msgpack
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.inputs import (BaseMultiModalField,
|
||||
MultiModalBatchedField,
|
||||
MultiModalFieldConfig, MultiModalFieldElem,
|
||||
MultiModalFlatField, MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalSharedField, NestedTensors)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CUSTOM_TYPE_PICKLE = 1
|
||||
CUSTOM_TYPE_CLOUDPICKLE = 2
|
||||
CUSTOM_TYPE_RAW_VIEW = 3
|
||||
|
||||
# MultiModalField class serialization type map.
|
||||
# These need to list all possible field types and match them
|
||||
# to factory methods in `MultiModalFieldConfig`.
|
||||
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
|
||||
MultiModalFlatField: "flat",
|
||||
MultiModalSharedField: "shared",
|
||||
MultiModalBatchedField: "batched",
|
||||
}
|
||||
|
||||
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
|
||||
|
||||
|
||||
def _log_insecure_serialization_warning():
|
||||
logger.warning_once("Allowing insecure serialization using pickle due to "
|
||||
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
|
||||
|
||||
|
||||
class MsgpackEncoder:
|
||||
"""Encoder with custom torch tensor and numpy array serialization.
|
||||
|
||||
Note that unlike vanilla `msgspec` Encoders, this interface is generally
|
||||
not thread-safe when encoding tensors / numpy arrays.
|
||||
|
||||
By default, arrays below 256B are serialized inline Larger will get sent
|
||||
via dedicated messages. Note that this is a per-tensor limit.
|
||||
"""
|
||||
|
||||
def __init__(self, size_threshold: Optional[int] = None):
|
||||
if size_threshold is None:
|
||||
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
|
||||
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
|
||||
# This is used as a local stash of buffers that we can then access from
|
||||
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
|
||||
# pass custom data to the hook otherwise.
|
||||
self.aux_buffers: Optional[list[bytestr]] = None
|
||||
self.size_threshold = size_threshold
|
||||
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
_log_insecure_serialization_warning()
|
||||
|
||||
def encode(self, obj: Any) -> Sequence[bytestr]:
|
||||
try:
|
||||
self.aux_buffers = bufs = [b'']
|
||||
bufs[0] = self.encoder.encode(obj)
|
||||
# This `bufs` list allows us to collect direct pointers to backing
|
||||
# buffers of tensors and np arrays, and return them along with the
|
||||
# top-level encoded buffer instead of copying their data into the
|
||||
# new buffer.
|
||||
return bufs
|
||||
finally:
|
||||
self.aux_buffers = None
|
||||
|
||||
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
|
||||
try:
|
||||
self.aux_buffers = [buf]
|
||||
bufs = self.aux_buffers
|
||||
self.encoder.encode_into(obj, buf)
|
||||
return bufs
|
||||
finally:
|
||||
self.aux_buffers = None
|
||||
|
||||
def enc_hook(self, obj: Any) -> Any:
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return self._encode_tensor(obj)
|
||||
|
||||
# Fall back to pickle for object or void kind ndarrays.
|
||||
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
|
||||
return self._encode_ndarray(obj)
|
||||
|
||||
if isinstance(obj, slice):
|
||||
# We are assuming only int-based values will be used here.
|
||||
return tuple(
|
||||
int(v) if v is not None else None
|
||||
for v in (obj.start, obj.stop, obj.step))
|
||||
|
||||
if isinstance(obj, MultiModalKwargs):
|
||||
mm: MultiModalKwargs = obj
|
||||
if not mm.modalities:
|
||||
# just return the main dict if there are no modalities.
|
||||
return dict(mm)
|
||||
|
||||
# ignore the main dict, it will be re-indexed.
|
||||
# Encode a list of MultiModalKwargsItems as plain dicts
|
||||
# + special handling for .field.
|
||||
# Any tensors *not* indexed by modality will be ignored.
|
||||
return [[{
|
||||
"modality": elem.modality,
|
||||
"key": elem.key,
|
||||
"data": self._encode_nested_tensors(elem.data),
|
||||
"field": self._encode_mm_field(elem.field),
|
||||
} for elem in item.values()]
|
||||
for itemlist in mm._items_by_modality.values()
|
||||
for item in itemlist]
|
||||
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
raise TypeError(f"Object of type {type(obj)} is not serializable"
|
||||
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
|
||||
"fallback to pickle-based serialization.")
|
||||
|
||||
if isinstance(obj, FunctionType):
|
||||
# `pickle` is generally faster than cloudpickle, but can have
|
||||
# problems serializing methods.
|
||||
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
|
||||
|
||||
return msgpack.Ext(CUSTOM_TYPE_PICKLE,
|
||||
pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
|
||||
|
||||
def _encode_ndarray(
|
||||
self, obj: np.ndarray
|
||||
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
|
||||
assert self.aux_buffers is not None
|
||||
# If the array is non-contiguous, we need to copy it first
|
||||
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
|
||||
if not obj.shape or obj.nbytes < self.size_threshold:
|
||||
# Encode small arrays and scalars inline. Using this extension type
|
||||
# ensures we can avoid copying when decoding.
|
||||
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
|
||||
else:
|
||||
# Otherwise encode index of backing buffer to avoid copy.
|
||||
data = len(self.aux_buffers)
|
||||
self.aux_buffers.append(arr_data)
|
||||
|
||||
# We serialize the ndarray as a tuple of native types.
|
||||
# The data is either inlined if small, or an index into a list of
|
||||
# backing buffers that we've stashed in `aux_buffers`.
|
||||
return obj.dtype.str, obj.shape, data
|
||||
|
||||
def _encode_tensor(
|
||||
self, obj: torch.Tensor
|
||||
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
|
||||
assert self.aux_buffers is not None
|
||||
# view the tensor as a contiguous 1D array of bytes
|
||||
arr = obj.flatten().contiguous().view(torch.uint8).numpy()
|
||||
if obj.nbytes < self.size_threshold:
|
||||
# Smaller tensors are encoded inline, just like ndarrays.
|
||||
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
|
||||
else:
|
||||
# Otherwise encode index of backing buffer to avoid copy.
|
||||
data = len(self.aux_buffers)
|
||||
self.aux_buffers.append(arr.data)
|
||||
dtype = str(obj.dtype).removeprefix("torch.")
|
||||
return dtype, obj.shape, data
|
||||
|
||||
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
|
||||
if isinstance(nt, torch.Tensor):
|
||||
return self._encode_tensor(nt)
|
||||
if isinstance(nt, (int, float)):
|
||||
# Although it violates NestedTensors type, MultiModalKwargs
|
||||
# values are sometimes floats.
|
||||
return nt
|
||||
return [self._encode_nested_tensors(x) for x in nt]
|
||||
|
||||
def _encode_mm_field(self, field: BaseMultiModalField):
|
||||
# Figure out the factory name for the field type.
|
||||
name = MMF_CLASS_TO_FACTORY.get(field.__class__)
|
||||
if not name:
|
||||
raise TypeError(f"Unsupported field type: {field.__class__}")
|
||||
# We just need to copy all of the field values in order
|
||||
# which will be then used to reconstruct the field.
|
||||
field_values = (getattr(field, f.name)
|
||||
for f in dataclasses.fields(field))
|
||||
return name, *field_values
|
||||
|
||||
|
||||
class MsgpackDecoder:
|
||||
"""Decoder with custom torch tensor and numpy array serialization.
|
||||
|
||||
Note that unlike vanilla `msgspec` Decoders, this interface is generally
|
||||
not thread-safe when encoding tensors / numpy arrays.
|
||||
"""
|
||||
|
||||
def __init__(self, t: Optional[Any] = None):
|
||||
args = () if t is None else (t, )
|
||||
self.decoder = msgpack.Decoder(*args,
|
||||
ext_hook=self.ext_hook,
|
||||
dec_hook=self.dec_hook)
|
||||
self.aux_buffers: Sequence[bytestr] = ()
|
||||
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
_log_insecure_serialization_warning()
|
||||
|
||||
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
|
||||
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
|
||||
# TODO - This check can become `isinstance(bufs, bytestr)`
|
||||
# as of Python 3.10.
|
||||
return self.decoder.decode(bufs)
|
||||
|
||||
self.aux_buffers = bufs
|
||||
try:
|
||||
return self.decoder.decode(bufs[0])
|
||||
finally:
|
||||
self.aux_buffers = ()
|
||||
|
||||
def dec_hook(self, t: type, obj: Any) -> Any:
|
||||
# Given native types in `obj`, convert to type `t`.
|
||||
if isclass(t):
|
||||
if issubclass(t, np.ndarray):
|
||||
return self._decode_ndarray(obj)
|
||||
if issubclass(t, torch.Tensor):
|
||||
return self._decode_tensor(obj)
|
||||
if t is slice:
|
||||
return slice(*obj)
|
||||
if issubclass(t, MultiModalKwargs):
|
||||
if isinstance(obj, list):
|
||||
return MultiModalKwargs.from_items(
|
||||
self._decode_mm_items(obj))
|
||||
return MultiModalKwargs({
|
||||
k: self._decode_nested_tensors(v)
|
||||
for k, v in obj.items()
|
||||
})
|
||||
return obj
|
||||
|
||||
def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
||||
dtype, shape, data = arr
|
||||
# zero-copy decode. We assume the ndarray will not be kept around,
|
||||
# as it now locks the whole received message buffer in memory.
|
||||
buffer = self.aux_buffers[data] if isinstance(data, int) else data
|
||||
return np.frombuffer(buffer, dtype=dtype).reshape(shape)
|
||||
|
||||
def _decode_tensor(self, arr: Any) -> torch.Tensor:
|
||||
dtype, shape, data = arr
|
||||
# Copy from inline representation, to decouple the memory storage
|
||||
# of the message from the original buffer. And also make Torch
|
||||
# not complain about a readonly memoryview.
|
||||
buffer = self.aux_buffers[data] if isinstance(data, int) \
|
||||
else bytearray(data)
|
||||
torch_dtype = getattr(torch, dtype)
|
||||
assert isinstance(torch_dtype, torch.dtype)
|
||||
if not buffer: # torch.frombuffer doesn't like empty buffers
|
||||
assert 0 in shape
|
||||
return torch.empty(shape, dtype=torch_dtype)
|
||||
# Create uint8 array
|
||||
arr = torch.frombuffer(buffer, dtype=torch.uint8)
|
||||
# Convert back to proper shape & type
|
||||
return arr.view(torch_dtype).view(shape)
|
||||
|
||||
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
|
||||
decoded_items = []
|
||||
for item in obj:
|
||||
elems = []
|
||||
for v in item:
|
||||
v["data"] = self._decode_nested_tensors(v["data"])
|
||||
# Reconstruct the field processor using MultiModalFieldConfig
|
||||
factory_meth_name, *field_args = v["field"]
|
||||
factory_meth = getattr(MultiModalFieldConfig,
|
||||
factory_meth_name)
|
||||
|
||||
# Special case: decode the union "slices" field of
|
||||
# MultiModalFlatField
|
||||
if factory_meth_name == "flat":
|
||||
field_args[0] = self._decode_nested_slices(field_args[0])
|
||||
|
||||
v["field"] = factory_meth(None, *field_args).field
|
||||
elems.append(MultiModalFieldElem(**v))
|
||||
decoded_items.append(MultiModalKwargsItem.from_elems(elems))
|
||||
return decoded_items
|
||||
|
||||
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
|
||||
if isinstance(obj, (int, float)):
|
||||
# Although it violates NestedTensors type, MultiModalKwargs
|
||||
# values are sometimes floats.
|
||||
return obj
|
||||
if not isinstance(obj, list):
|
||||
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
|
||||
if obj and isinstance(obj[0], str):
|
||||
return self._decode_tensor(obj)
|
||||
return [self._decode_nested_tensors(x) for x in obj]
|
||||
|
||||
def _decode_nested_slices(self, obj: Any) -> Any:
|
||||
assert isinstance(obj, (list, tuple))
|
||||
if obj and not isinstance(obj[0], (list, tuple)):
|
||||
return slice(*obj)
|
||||
return [self._decode_nested_slices(x) for x in obj]
|
||||
|
||||
def ext_hook(self, code: int, data: memoryview) -> Any:
|
||||
if code == CUSTOM_TYPE_RAW_VIEW:
|
||||
return data
|
||||
|
||||
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
if code == CUSTOM_TYPE_PICKLE:
|
||||
return pickle.loads(data)
|
||||
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
||||
return cloudpickle.loads(data)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Extension type code {code} is not supported")
|
||||
0
vllm/v1/spec_decode/__init__.py
Normal file
0
vllm/v1/spec_decode/__init__.py
Normal file
434
vllm/v1/spec_decode/eagle.py
Normal file
434
vllm/v1/spec_decode/eagle.py
Normal file
@@ -0,0 +1,434 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
class EagleProposer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
self.method = self.speculative_config.method
|
||||
|
||||
self.runner = runner
|
||||
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.num_speculative_tokens = (
|
||||
self.speculative_config.num_speculative_tokens)
|
||||
self.max_num_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
# We need to get the hidden size from the draft model config because
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE and
|
||||
not self.vllm_config.model_config.enforce_eager)
|
||||
self.cudagraph_batch_sizes = list(
|
||||
reversed(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
# persistent buffers for cuda graph
|
||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.positions = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
self.hidden_states = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
|
||||
1,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
def propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
target_token_ids: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_positions: torch.Tensor,
|
||||
# [num_tokens, hidden_size]
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_slot_mapping: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
# [batch_size + 1] starting with 0
|
||||
cu_num_tokens: torch.Tensor,
|
||||
# [batch_size, max_num_blocks_per_req]
|
||||
block_table: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
last_token_indices = cu_num_tokens[1:] - 1
|
||||
|
||||
if self.method == "eagle3":
|
||||
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||
target_hidden_states = self.model.combine_hidden_states(
|
||||
target_hidden_states)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
# FA requires seq_len to have dtype int32.
|
||||
seq_lens = (target_positions[last_token_indices] + 1).int()
|
||||
|
||||
if self.method in ["eagle", "eagle3"]:
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
max_seq_len = seq_lens.max().item()
|
||||
max_num_tokens = (cu_num_tokens[1:] -
|
||||
cu_num_tokens[:-1]).max().item()
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_num_tokens,
|
||||
query_start_loc=cu_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table,
|
||||
slot_mapping=target_slot_mapping,
|
||||
# TODO(woosuk): Support cascade attention.
|
||||
use_cascade=False,
|
||||
common_prefix_len=0,
|
||||
cu_prefix_query_lens=None,
|
||||
prefix_kv_lens=None,
|
||||
suffix_kv_lens=None,
|
||||
)
|
||||
elif self.method == "deepseek_mtp":
|
||||
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
|
||||
max_query_len = query_lens.max().item()
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=cu_num_tokens,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
)
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata = self.runner.attn_metadata_builders[0].build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {self.method}")
|
||||
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
ret_hidden_states = self.model(
|
||||
self.input_ids[:num_input_tokens],
|
||||
self.positions[:num_input_tokens],
|
||||
self.hidden_states[:num_input_tokens],
|
||||
)
|
||||
if self.method == "deepseek_mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
# TODO: Currently, MTP module released by deepseek only has
|
||||
# one layer. Adapt this code to support multiple layers once
|
||||
# there's a multi-layer MTP module.
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
if self.use_cuda_graph and \
|
||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.max_query_len = 1
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
for _ in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
# but adjust the position ids and slot mappings to avoid the
|
||||
# out-of-range access during the model execution. The draft tokens
|
||||
# generated with this adjustment should be ignored.
|
||||
exceeds_max_model_len = positions >= self.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata.max_seq_len += 1
|
||||
attn_metadata.seq_lens += 1
|
||||
# Consider max model length.
|
||||
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
|
||||
self.max_model_len)
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_numbers = clamped_positions // self.block_size
|
||||
block_ids = block_table.gather(dim=1,
|
||||
index=block_numbers.view(-1, 1))
|
||||
block_ids = block_ids.view(-1)
|
||||
attn_metadata.slot_mapping = (block_ids * self.block_size +
|
||||
clamped_positions % self.block_size)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
|
||||
PADDING_SLOT_ID)
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
|
||||
# Run the model.
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
self.input_ids[:input_batch_size],
|
||||
self.positions[:input_batch_size],
|
||||
self.hidden_states[:input_batch_size],
|
||||
)
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||
None)
|
||||
|
||||
# TODO(wenlong): get more than one token for tree attention
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
# [batch_size, num_speculative_tokens]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs(
|
||||
# [batch_size + 1]
|
||||
cu_target_query_lens: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_rejected_tokens: torch.Tensor,
|
||||
num_tokens: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
||||
# num_rejected_tokens: [n1, n2, n3]
|
||||
# num_tokens_per_req: [a - n1, b - n2, c - n3]
|
||||
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||
# token_indices: [0, 1, ..., a - n1 - 1,
|
||||
# a, a + 1, ..., a + b - n2 - 1,
|
||||
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
||||
|
||||
# [0, a, a + b, a + b + c] -> [a, b, c]
|
||||
query_len_per_req = (cu_target_query_lens[1:] -
|
||||
cu_target_query_lens[:-1])
|
||||
# [a, b, c] -> [a - n1, b - n2, c - n3]
|
||||
num_tokens_per_req = query_len_per_req - num_rejected_tokens
|
||||
|
||||
# [a - n1, b - n2, c - n3] ->
|
||||
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
|
||||
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
|
||||
token_indices = torch.empty(
|
||||
num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=cu_target_query_lens.device,
|
||||
)
|
||||
batch_size = num_rejected_tokens.shape[0]
|
||||
BLOCK_SIZE = 1024
|
||||
prepare_eagle_input_kernel[(batch_size, )](
|
||||
token_indices,
|
||||
cu_target_query_lens,
|
||||
cu_num_tokens,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return cu_num_tokens, token_indices
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
draft_model_config = \
|
||||
self.vllm_config.speculative_config.draft_model_config
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
|
||||
|
||||
self.model = get_model(vllm_config=self.vllm_config,
|
||||
model_config=draft_model_config)
|
||||
|
||||
draft_attn_layer_names = (
|
||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
|
||||
target_attn_layer_names)
|
||||
|
||||
self.attn_layer_names = list(draft_attn_layer_names)
|
||||
|
||||
# share embed_tokens with the target model if needed
|
||||
if get_pp_group().world_size == 1 \
|
||||
and self.model.model.embed_tokens.weight.shape \
|
||||
== target_model.model.embed_tokens.weight.shape:
|
||||
logger.info(
|
||||
"Assuming the EAGLE head shares the same vocab embedding" \
|
||||
" with the target model."
|
||||
)
|
||||
del self.model.model.embed_tokens
|
||||
self.model.model.embed_tokens = target_model.model.embed_tokens
|
||||
else:
|
||||
logger.info(
|
||||
"The EAGLE head's vocab embedding will be loaded separately" \
|
||||
" from the target model."
|
||||
)
|
||||
|
||||
# share lm_head with the target model if needed
|
||||
# some model definition do not define lm_head explicitly
|
||||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
||||
if self.vllm_config.speculative_config.method != "eagle3" and \
|
||||
hasattr(target_model, "lm_head"):
|
||||
logger.info("Loading EAGLE LM head weights from the target model.")
|
||||
if supports_multimodal(target_model):
|
||||
self.model.lm_head = target_model.get_language_model().lm_head
|
||||
else:
|
||||
self.model.lm_head = target_model.lm_head
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
with set_forward_context(None, self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
self.model(
|
||||
self.input_ids[:num_tokens],
|
||||
self.positions[:num_tokens],
|
||||
self.hidden_states[:num_tokens],
|
||||
)
|
||||
|
||||
def validate_same_kv_cache_group(self,
|
||||
kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Validate that all eagle layers belong to the same KVCacheGroup.
|
||||
Need this assumption to ensure all eagle layers can use the
|
||||
same AttentionMetadata.
|
||||
May extend to multiple AttentionMetadata in the future.
|
||||
"""
|
||||
kv_cache_groups: dict[str, int] = {}
|
||||
for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
kv_cache_groups[layer_name] = id
|
||||
assert len(
|
||||
set([
|
||||
kv_cache_groups[layer_name]
|
||||
for layer_name in self.attn_layer_names
|
||||
])
|
||||
) == 1, "All eagle layers should belong to the same kv cache group"
|
||||
|
||||
|
||||
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
||||
# to sample the draft tokens. We will use this after we find a way to manage
|
||||
# the draft prob tensor.
|
||||
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
|
||||
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
|
||||
# We should refactor this to reuse the same sampling implementation.
|
||||
def compute_probs_and_sample_next_token(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if sampling_metadata.all_greedy:
|
||||
# For greedy requests, draft_probs is not used in rejection sampling.
|
||||
# Therefore, we can just return the logits.
|
||||
probs = logits
|
||||
next_token_ids = logits.argmax(dim=-1)
|
||||
return next_token_ids, probs
|
||||
|
||||
is_greedy = sampling_metadata.temperature == -1
|
||||
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
||||
logits.div_(temperature.view(-1, 1))
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
|
||||
# generating the draft tokens. We only use the temperature. While this
|
||||
# could degrade the acceptance rate, it does not affect the distribution
|
||||
# of the generated tokens after rejection sampling.
|
||||
|
||||
# TODO(woosuk): Consider seeds.
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
# NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
|
||||
# will be used later for rejection sampling.
|
||||
next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
|
||||
if not sampling_metadata.all_random:
|
||||
greedy_token_ids = probs.argmax(dim=-1)
|
||||
next_token_ids = torch.where(
|
||||
is_greedy,
|
||||
greedy_token_ids,
|
||||
next_token_ids,
|
||||
)
|
||||
return next_token_ids, probs
|
||||
62
vllm/v1/spec_decode/medusa.py
Normal file
62
vllm/v1/spec_decode/medusa.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
# Initialize logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MedusaProposer:
|
||||
"""
|
||||
Medusa proposer class for generating token sequences
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
# Save config parameters
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.max_num_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
self.hidden_size = vllm_config.speculative_config.\
|
||||
draft_model_config.get_hidden_size(
|
||||
)
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
def propose(
|
||||
self,
|
||||
target_hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Generate blocks and compute logits
|
||||
blocks = self.model(target_hidden_states)
|
||||
logits = self.model.compute_logits(blocks, None)
|
||||
|
||||
# Get draft tokens and transpose the result
|
||||
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
|
||||
return [list(row) for row in zip(*draft_tokens)]
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
self.model = get_model(vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.
|
||||
speculative_config.draft_model_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self, num_tokens: int) -> None:
|
||||
hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
with set_forward_context(None, self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
self.model(hidden_states)
|
||||
62
vllm/v1/spec_decode/metadata.py
Normal file
62
vllm/v1/spec_decode/metadata.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecDecodeMetadata:
|
||||
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor
|
||||
# [batch_size]
|
||||
num_draft_tokens: list[int]
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor
|
||||
# [num_tokens]
|
||||
target_logits_indices: torch.Tensor
|
||||
# [batch_size]
|
||||
bonus_logits_indices: torch.Tensor
|
||||
# [num_tokens + batch_size]
|
||||
logits_indices: torch.Tensor
|
||||
|
||||
def __post_init__(self):
|
||||
self.max_spec_len = max(self.num_draft_tokens)
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
draft_token_ids: list[list[int]],
|
||||
device: torch.device,
|
||||
) -> "SpecDecodeMetadata":
|
||||
batch_size = len(draft_token_ids)
|
||||
num_draft_tokens = [len(ids) for ids in draft_token_ids]
|
||||
flattened_draft_token_ids = sum(draft_token_ids, [])
|
||||
num_tokens = len(flattened_draft_token_ids)
|
||||
|
||||
draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
||||
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(
|
||||
device)
|
||||
|
||||
target_logits_indices = torch.zeros(num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
bonus_logits_indices = torch.zeros(batch_size,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
logits_indices = torch.zeros(num_tokens + batch_size,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
return cls(
|
||||
draft_token_ids=draft_token_ids_tensor,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
178
vllm/v1/spec_decode/metrics.py
Normal file
178
vllm/v1/spec_decode/metrics.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import prometheus_client
|
||||
|
||||
from vllm.config import SpeculativeConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecDecodingStats:
|
||||
"""Per-step iteration decoding stats from scheduler.
|
||||
|
||||
Each scheduler step, statistics on spec decoding performance are
|
||||
aggregated across requests by the scheduler and returned to the
|
||||
frontend in EngineCoreOutputs->SchedulerStats.
|
||||
"""
|
||||
|
||||
num_spec_tokens: int
|
||||
num_drafts: int = 0
|
||||
num_draft_tokens: int = 0
|
||||
num_accepted_tokens: int = 0
|
||||
num_accepted_tokens_per_pos: list[int] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def new(cls, num_spec_tokens: int) -> "SpecDecodingStats":
|
||||
return cls(num_spec_tokens=num_spec_tokens,
|
||||
num_accepted_tokens_per_pos=[0] * num_spec_tokens)
|
||||
|
||||
def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int):
|
||||
self.num_drafts += 1
|
||||
self.num_draft_tokens += num_draft_tokens
|
||||
self.num_accepted_tokens += num_accepted_tokens
|
||||
assert num_accepted_tokens <= self.num_spec_tokens
|
||||
for i in range(num_accepted_tokens):
|
||||
self.num_accepted_tokens_per_pos[i] += 1
|
||||
|
||||
|
||||
class SpecDecodingLogging:
|
||||
"""Aggregate and log spec decoding metrics.
|
||||
|
||||
LoggingStatLogger aggregates per-iteration metrics over a set
|
||||
time interval using observe() and then logs them using log()
|
||||
before resetting to zero.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.num_drafts: list[int] = []
|
||||
self.num_draft_tokens: list[int] = []
|
||||
self.num_accepted_tokens: list[int] = []
|
||||
self.accepted_tokens_per_pos_lists: list[list[int]] = []
|
||||
|
||||
def observe(self, spec_decoding_stats: SpecDecodingStats):
|
||||
self.num_drafts.append(spec_decoding_stats.num_drafts)
|
||||
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
|
||||
self.num_accepted_tokens.append(
|
||||
spec_decoding_stats.num_accepted_tokens)
|
||||
self.accepted_tokens_per_pos_lists.append(
|
||||
spec_decoding_stats.num_accepted_tokens_per_pos)
|
||||
|
||||
def log(self, log_fn=logger.info):
|
||||
if not self.num_drafts:
|
||||
return
|
||||
num_drafts = np.sum(self.num_drafts)
|
||||
num_draft_tokens = np.sum(self.num_draft_tokens)
|
||||
num_accepted_tokens = np.sum(self.num_accepted_tokens)
|
||||
|
||||
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens *
|
||||
100 if num_draft_tokens > 0 else float("nan"))
|
||||
|
||||
# Conventionally, mean acceptance length includes the bonus token
|
||||
mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts)
|
||||
|
||||
pos_matrix = np.array(self.accepted_tokens_per_pos_lists)
|
||||
acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
|
||||
rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates)
|
||||
|
||||
log_fn(
|
||||
"SpecDecoding metrics: "
|
||||
"Draft acceptance rate: %.1f%%, "
|
||||
"Mean acceptance length: %.2f, "
|
||||
"Accepted: %d tokens, "
|
||||
"Drafted: %d tokens, "
|
||||
"Per-position acceptance rate: %s",
|
||||
draft_acceptance_rate,
|
||||
mean_acceptance_length,
|
||||
num_accepted_tokens,
|
||||
num_draft_tokens,
|
||||
rates_str,
|
||||
)
|
||||
self.reset()
|
||||
|
||||
|
||||
class SpecDecodingProm:
|
||||
"""Record spec decoding metrics in Prometheus.
|
||||
|
||||
The acceptance rate can be calculated using a PromQL query:
|
||||
|
||||
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
|
||||
rate(vllm:spec_decode_num_draft_tokens_total[$interval])
|
||||
|
||||
The mean acceptance length (conventionally including bonus tokens)
|
||||
can be calculated using:
|
||||
|
||||
1 + (
|
||||
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
|
||||
rate(vllm:spec_decode_num_drafts[$interval]))
|
||||
|
||||
A per-position acceptance rate vector can be computed using
|
||||
|
||||
vllm:spec_decode_num_accepted_tokens_per_pos[$interval] /
|
||||
vllm:spec_decode_num_drafts[$interval]
|
||||
"""
|
||||
|
||||
_counter_cls = prometheus_client.Counter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
labelnames: list[str],
|
||||
labelvalues: list[str],
|
||||
):
|
||||
self.spec_decoding_enabled = speculative_config is not None
|
||||
if not self.spec_decoding_enabled:
|
||||
return
|
||||
|
||||
self.counter_spec_decode_num_drafts = \
|
||||
self._counter_cls(
|
||||
name="vllm:spec_decode_num_drafts",
|
||||
documentation="Number of spec decoding drafts.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.counter_spec_decode_num_draft_tokens = \
|
||||
self._counter_cls(
|
||||
name="vllm:spec_decode_num_draft_tokens",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames,).labels(*labelvalues)
|
||||
self.counter_spec_decode_num_accepted_tokens = \
|
||||
self._counter_cls(
|
||||
name="vllm:spec_decode_num_accepted_tokens",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
assert speculative_config is not None
|
||||
num_spec_tokens = (speculative_config.num_speculative_tokens
|
||||
if self.spec_decoding_enabled else 0)
|
||||
pos_labelnames = labelnames + ["position"]
|
||||
base_counter = self._counter_cls(
|
||||
name="vllm:spec_decode_num_accepted_tokens_per_pos",
|
||||
documentation="Accepted tokens per draft position.",
|
||||
labelnames=pos_labelnames,
|
||||
)
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos: list[
|
||||
prometheus_client.Counter] = []
|
||||
for pos in range(num_spec_tokens):
|
||||
pos_labelvalues = labelvalues + [str(pos)]
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos.append(
|
||||
base_counter.labels(*pos_labelvalues))
|
||||
|
||||
def observe(self, spec_decoding_stats: SpecDecodingStats):
|
||||
if not self.spec_decoding_enabled:
|
||||
return
|
||||
self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts)
|
||||
self.counter_spec_decode_num_draft_tokens.inc(
|
||||
spec_decoding_stats.num_draft_tokens)
|
||||
self.counter_spec_decode_num_accepted_tokens.inc(
|
||||
spec_decoding_stats.num_accepted_tokens)
|
||||
for pos, counter in enumerate(
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos):
|
||||
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
|
||||
132
vllm/v1/spec_decode/ngram_proposer.py
Normal file
132
vllm/v1/spec_decode/ngram_proposer.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from numba import jit
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class NgramProposer:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
# Minimum length of the n-gram to match.
|
||||
self.min_n = vllm_config.speculative_config.prompt_lookup_min
|
||||
# Maximum length of the n-gram to match.
|
||||
self.max_n = vllm_config.speculative_config.prompt_lookup_max
|
||||
# Number of tokens follow the match. If there are less than k
|
||||
# tokens follow the match, we will return the maximum amount of
|
||||
# tokens until the end.
|
||||
self.k = vllm_config.speculative_config.num_speculative_tokens
|
||||
# Maximum length of the model.
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
# Trigger Numba JIT compilation for N-gram proposer.
|
||||
# This usually takes less than 1 second.
|
||||
self.propose(np.zeros(1024, dtype=np.int32))
|
||||
|
||||
def propose(
|
||||
self,
|
||||
context_token_ids: np.ndarray,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Proposes the next sequence of tokens based on n-gram pattern
|
||||
matching in the context. The function finds matches of the last n
|
||||
tokens in the previous context, and returns k tokens that followed
|
||||
that match.
|
||||
|
||||
Args:
|
||||
context_token_ids: Numpy array of token IDs representing the
|
||||
context sequence.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The sequence of tokens that followed
|
||||
the matched n-gram in the context.
|
||||
None: If no matching n-gram pattern is found.
|
||||
|
||||
Example:
|
||||
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
|
||||
k = 4:
|
||||
- The last 3 (= max_n) tokens [4,2,3] cannot find a match.
|
||||
- The last 2 tokens [2,3] will be matched against the previous
|
||||
4 tokens [1,2,3,4].
|
||||
- Finding a match of [2,3] would return the tokens that
|
||||
followed that pattern. Here we will return [4,2,3] because
|
||||
we only have three tokens after the match.
|
||||
"""
|
||||
# Do not generate draft tokens beyond the max model length.
|
||||
k = min(self.k, self.max_model_len - context_token_ids.shape[0])
|
||||
if k <= 0:
|
||||
return None
|
||||
|
||||
# TODO(woosuk): Optimize this.
|
||||
for n in range(self.max_n, self.min_n - 1, -1):
|
||||
result = _find_subarray_kmp(context_token_ids, n, k)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
# No model to load.
|
||||
pass
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Build the lps (longest proper prefix which is also suffix)
|
||||
array for the pattern.
|
||||
"""
|
||||
lps = np.zeros(len(pattern), dtype=np.int32)
|
||||
prev_lps = 0 # length of the previous longest prefix suffix
|
||||
i = 1
|
||||
|
||||
while i < len(pattern):
|
||||
if pattern[i] == pattern[prev_lps]:
|
||||
prev_lps += 1
|
||||
lps[i] = prev_lps
|
||||
i += 1
|
||||
else:
|
||||
if prev_lps != 0:
|
||||
prev_lps = lps[prev_lps - 1]
|
||||
else:
|
||||
lps[i] = 0
|
||||
i += 1
|
||||
return lps
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def _find_subarray_kmp(
|
||||
context_token_ids: np.ndarray,
|
||||
n: int,
|
||||
k: int,
|
||||
) -> Optional[np.ndarray]:
|
||||
context_len = context_token_ids.shape[0]
|
||||
assert n > 0
|
||||
|
||||
pattern = context_token_ids[-n:]
|
||||
# Precompute lps array for Y
|
||||
lps = _kmp_lps_array(pattern)
|
||||
|
||||
i = 0
|
||||
j = 0
|
||||
# -n because the last n tokens are used as pattern
|
||||
while i < context_len - n:
|
||||
if context_token_ids[i] == pattern[j]:
|
||||
i += 1
|
||||
j += 1
|
||||
|
||||
# If we have matched the entire Y
|
||||
if j == n:
|
||||
# Found pattern in context, gather the next K elements
|
||||
return context_token_ids[i:i + k]
|
||||
else:
|
||||
# Mismatch
|
||||
if j != 0:
|
||||
# Use the lps array to avoid re-checking elements
|
||||
j = lps[j - 1]
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# Y not found
|
||||
return None
|
||||
46
vllm/v1/spec_decode/utils.py
Normal file
46
vllm/v1/spec_decode/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
|
||||
if req_id in input_batch.min_p_reqs:
|
||||
# Spec decode doesn't support min_p sampling.
|
||||
return False
|
||||
elif (req_id in input_batch.frequency_penalties_reqs
|
||||
or req_id in input_batch.presence_penalties_reqs
|
||||
or req_id in input_batch.repetition_penalties_reqs):
|
||||
# Spec decode doesn't support penalties.
|
||||
return False
|
||||
elif req_id in input_batch.num_logprobs:
|
||||
# Spec decode doesn't support logprobs.
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@triton.jit
|
||||
def prepare_eagle_input_kernel(
|
||||
out_ptr,
|
||||
cu_query_lens_ptr,
|
||||
cu_num_tokens_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# [start_pos, end_pos)
|
||||
start_pos = tl.load(cu_num_tokens_ptr + pid)
|
||||
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
|
||||
num_tokens = end_pos - start_pos
|
||||
|
||||
index_start = tl.load(cu_query_lens_ptr + pid)
|
||||
|
||||
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
|
||||
for i in tl.range(num_blocks):
|
||||
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(
|
||||
out_ptr + start_pos + offset,
|
||||
index_start + offset,
|
||||
mask=offset < num_tokens,
|
||||
)
|
||||
222
vllm/v1/structured_output/__init__.py
Normal file
222
vllm/v1/structured_output/__init__.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar)
|
||||
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StructuredOutputManager:
|
||||
"""Engine-level manager for structured output requests."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.backend: Optional[StructuredOutputBackend] = None
|
||||
self.reasoner: Optional[ReasoningParser] = None
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
||||
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
||||
|
||||
# The default max_workers if not specified is the number of CPUs * 5,
|
||||
# which is way too high since these tasks are CPU-bound, not I/O bound.
|
||||
# We also know we would never dominate CPU usage with just grammar
|
||||
# compilation, so we set it to half the number of CPUs.
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.vllm_config.model_config,
|
||||
scheduler_config=self.vllm_config.scheduler_config,
|
||||
lora_config=self.vllm_config.lora_config,
|
||||
).get_lora_tokenizer(None)
|
||||
reasoning_backend = vllm_config.decoding_config.reasoning_backend
|
||||
if reasoning_backend:
|
||||
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
|
||||
reasoning_backend)
|
||||
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
|
||||
|
||||
def grammar_init(self, request: Request) -> None:
|
||||
if request.structured_output_request is None:
|
||||
return
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert request.sampling_params.guided_decoding is not None
|
||||
|
||||
# Initialize the backend the first time it is needed.
|
||||
#
|
||||
# NOTE: We only support a single backend. We do NOT support different
|
||||
# backends on a per-request basis in V1 (for now, anyway...).
|
||||
if self.backend is None:
|
||||
backend = request.sampling_params.guided_decoding.backend
|
||||
vocab_size = self.vllm_config.model_config.get_vocab_size()
|
||||
if backend == "xgrammar":
|
||||
self.backend = XgrammarBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif backend == "guidance":
|
||||
self.backend = GuidanceBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported structured output backend: {backend}")
|
||||
|
||||
grammar = self.executor.submit(self._async_create_grammar, request)
|
||||
request.structured_output_request.grammar = grammar # type: ignore[assignment]
|
||||
|
||||
def _async_create_grammar(
|
||||
self,
|
||||
request: Request,
|
||||
) -> StructuredOutputGrammar:
|
||||
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
|
||||
|
||||
# Note that the request was validated in the engine core client,
|
||||
# so at this point we know it is a supported type of request.
|
||||
#
|
||||
# TODO: we still need to handle xgrammar compilation failures,
|
||||
# though it should be unlikely as we test that up front as well.
|
||||
request_type, grammar_spec = key
|
||||
|
||||
assert self.backend is not None
|
||||
return self.backend.compile_grammar(request_type, grammar_spec)
|
||||
|
||||
def grammar_bitmask(
|
||||
self,
|
||||
requests: dict[str, Request],
|
||||
structured_output_request_ids: dict[str, int],
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||
) -> Optional[npt.NDArray[np.int32]]:
|
||||
# Prepare the structured output bitmask for this batch.
|
||||
if not structured_output_request_ids:
|
||||
return None
|
||||
|
||||
max_num_spec_tokens = 0
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
max_num_spec_tokens = \
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
|
||||
if self._grammar_bitmask is None:
|
||||
assert self.backend is not None
|
||||
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
||||
|
||||
# Allocate a bitmask for each token needing to be checked:
|
||||
# one for each speculative position, and one more for the
|
||||
# bonus token / non-speculative token.
|
||||
self._grammar_bitmask = \
|
||||
self.backend.allocate_token_bitmask(
|
||||
max_batch_size * (1 + max_num_spec_tokens))
|
||||
|
||||
bitmask_tensor = self._grammar_bitmask
|
||||
# Generate a batched bitmask for all structured output requests.
|
||||
# When speculative decoding is enabled, we need to include multiple
|
||||
# masks for each request, one for each possible bonus token position.
|
||||
# These are stored inline in the tensor and unpacked by the gpu runner.
|
||||
cumulative_index = 0
|
||||
ordered_seq = sorted(structured_output_request_ids.items(),
|
||||
key=lambda x: x[1])
|
||||
|
||||
# Note that for thinking support, we will need to
|
||||
# reset the relevant part of the bitmask for consequent
|
||||
# request here.
|
||||
bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_(
|
||||
self._full_mask)
|
||||
|
||||
# NOTE: This outer loop can likely be parallelized to improve
|
||||
# performance of bitmask generation for large batches.
|
||||
for req_id, _ in ordered_seq:
|
||||
request = requests[req_id]
|
||||
structured_output_request = request.structured_output_request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert structured_output_request is not None
|
||||
assert structured_output_request.grammar is not None
|
||||
apply_bitmask: bool = True
|
||||
if self.reasoner is not None:
|
||||
if structured_output_request.reasoning_ended is None:
|
||||
structured_output_request.reasoning_ended = \
|
||||
self.reasoner.is_reasoning_end(request.prompt_token_ids)
|
||||
apply_bitmask = structured_output_request.reasoning_ended
|
||||
|
||||
state_advancements = 0
|
||||
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
|
||||
for i, token in enumerate(req_tokens):
|
||||
if apply_bitmask and not \
|
||||
structured_output_request.grammar.is_terminated():
|
||||
structured_output_request.grammar.fill_bitmask(
|
||||
bitmask_tensor, cumulative_index)
|
||||
if token is not None:
|
||||
# In order to generate the correct bitmask for each
|
||||
# position in the speculative sequence, we advance
|
||||
# the FSM state for each speculative token and rollback
|
||||
# to restore the previous state when we are finished.
|
||||
assert structured_output_request.grammar.accept_tokens(
|
||||
req_id, [token])
|
||||
state_advancements += 1
|
||||
cumulative_index += 1
|
||||
if state_advancements > 0:
|
||||
structured_output_request.grammar.rollback(state_advancements)
|
||||
|
||||
if cumulative_index < bitmask_tensor.shape[0]:
|
||||
bitmask_tensor = bitmask_tensor[:cumulative_index]
|
||||
|
||||
# After finishing with the xgrammar operations, we convert to
|
||||
# np.ndarray, because that is much more efficient for serialization
|
||||
# and deserialization when sending this to the GPU workers.
|
||||
return bitmask_tensor.numpy()
|
||||
|
||||
def should_advance(self, request: Request) -> bool:
|
||||
if not request.use_structured_output:
|
||||
return False
|
||||
|
||||
# To determine whether we can advance the FSM.
|
||||
# Supports thinking usage where we skip the reasoning components.
|
||||
if TYPE_CHECKING:
|
||||
assert request.structured_output_request is not None
|
||||
assert request.structured_output_request.grammar is not None
|
||||
# by default, we should always advance
|
||||
# for cases that doesn't uses thinking mode.
|
||||
if self.reasoner is not None:
|
||||
structured_req = request.structured_output_request
|
||||
|
||||
if structured_req.reasoning_ended:
|
||||
return True
|
||||
|
||||
# Check if reasoning ends in *this* step
|
||||
if self.reasoner.is_reasoning_end(request.all_token_ids):
|
||||
# Reasoning just ended, so we shouldn't advanced til
|
||||
# next pass
|
||||
structured_req.reasoning_ended = True
|
||||
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def clear_backend(self) -> None:
|
||||
if self.backend is not None:
|
||||
self.backend.destroy()
|
||||
245
vllm/v1/structured_output/backend_guidance.py
Normal file
245
vllm/v1/structured_output/backend_guidance.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions)
|
||||
from vllm.v1.structured_output.request import get_structured_output_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import llguidance
|
||||
import llguidance.hf as llguidance_hf
|
||||
import llguidance.torch as llguidance_torch
|
||||
else:
|
||||
llguidance = LazyLoader("llguidance", globals(), "llguidance")
|
||||
llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf")
|
||||
llguidance_torch = LazyLoader("llguidance.torch", globals(),
|
||||
"llguidance.torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _walk_json_for_additional_properties(data: object):
|
||||
if isinstance(data, dict):
|
||||
for value in data.values():
|
||||
_walk_json_for_additional_properties(value)
|
||||
if 'additionalProperties' not in data and \
|
||||
('properties' in data or 'patternProperties' in data):
|
||||
data['additionalProperties'] = False
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
_walk_json_for_additional_properties(item)
|
||||
|
||||
|
||||
def process_for_additional_properties(
|
||||
guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]:
|
||||
if isinstance(guide_json, str):
|
||||
guide_json_obj = json.loads(guide_json)
|
||||
else:
|
||||
# copy for modifications
|
||||
guide_json_obj = copy.deepcopy(guide_json)
|
||||
_walk_json_for_additional_properties(guide_json_obj)
|
||||
return guide_json_obj
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidanceBackend(StructuredOutputBackend):
|
||||
|
||||
def __post_init__(self):
|
||||
self.disable_any_whitespace = \
|
||||
self.vllm_config.decoding_config.disable_any_whitespace
|
||||
self.disable_additional_properties = \
|
||||
self.vllm_config.decoding_config.disable_additional_properties
|
||||
|
||||
self.ll_tokenizer = llguidance_hf.from_tokenizer(
|
||||
self.tokenizer, self.vocab_size)
|
||||
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
self.serialized_grammar = serialize_guidance_grammar(
|
||||
request_type, grammar_spec, self.disable_any_whitespace,
|
||||
self.disable_additional_properties)
|
||||
|
||||
ll_matcher = llguidance.LLMatcher(
|
||||
self.ll_tokenizer,
|
||||
self.serialized_grammar,
|
||||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||
)
|
||||
|
||||
r = GuidanceGrammar(
|
||||
ll_matcher=ll_matcher,
|
||||
ll_tokenizer=self.ll_tokenizer,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
r.check_error()
|
||||
return r
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int):
|
||||
return llguidance_torch.allocate_token_bitmask(
|
||||
max_num_seqs, self.ll_tokenizer.vocab_size)
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidanceGrammar(StructuredOutputGrammar):
|
||||
ll_matcher: llguidance.LLMatcher
|
||||
ll_tokenizer: llguidance.LLTokenizer
|
||||
vocab_size: int
|
||||
printed_error: bool = False
|
||||
terminated: bool = False
|
||||
|
||||
def check_error(self):
|
||||
if not self.printed_error:
|
||||
err = self.ll_matcher.get_error()
|
||||
if err:
|
||||
self.printed_error = True
|
||||
logger.warning("LLMatcher error: %s", err)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the parser.
|
||||
|
||||
Returns True if the parser was advanced successfully.
|
||||
Returns False if the parser failed to advance.
|
||||
"""
|
||||
|
||||
if self.ll_tokenizer.eos_token in tokens:
|
||||
self.terminated = True
|
||||
|
||||
if self.ll_matcher.is_stopped():
|
||||
return True
|
||||
|
||||
# TODO - Add jump decoding support in the future:
|
||||
# self.ll_matcher.compute_ff_bytes() - this should always work
|
||||
# self.ll_matcher.compute_ff_tokens() - this only works for
|
||||
# "canonical" tokenizers
|
||||
# For conversion between the two, see
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md
|
||||
|
||||
r = self.ll_matcher.consume_tokens(tokens)
|
||||
|
||||
self.check_error()
|
||||
|
||||
return r
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""Checks if the list of tokens are accepted by the parser in sequence.
|
||||
Will not advance the parser.
|
||||
|
||||
Returns the prefix list of tokens that are accepted by the parser.
|
||||
"""
|
||||
if len(tokens) == 0:
|
||||
return []
|
||||
if self.ll_matcher.is_stopped():
|
||||
return []
|
||||
|
||||
num_tokens = self.ll_matcher.validate_tokens(tokens)
|
||||
|
||||
self.check_error()
|
||||
|
||||
return tokens[:num_tokens]
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.ll_matcher.rollback(num_tokens)
|
||||
self.check_error()
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
# this will automatically return [EOS] mask if the matcher is stopped
|
||||
# or otherwise in an error state
|
||||
llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx)
|
||||
self.check_error()
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
return self.terminated
|
||||
|
||||
def reset(self):
|
||||
# This method may be not needed anymore? TODO
|
||||
self.ll_matcher.reset()
|
||||
|
||||
|
||||
def serialize_guidance_grammar(
|
||||
request_type: StructuredOutputOptions,
|
||||
grammar_spec: Union[str, dict[str, Any]],
|
||||
disable_any_whitespace: bool = False,
|
||||
disable_additional_properties: bool = False,
|
||||
) -> str:
|
||||
|
||||
def _process_schema(grammar_spec: Union[str, dict[str, Any]], ) -> str:
|
||||
if disable_additional_properties:
|
||||
grammar_spec = process_for_additional_properties(grammar_spec)
|
||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||
grammar_spec,
|
||||
defaults={
|
||||
"whitespace_flexible": not disable_any_whitespace,
|
||||
})
|
||||
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
return _process_schema(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||
'{"type": "object"}',
|
||||
defaults={
|
||||
"whitespace_flexible": not disable_any_whitespace,
|
||||
})
|
||||
else:
|
||||
if request_type == StructuredOutputOptions.REGEX:
|
||||
tp = "regex"
|
||||
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||
tp = "grammar"
|
||||
elif request_type == StructuredOutputOptions.CHOICE:
|
||||
tp = "choice"
|
||||
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
|
||||
if isinstance(grammar_spec, str):
|
||||
s_tag = json.loads(grammar_spec)
|
||||
else:
|
||||
s_tag = grammar_spec
|
||||
triggers: list[str] = s_tag["triggers"]
|
||||
tags: list[llguidance.StructTag] = []
|
||||
for s in s_tag["structures"]:
|
||||
begin: str = s["begin"]
|
||||
trig = next((t for t in triggers if begin.startswith(t)), None)
|
||||
if trig is None:
|
||||
raise ValueError(
|
||||
f"Trigger {begin} not found in triggers {triggers}")
|
||||
tags.append(
|
||||
llguidance.StructTag(
|
||||
trigger=trig,
|
||||
begin=s["begin"],
|
||||
grammar=_process_schema(s["schema"]),
|
||||
end=s["end"],
|
||||
))
|
||||
if not tags:
|
||||
raise ValueError(
|
||||
"No structural tags found in the grammar spec.")
|
||||
return llguidance.StructTag.to_grammar(tags)
|
||||
else:
|
||||
logger.error("Validation should have already occurred. "
|
||||
"Please file an issue.")
|
||||
raise ValueError("grammar is not of valid supported types. "
|
||||
f"({request_type!s})")
|
||||
return llguidance.grammar_from(tp, grammar_spec)
|
||||
|
||||
|
||||
def validate_guidance_grammar(
|
||||
sampling_params: SamplingParams,
|
||||
tokenizer: Optional[llguidance.LLTokenizer] = None) -> None:
|
||||
tp, grm = get_structured_output_key(sampling_params)
|
||||
guidance_grm = serialize_guidance_grammar(tp, grm)
|
||||
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
|
||||
if err:
|
||||
raise ValueError(f"Grammar error: {err}")
|
||||
134
vllm/v1/structured_output/backend_types.py
Normal file
134
vllm/v1/structured_output/backend_types.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class StructuredOutputOptions(enum.Enum):
|
||||
JSON = enum.auto()
|
||||
JSON_OBJECT = enum.auto()
|
||||
REGEX = enum.auto()
|
||||
GRAMMAR = enum.auto()
|
||||
CHOICE = enum.auto()
|
||||
STRUCTURAL_TAG = enum.auto()
|
||||
|
||||
|
||||
StructuredOutputKey = tuple[StructuredOutputOptions, str]
|
||||
|
||||
|
||||
class StructuredOutputGrammar(ABC):
|
||||
"""Request-level backend for structured output requests."""
|
||||
|
||||
@abstractmethod
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""
|
||||
Determines whether the provided tokens are accepted for the
|
||||
given request.
|
||||
|
||||
Args:
|
||||
request_id (str): The unique identifier for the request.
|
||||
tokens (list[int]): A list of token IDs to evaluate.
|
||||
|
||||
Returns:
|
||||
bool: True if the tokens are accepted, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""
|
||||
Validates the provided tokens against the grammar.
|
||||
Will not advance the FSM.
|
||||
|
||||
Args:
|
||||
tokens (list[int]): A list of token IDs to validate.
|
||||
|
||||
Returns:
|
||||
list[int]: A list of accepted token IDs. Will be a prefix
|
||||
of the input tokens, and empty if none are accepted.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
"""
|
||||
Rolls back the state of the grammar by a specified number of tokens.
|
||||
Will also revert counters for the number of processed tokens.
|
||||
|
||||
Args:
|
||||
num_tokens (int): The number of tokens to roll back.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
|
||||
"""
|
||||
Fills the bitmask for a specific batch index.
|
||||
|
||||
Args:
|
||||
bitmask (torch.Tensor): The bitmask to fill
|
||||
batch_index (int): The index in the bitmask to fill
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def is_terminated(self) -> bool:
|
||||
"""
|
||||
Checks whether the structured output process has terminated.
|
||||
|
||||
Returns:
|
||||
bool: True if the process is terminated, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the state of the structured output grammar.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructuredOutputBackend(ABC):
|
||||
"""Engine-level backend for structured output requests."""
|
||||
|
||||
vllm_config: VllmConfig
|
||||
tokenizer: AnyTokenizer
|
||||
vocab_size: int
|
||||
|
||||
@abstractmethod
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
"""
|
||||
Compiles a grammar specification into a structured output grammar.
|
||||
|
||||
Args:
|
||||
request_type (StructuredOutputOptions): The type of structured
|
||||
output request.
|
||||
grammar_spec (str): The grammar specification to compile.
|
||||
|
||||
Returns:
|
||||
StructuredOutputGrammar: The compiled structured output grammar.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||
"""
|
||||
Allocates a token bitmask for the specified maximum number of sequences.
|
||||
|
||||
Args:
|
||||
max_num_seqs (int): The maximum number of sequences for which
|
||||
to allocate the bitmask.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def destroy(self):
|
||||
"""
|
||||
Backend-specific cleanup.
|
||||
"""
|
||||
318
vllm/v1/structured_output/backend_xgrammar.py
Normal file
318
vllm/v1/structured_output/backend_xgrammar.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions)
|
||||
from vllm.v1.structured_output.utils import (choice_as_grammar,
|
||||
convert_lark_to_ebnf,
|
||||
grammar_is_likely_lark)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XgrammarBackend(StructuredOutputBackend):
|
||||
|
||||
def __post_init__(self):
|
||||
self.disable_any_whitespace = \
|
||||
self.vllm_config.decoding_config.disable_any_whitespace
|
||||
|
||||
if isinstance(self.tokenizer, MistralTokenizer):
|
||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||
try:
|
||||
if self.tokenizer.is_tekken:
|
||||
encoded_vocab = self.tokenizer._vocab
|
||||
else:
|
||||
encoded_vocab = [
|
||||
token for token, _ in sorted(
|
||||
self.tokenizer.get_vocab().items(),
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
]
|
||||
stop_token_ids = None
|
||||
if (hasattr(
|
||||
self.tokenizer,
|
||||
"eos_token_id",
|
||||
) and self.tokenizer.eos_token_id is not None):
|
||||
stop_token_ids = [self.tokenizer.eos_token_id]
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Cannot get the vocabulary of the tokenizer "
|
||||
f"{type(self.tokenizer)}. The tokenizer should have a "
|
||||
"get_vocab method.") from e
|
||||
tokenizer_info = xgr.TokenizerInfo( # type: ignore
|
||||
encoded_vocab=encoded_vocab,
|
||||
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
vocab_type=xgr.VocabType.RAW
|
||||
if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK,
|
||||
vocab_size=self.vocab_size,
|
||||
stop_token_ids=stop_token_ids,
|
||||
add_prefix_space=True,
|
||||
)
|
||||
else:
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
self.tokenizer,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
self.compiler = xgr.GrammarCompiler(
|
||||
tokenizer_info,
|
||||
max_threads=8,
|
||||
cache_enabled=True,
|
||||
cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
|
||||
)
|
||||
|
||||
self.num_speculative_tokens = 0
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
self.num_speculative_tokens = \
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
ctx = self.compiler.compile_json_schema(
|
||||
grammar_spec, any_whitespace=not self.disable_any_whitespace)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
ctx = self.compiler.compile_json_schema(
|
||||
'{"type": "object"}',
|
||||
any_whitespace=not self.disable_any_whitespace)
|
||||
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||
ctx = self.compiler.compile_grammar(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
ctx = self.compiler.compile_regex(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
|
||||
s_tag = json.loads(grammar_spec)
|
||||
tags = [
|
||||
xgr.StructuralTagItem(
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
) for s in s_tag["structures"]
|
||||
]
|
||||
ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"])
|
||||
else:
|
||||
logger.error(
|
||||
"Validation should have already occurred. Please file an issue."
|
||||
)
|
||||
raise ValueError(
|
||||
f"grammar is not of valid supported types. ({request_type!s})")
|
||||
|
||||
return XgrammarGrammar(
|
||||
matcher=xgr.GrammarMatcher(
|
||||
ctx,
|
||||
max_rollback_tokens=self.num_speculative_tokens,
|
||||
),
|
||||
vocab_size=self.vocab_size,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int):
|
||||
return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)
|
||||
|
||||
def destroy(self):
|
||||
del self.compiler
|
||||
|
||||
|
||||
@dataclass
|
||||
class XgrammarGrammar(StructuredOutputGrammar):
|
||||
# NOTE: This would be a generic-enough class for
|
||||
# supporting different backends, in the future.
|
||||
# For now, just xgrammar.
|
||||
#
|
||||
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
|
||||
# for jump-forward decoding
|
||||
|
||||
vocab_size: int
|
||||
matcher: xgr.GrammarMatcher = field(hash=False)
|
||||
ctx: xgr.CompiledGrammar = field(hash=False)
|
||||
num_processed_tokens: int = field(default_factory=lambda: 0,
|
||||
repr=False,
|
||||
hash=False,
|
||||
init=False)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the FSM.
|
||||
|
||||
Returns True if the FSM was advanced successfully.
|
||||
Returns False if the FSM failed to advance.
|
||||
"""
|
||||
for token in tokens:
|
||||
if not self.matcher.accept_token(token):
|
||||
logger.error(
|
||||
"Failed to advance FSM for request %s "
|
||||
"for tokens %s. Please file an issue.", request_id, token)
|
||||
return False
|
||||
self.num_processed_tokens += 1
|
||||
return True
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""Checks if the list of tokens are accepted by the FSM in sequence.
|
||||
Will not advance the FSM.
|
||||
|
||||
Returns the prefix list of tokens that are accepted by the FSM.
|
||||
"""
|
||||
accepted_tokens = []
|
||||
for token in tokens:
|
||||
if self.matcher.accept_token(token):
|
||||
accepted_tokens.append(token)
|
||||
else:
|
||||
break
|
||||
if len(accepted_tokens) > 0:
|
||||
# Rollback the FSM to the initial state
|
||||
self.matcher.rollback(len(accepted_tokens))
|
||||
return accepted_tokens
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.matcher.rollback(num_tokens)
|
||||
self.num_processed_tokens -= num_tokens
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
self.matcher.fill_next_token_bitmask(bitmask, idx)
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
return self.matcher.is_terminated()
|
||||
|
||||
def reset(self):
|
||||
self.num_processed_tokens = 0
|
||||
self.matcher.reset()
|
||||
|
||||
|
||||
def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
|
||||
"""Check if JSON schema contains features unsupported by xgrammar."""
|
||||
|
||||
def check_object(obj: dict[str, Any]) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
# Check for numeric ranges
|
||||
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
|
||||
return True
|
||||
|
||||
# Check for array unsupported keywords
|
||||
if obj.get("type") == "array" and any(
|
||||
key in obj for key in ("uniqueItems", "contains",
|
||||
"minContains", "maxContains")):
|
||||
return True
|
||||
|
||||
# Unsupported keywords for strings
|
||||
if obj.get("type") == "string" and "format" in obj:
|
||||
return True
|
||||
|
||||
# Unsupported keywords for objects
|
||||
if obj.get("type") == "object" and any(
|
||||
key in obj for key in ("minProperties", "maxProperties",
|
||||
"propertyNames", "patternProperties")):
|
||||
return True
|
||||
|
||||
# Recursively check all nested objects and arrays
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
|
||||
def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
|
||||
"""Validate that the request is supported by structured output.
|
||||
|
||||
Raises ValueError if the request is not supported.
|
||||
"""
|
||||
if sampling_params.guided_decoding is None:
|
||||
return
|
||||
|
||||
gd_params = sampling_params.guided_decoding
|
||||
|
||||
if gd_params.regex:
|
||||
try:
|
||||
xgr.Grammar.from_regex(gd_params.regex)
|
||||
except Exception as err:
|
||||
raise ValueError("Failed to transform regex into a grammar: "
|
||||
f"{err}") from err
|
||||
|
||||
if gd_params.choice:
|
||||
choice_grammar = choice_as_grammar(gd_params.choice)
|
||||
try:
|
||||
xgr.Grammar.from_ebnf(choice_grammar)
|
||||
except Exception as err:
|
||||
raise ValueError("Failed to transform choices into a grammar: "
|
||||
"{err}") from err
|
||||
gd_params.choice = None
|
||||
gd_params.grammar = choice_grammar
|
||||
return
|
||||
|
||||
if gd_params.json:
|
||||
if isinstance(gd_params.json, str):
|
||||
try:
|
||||
schema = json.loads(gd_params.json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid JSON grammar specification.") from e
|
||||
else:
|
||||
schema = gd_params.json
|
||||
|
||||
try:
|
||||
xgr.Grammar.from_json_schema(schema)
|
||||
except Exception as err:
|
||||
raise ValueError("Failed to transform json schema into a grammar: "
|
||||
f"{err}") from err
|
||||
|
||||
if has_xgrammar_unsupported_json_features(schema):
|
||||
raise ValueError("The provided JSON schema contains features not "
|
||||
"supported by xgrammar.")
|
||||
return
|
||||
|
||||
if gd_params.grammar:
|
||||
if grammar_is_likely_lark(gd_params.grammar):
|
||||
# xgrammar supports EBNF grammars only
|
||||
try:
|
||||
gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Failed to convert the grammar from Lark to EBNF. ") from e
|
||||
|
||||
# Test parsing EBNF grammar, possibly already converted from Lark
|
||||
try:
|
||||
# parse the grammar, but we aren't compiling it.
|
||||
xgr.Grammar.from_ebnf(gd_params.grammar)
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid grammar specification.") from e
|
||||
return
|
||||
|
||||
if gd_params.structural_tag:
|
||||
try:
|
||||
s_tag = json.loads(gd_params.structural_tag)
|
||||
tags = [
|
||||
xgr.StructuralTagItem(
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
) for s in s_tag["structures"]
|
||||
]
|
||||
xgr.Grammar.from_structural_tag(tags, s_tag["triggers"])
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid structural tag specification.") from e
|
||||
86
vllm/v1/structured_output/request.py
Normal file
86
vllm/v1/structured_output/request.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import json
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures._base import TimeoutError
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar,
|
||||
StructuredOutputKey,
|
||||
StructuredOutputOptions)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StructuredOutputRequest:
|
||||
|
||||
sampling_params: SamplingParams
|
||||
_grammar: Optional[Union[Future[StructuredOutputGrammar],
|
||||
StructuredOutputGrammar]] = None
|
||||
reasoning_ended: Optional[bool] = None
|
||||
|
||||
def _check_grammar_completion(self) -> bool:
|
||||
# NOTE: We have to lazy import to gate circular imports
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
if isinstance(self._grammar, Future):
|
||||
try:
|
||||
# We will check whether the future is ready within 100 us
|
||||
self._grammar = self._grammar.result(timeout=0.0001)
|
||||
self.status = RequestStatus.WAITING
|
||||
except TimeoutError:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_grammar_ready(self) -> bool:
|
||||
return self._check_grammar_completion()
|
||||
|
||||
@property
|
||||
def grammar(self) -> Optional[StructuredOutputGrammar]:
|
||||
completed = self._check_grammar_completion()
|
||||
return cast(Optional[StructuredOutputGrammar],
|
||||
self._grammar) if completed else None
|
||||
|
||||
@grammar.setter
|
||||
def grammar(
|
||||
self, grammar: Union[StructuredOutputGrammar,
|
||||
Future[StructuredOutputGrammar]]
|
||||
) -> None:
|
||||
self._grammar = grammar
|
||||
|
||||
@functools.cached_property
|
||||
def structured_output_key(self) -> StructuredOutputKey:
|
||||
return get_structured_output_key(self.sampling_params)
|
||||
|
||||
|
||||
def get_structured_output_key(
|
||||
sampling_params: SamplingParams) -> StructuredOutputKey:
|
||||
params = sampling_params.guided_decoding
|
||||
assert params is not None, "params can't be None."
|
||||
if params.json is not None:
|
||||
if not isinstance(params.json, str):
|
||||
json_str = json.dumps(params.json)
|
||||
else:
|
||||
json_str = params.json
|
||||
return (StructuredOutputOptions.JSON, json_str)
|
||||
elif params.json_object:
|
||||
return (StructuredOutputOptions.JSON_OBJECT, "")
|
||||
elif params.regex is not None:
|
||||
return (StructuredOutputOptions.REGEX, params.regex)
|
||||
elif params.choice is not None:
|
||||
if not isinstance(params.choice, str):
|
||||
json_str = json.dumps(params.choice)
|
||||
else:
|
||||
json_str = params.choice
|
||||
return (StructuredOutputOptions.CHOICE, json_str)
|
||||
elif params.grammar is not None:
|
||||
return (StructuredOutputOptions.GRAMMAR, params.grammar)
|
||||
elif params.structural_tag is not None:
|
||||
return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag)
|
||||
else:
|
||||
raise ValueError("No valid structured output parameter found")
|
||||
175
vllm/v1/structured_output/utils.py
Normal file
175
vllm/v1/structured_output/utils.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import regex as re
|
||||
|
||||
|
||||
def grammar_is_likely_lark(grammar_str: str) -> bool:
|
||||
"""
|
||||
Check if grammar appears to use Lark syntax.
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar string
|
||||
|
||||
Returns:
|
||||
bool: True if grammar appears to be in Lark format, False otherwise
|
||||
|
||||
Examples:
|
||||
>>> grammar_is_likely_lark("rule: 'abc'")
|
||||
True
|
||||
>>> grammar_is_likely_lark("rule ::= 'abc'")
|
||||
False
|
||||
"""
|
||||
if not grammar_str or not isinstance(grammar_str, str):
|
||||
return False
|
||||
|
||||
for line in grammar_str.split('\n'):
|
||||
# Remove both comment styles
|
||||
line = re.sub(r'(#|//).*$', '', line).strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Look for EBNF rule definition
|
||||
if '::=' in line:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def convert_lark_to_ebnf(grammar_str: str) -> str:
|
||||
"""
|
||||
Convert a Lark grammar string to EBNF format.
|
||||
|
||||
EBNF reference:
|
||||
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||
Lark grammar reference:
|
||||
https://lark-parser.readthedocs.io/en/latest/grammar.html
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar in Lark format
|
||||
|
||||
Returns:
|
||||
str: Converted grammar in EBNF format
|
||||
|
||||
Examples:
|
||||
>>> print(convert_lark_to_ebnf("rule: 'hello'"))
|
||||
root ::= rule
|
||||
rule ::= "hello"
|
||||
"""
|
||||
if not isinstance(grammar_str, str):
|
||||
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
|
||||
if not grammar_str.strip():
|
||||
raise ValueError("Grammar string cannot be empty")
|
||||
|
||||
defined_rules = set()
|
||||
referenced_rules = set()
|
||||
output_lines = []
|
||||
|
||||
def clean_line(line: str) -> str:
|
||||
"""Remove comments and whitespace from line."""
|
||||
return re.sub(r'(#|//).*$', '', line).strip()
|
||||
|
||||
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
|
||||
"""Validate quote matching in text."""
|
||||
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Mismatched quotes in {rule_name} on line {line_num}")
|
||||
|
||||
def extract_references(text: str) -> set:
|
||||
"""Extract rule references from text."""
|
||||
# Remove quoted strings and special characters
|
||||
text = re.sub(r'"[^"]*"', '', text)
|
||||
text = re.sub(r'[+*?()|\[\]{}]', ' ', text)
|
||||
return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text))
|
||||
|
||||
# First pass: Find root rule and validate rule definitions
|
||||
lines = [clean_line(line) for line in grammar_str.split('\n')]
|
||||
first_rule = None
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line or line.startswith('|'):
|
||||
continue
|
||||
|
||||
if ':' in line:
|
||||
try:
|
||||
name = line.split(':', 1)[0].strip().strip('?')
|
||||
defined_rules.add(name)
|
||||
if first_rule is None:
|
||||
first_rule = name
|
||||
if name == 'start':
|
||||
first_rule = 'start'
|
||||
except IndexError as e:
|
||||
raise ValueError(f"Invalid rule format on line {line_num}. "
|
||||
"Expected 'rule_name: definition'") from e
|
||||
|
||||
if not defined_rules:
|
||||
raise ValueError("No valid rules found in grammar")
|
||||
|
||||
# Add root rule
|
||||
output_lines.append(f"root ::= {first_rule}")
|
||||
|
||||
# Second pass: Process rule definitions and alternatives
|
||||
current_rule = None
|
||||
current_definition = []
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
if ':' in line and not line.startswith('|'):
|
||||
# Save previous rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(
|
||||
f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||
|
||||
# Process new rule
|
||||
name, definition = line.split(':', 1)
|
||||
current_rule = name.strip().strip('?')
|
||||
|
||||
check_quotes(definition, f"rule '{current_rule}'", line_num)
|
||||
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
|
||||
referenced_rules.update(extract_references(definition))
|
||||
current_definition = [definition.strip()]
|
||||
|
||||
elif line.startswith('|'):
|
||||
if not current_rule:
|
||||
raise ValueError(f"Alternative '|' on line {line_num} "
|
||||
"without a preceding rule definition")
|
||||
|
||||
alt_def = line[1:].strip()
|
||||
check_quotes(alt_def, f"alternative for rule '{current_rule}'",
|
||||
line_num)
|
||||
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
|
||||
referenced_rules.update(extract_references(alt_def))
|
||||
current_definition.append(alt_def)
|
||||
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
|
||||
|
||||
# Add final rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(
|
||||
f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||
|
||||
# Validate all rules are defined
|
||||
undefined_rules = referenced_rules - defined_rules - {'root'}
|
||||
if undefined_rules:
|
||||
raise ValueError("Referenced rules are not defined: "
|
||||
f"{', '.join(sorted(undefined_rules))}")
|
||||
|
||||
return '\n'.join(output_lines)
|
||||
|
||||
|
||||
def choice_as_grammar(choice: list[str]) -> str:
|
||||
|
||||
def escape_ebnf_string(s: str) -> str:
|
||||
"""Escape special characters in a EBNF string."""
|
||||
# Escape double quotes and backslashes
|
||||
return re.sub(r'(["\\])', r'\\\1', s)
|
||||
|
||||
escaped_choices = (escape_ebnf_string(c) for c in choice)
|
||||
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
|
||||
return grammar
|
||||
743
vllm/v1/utils.py
Normal file
743
vllm/v1/utils.py
Normal file
@@ -0,0 +1,743 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from multiprocessing import Process, connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
||||
Union, overload)
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
|
||||
get_tcp_uri, kill_process_tree)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
STARTUP_POLL_PERIOD_MS = 10000
|
||||
|
||||
|
||||
class ConstantList(Generic[T], Sequence):
|
||||
|
||||
def __init__(self, x: list[T]) -> None:
|
||||
self._x = x
|
||||
|
||||
def append(self, item):
|
||||
raise Exception("Cannot append to a constant list")
|
||||
|
||||
def extend(self, item):
|
||||
raise Exception("Cannot extend a constant list")
|
||||
|
||||
def insert(self, item):
|
||||
raise Exception("Cannot insert into a constant list")
|
||||
|
||||
def pop(self, item):
|
||||
raise Exception("Cannot pop from a constant list")
|
||||
|
||||
def remove(self, item):
|
||||
raise Exception("Cannot remove from a constant list")
|
||||
|
||||
def clear(self):
|
||||
raise Exception("Cannot clear a constant list")
|
||||
|
||||
def index(self,
|
||||
item: T,
|
||||
start: int = 0,
|
||||
stop: Optional[int] = None) -> int:
|
||||
return self._x.index(item, start,
|
||||
stop if stop is not None else len(self._x))
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: int) -> T:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice, /) -> list[T]:
|
||||
...
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
|
||||
return self._x[item]
|
||||
|
||||
@overload
|
||||
def __setitem__(self, item: int, value: T):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __setitem__(self, s: slice, value: T, /):
|
||||
...
|
||||
|
||||
def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
|
||||
raise Exception("Cannot set item in a constant list")
|
||||
|
||||
def __delitem__(self, item):
|
||||
raise Exception("Cannot delete item from a constant list")
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._x)
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._x
|
||||
|
||||
def __len__(self):
|
||||
return len(self._x)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ConstantList({self._x})"
|
||||
|
||||
|
||||
def get_engine_client_zmq_addr(local_only: bool,
|
||||
host: str,
|
||||
port: int = 0) -> str:
|
||||
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
|
||||
host, port or get_open_port()))
|
||||
|
||||
|
||||
class CoreEngineState(Enum):
|
||||
NEW = auto()
|
||||
CONNECTED = auto()
|
||||
READY = auto()
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank."""
|
||||
|
||||
def __init__(self, index: int = 0, local: bool = True):
|
||||
self.local = local
|
||||
self.index = index
|
||||
self.identity = index.to_bytes(2, "little")
|
||||
|
||||
self.state = CoreEngineState.NEW
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineZmqAddresses:
|
||||
# ZMQ input socket addresses for each front-end client (requests)
|
||||
inputs: list[str]
|
||||
# ZMQ output socket addresses for each front-end client (responses)
|
||||
outputs: list[str]
|
||||
# ZMQ input socket address of DP coordinator if applicable
|
||||
coordinator_input: Optional[str] = None
|
||||
# ZMQ output socket address of DP coordinator if applicable
|
||||
coordinator_output: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineHandshakeMetadata:
|
||||
"""Metadata sent to each engine process during startup handshake,
|
||||
including addresses of the front-end ZMQ queues that they should
|
||||
connect to.
|
||||
"""
|
||||
addresses: EngineZmqAddresses
|
||||
parallel_config: dict[str, Union[int, str]]
|
||||
|
||||
|
||||
class APIServerProcessManager:
|
||||
"""Manages a group of API server processes.
|
||||
|
||||
Handles creation, monitoring, and termination of API server worker
|
||||
processes. Also monitors extra processes to check if they are healthy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_server_fn: Callable,
|
||||
listen_address: str,
|
||||
sock: Any,
|
||||
args: argparse.Namespace,
|
||||
num_servers: int,
|
||||
input_addresses: list[str],
|
||||
output_addresses: list[str],
|
||||
stats_update_address: Optional[str] = None,
|
||||
):
|
||||
"""Initialize and start API server worker processes.
|
||||
|
||||
Args:
|
||||
target_server_fn: Function to call for each API server process
|
||||
listen_address: Address to listen for client connections
|
||||
sock: Socket for client connections
|
||||
args: Command line arguments
|
||||
num_servers: Number of API server processes to start
|
||||
input_addresses: Input addresses for each API server
|
||||
output_addresses: Output addresses for each API server
|
||||
stats_update_address: Optional stats update address
|
||||
"""
|
||||
self.listen_address = listen_address
|
||||
self.sock = sock
|
||||
self.args = args
|
||||
|
||||
# Start API servers
|
||||
spawn_context = multiprocessing.get_context("spawn")
|
||||
self.processes: list[BaseProcess] = []
|
||||
|
||||
for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
|
||||
output_addresses):
|
||||
client_config = {
|
||||
"input_address": in_addr,
|
||||
"output_address": out_addr,
|
||||
"client_index": i
|
||||
}
|
||||
if stats_update_address is not None:
|
||||
client_config["stats_update_address"] = stats_update_address
|
||||
|
||||
proc = spawn_context.Process(target=target_server_fn,
|
||||
name=f"ApiServer_{i}",
|
||||
args=(listen_address, sock, args,
|
||||
client_config))
|
||||
self.processes.append(proc)
|
||||
proc.start()
|
||||
|
||||
logger.info("Started %d API server processes", len(self.processes))
|
||||
|
||||
# Shutdown only the API server processes on garbage collection
|
||||
# The extra processes are managed by their owners
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||
|
||||
def close(self) -> None:
|
||||
self._finalizer()
|
||||
|
||||
|
||||
class CoreEngineProcManager:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
of background processes used by the AsyncLLM and LLMEngine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_fn: Callable,
|
||||
local_engine_count: int,
|
||||
start_index: int,
|
||||
local_start_index: int,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
context = get_mp_context()
|
||||
common_kwargs = {
|
||||
"vllm_config": vllm_config,
|
||||
"on_head_node": on_head_node,
|
||||
"handshake_address": handshake_address,
|
||||
"executor_class": executor_class,
|
||||
"log_stats": log_stats,
|
||||
}
|
||||
|
||||
self.processes: list[BaseProcess] = []
|
||||
for index in range(local_engine_count):
|
||||
local_index = local_start_index + index
|
||||
global_index = start_index + index
|
||||
# Start EngineCore in background process.
|
||||
self.processes.append(
|
||||
context.Process(target=target_fn,
|
||||
name=f"EngineCore_{global_index}",
|
||||
kwargs=common_kwargs | {
|
||||
"dp_rank": global_index,
|
||||
"local_dp_rank": local_index,
|
||||
}))
|
||||
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||
try:
|
||||
for proc in self.processes:
|
||||
proc.start()
|
||||
finally:
|
||||
# Kill other procs if not all are running.
|
||||
if self.finished_procs():
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""Shutdown all procs."""
|
||||
self._finalizer()
|
||||
|
||||
def join_first(self):
|
||||
"""Wait for any process to exit."""
|
||||
connection.wait(proc.sentinel for proc in self.processes)
|
||||
|
||||
def sentinels(self) -> list:
|
||||
return [proc.sentinel for proc in self.processes]
|
||||
|
||||
def finished_procs(self) -> dict[str, int]:
|
||||
"""Returns dict of proc name -> exit code for any finished procs."""
|
||||
return {
|
||||
proc.name: proc.exitcode
|
||||
for proc in self.processes if proc.exitcode is not None
|
||||
}
|
||||
|
||||
|
||||
class CoreEngineActorManager:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
of core engine Ray actors used by the AsyncLLM and LLMEngine.
|
||||
|
||||
Different from CoreEngineProcManager, this class manages
|
||||
core engines for both local and remote nodes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
addresses: EngineZmqAddresses,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
placement_groups: Optional[list["PlacementGroup"]] = None,
|
||||
local_dp_ranks: Optional[list[int]] = None,
|
||||
):
|
||||
import copy
|
||||
|
||||
import ray
|
||||
from ray.util.scheduling_strategies import (
|
||||
PlacementGroupSchedulingStrategy)
|
||||
|
||||
from vllm.v1.engine.core import DPEngineCoreActor
|
||||
|
||||
self.local_engine_actors: list[ray.ActorHandle] = []
|
||||
self.remote_engine_actors: list[ray.ActorHandle] = []
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
|
||||
if ray.is_initialized():
|
||||
logger.info(
|
||||
"Ray is already initialized. Skipping Ray initialization.")
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
if placement_groups is not None:
|
||||
assert local_dp_ranks is not None, (
|
||||
"local_dp_ranks must be provided if "
|
||||
"placement_groups is provided")
|
||||
assert len(placement_groups) == len(local_dp_ranks), (
|
||||
"placement_groups and local_dp_ranks must "
|
||||
"have the same length")
|
||||
logger.info("Using provided placement groups")
|
||||
# TODO(rui): validate passed-in placement groups
|
||||
self.created_placement_groups = []
|
||||
else:
|
||||
placement_groups, local_dp_ranks = \
|
||||
CoreEngineActorManager.create_dp_placement_groups(vllm_config)
|
||||
self.created_placement_groups = placement_groups
|
||||
assert len(placement_groups) == dp_size, (
|
||||
"Number of placement groups must match data parallel size")
|
||||
|
||||
refs = []
|
||||
for index in range(dp_size):
|
||||
local_index = local_dp_ranks[index]
|
||||
dp_vllm_config = copy.deepcopy(vllm_config)
|
||||
pg = placement_groups[index]
|
||||
dp_vllm_config.parallel_config.placement_group = pg
|
||||
on_head_node = index < local_engine_count
|
||||
actor = ray.remote(DPEngineCoreActor).options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_bundle_index=world_size,
|
||||
)).remote(vllm_config=dp_vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
on_head_node=on_head_node,
|
||||
addresses=addresses,
|
||||
dp_rank=index,
|
||||
local_dp_rank=local_index)
|
||||
if on_head_node:
|
||||
self.local_engine_actors.append(actor)
|
||||
else:
|
||||
self.remote_engine_actors.append(actor)
|
||||
refs.append(actor.wait_for_init.remote())
|
||||
|
||||
ray.get(refs)
|
||||
self.run_refs = []
|
||||
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||
self.run_refs.append(actor.run.remote())
|
||||
|
||||
@staticmethod
|
||||
def create_dp_placement_groups(
|
||||
vllm_config: VllmConfig
|
||||
) -> tuple[list["PlacementGroup"], list[int]]:
|
||||
|
||||
import ray
|
||||
from ray._private.state import available_resources_per_node
|
||||
from ray.util.state import list_nodes
|
||||
|
||||
logger.info("Creating placement groups for data parallel")
|
||||
dp_master_ip = \
|
||||
vllm_config.parallel_config.data_parallel_master_ip
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
|
||||
nodes = list_nodes()
|
||||
nodes = sorted(list_nodes(),
|
||||
key=lambda node: node.node_ip != dp_master_ip)
|
||||
assert nodes[0].node_ip == dp_master_ip, (
|
||||
"The first node must be the head node")
|
||||
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
|
||||
"There can only be one head node")
|
||||
|
||||
available_resources = available_resources_per_node()
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
placement_groups: list[PlacementGroup] = []
|
||||
local_dp_ranks: list[int] = []
|
||||
|
||||
for node in nodes:
|
||||
node_ip = node.node_ip
|
||||
node_resources = available_resources[node.node_id]
|
||||
# For now, each DP rank can only be assigned to one node
|
||||
# TODO(rui): support allocating a single DP rank
|
||||
# to multiple nodes
|
||||
available_engine_count = int(node_resources["GPU"]) // world_size
|
||||
if node_ip == dp_master_ip:
|
||||
assert available_engine_count >= local_engine_count, (
|
||||
"Not enough resources to allocate DP ranks "
|
||||
f"on DP master node {node_ip}")
|
||||
for i in range(local_engine_count):
|
||||
bundles = [{
|
||||
"GPU": 1.0,
|
||||
"node:" + dp_master_ip: 0.001
|
||||
}] * world_size + [{
|
||||
"CPU": 1.0
|
||||
}]
|
||||
pg = ray.util.placement_group(
|
||||
name=f"dp_rank_{len(placement_groups)}",
|
||||
strategy="STRICT_PACK",
|
||||
bundles=bundles,
|
||||
)
|
||||
placement_groups.append(pg)
|
||||
local_dp_ranks.append(i)
|
||||
else:
|
||||
for i in range(available_engine_count):
|
||||
if len(placement_groups) == dp_size:
|
||||
break
|
||||
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
|
||||
pg = ray.util.placement_group(
|
||||
name=f"dp_rank_{len(placement_groups)}",
|
||||
strategy="STRICT_PACK",
|
||||
bundles=bundles,
|
||||
)
|
||||
placement_groups.append(pg)
|
||||
local_dp_ranks.append(i)
|
||||
return placement_groups, local_dp_ranks
|
||||
|
||||
def get_run_refs(self):
|
||||
return self.run_refs
|
||||
|
||||
def close(self):
|
||||
import ray
|
||||
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||
ray.kill(actor)
|
||||
for pg in self.created_placement_groups:
|
||||
ray.util.remove_placement_group(pg)
|
||||
|
||||
|
||||
def wait_for_engine_startup(
|
||||
handshake_socket: zmq.Socket,
|
||||
addresses: EngineZmqAddresses,
|
||||
core_engines: list[CoreEngine],
|
||||
parallel_config: ParallelConfig,
|
||||
cache_config: CacheConfig,
|
||||
proc_manager: Optional[CoreEngineProcManager],
|
||||
coord_process: Optional[Process],
|
||||
):
|
||||
|
||||
# Wait for engine core process(es) to send ready messages.
|
||||
local_count = parallel_config.data_parallel_size_local
|
||||
remote_count = len(core_engines) - local_count
|
||||
# [local, remote] counts
|
||||
conn_pending, start_pending = [local_count, remote_count], [0, 0]
|
||||
poller = zmq.Poller()
|
||||
poller.register(handshake_socket, zmq.POLLIN)
|
||||
|
||||
if proc_manager is not None:
|
||||
for sentinel in proc_manager.sentinels():
|
||||
poller.register(sentinel, zmq.POLLIN)
|
||||
if coord_process is not None:
|
||||
poller.register(coord_process.sentinel, zmq.POLLIN)
|
||||
while any(conn_pending) or any(start_pending):
|
||||
events = poller.poll(STARTUP_POLL_PERIOD_MS)
|
||||
if not events:
|
||||
if any(conn_pending):
|
||||
logger.debug(
|
||||
"Waiting for %d local, %d remote core engine proc(s) "
|
||||
"to connect.", *conn_pending)
|
||||
if any(start_pending):
|
||||
logger.debug(
|
||||
"Waiting for %d local, %d remote core engine proc(s) "
|
||||
"to start.", *start_pending)
|
||||
continue
|
||||
if len(events) > 1 or events[0][0] != handshake_socket:
|
||||
# One of the local core processes exited.
|
||||
finished = proc_manager.finished_procs() if proc_manager else {}
|
||||
if coord_process is not None and coord_process.exitcode is not None:
|
||||
finished[coord_process.name] = coord_process.exitcode
|
||||
raise RuntimeError("Engine core initialization failed. "
|
||||
"See root cause above. "
|
||||
f"Failed core proc(s): {finished}")
|
||||
|
||||
# Receive HELLO and READY messages from the input socket.
|
||||
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
|
||||
eng_index = int.from_bytes(eng_identity, "little")
|
||||
engine = next((e for e in core_engines if e.identity == eng_identity),
|
||||
None)
|
||||
if engine is None:
|
||||
raise RuntimeError(f"Message from engine with unexpected data "
|
||||
f"parallel rank: {eng_index}")
|
||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||
status, local = msg["status"], msg["local"]
|
||||
if local != engine.local:
|
||||
raise RuntimeError(f"{status} message from "
|
||||
f"{'local' if local else 'remote'} "
|
||||
f"engine {eng_index}, expected it to be "
|
||||
f"{'local' if engine.local else 'remote'}")
|
||||
|
||||
if status == "HELLO" and engine.state == CoreEngineState.NEW:
|
||||
|
||||
# Send init message with DP config info.
|
||||
init_message = msgspec.msgpack.encode(
|
||||
EngineHandshakeMetadata(
|
||||
addresses=addresses,
|
||||
parallel_config={
|
||||
"data_parallel_master_ip":
|
||||
parallel_config.data_parallel_master_ip,
|
||||
"data_parallel_master_port":
|
||||
parallel_config.data_parallel_master_port,
|
||||
"data_parallel_size":
|
||||
parallel_config.data_parallel_size,
|
||||
}))
|
||||
handshake_socket.send_multipart((eng_identity, init_message),
|
||||
copy=False)
|
||||
conn_pending[0 if local else 1] -= 1
|
||||
start_pending[0 if local else 1] += 1
|
||||
engine.state = CoreEngineState.CONNECTED
|
||||
elif status == "READY" and (engine.state == CoreEngineState.CONNECTED):
|
||||
# Setup KV cache config with initialization state from
|
||||
# engine core process. Sum values from all engines in DP case.
|
||||
num_gpu_blocks = cache_config.num_gpu_blocks or 0
|
||||
num_gpu_blocks += msg["num_gpu_blocks"]
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
start_pending[0 if local else 1] -= 1
|
||||
engine.state = CoreEngineState.READY
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected {status} message for "
|
||||
f"{'local' if local else 'remote'} engine "
|
||||
f"{eng_index} in {engine.state} state.")
|
||||
|
||||
logger.debug("%s from %s core engine process %s.", status,
|
||||
"local" if local else "remote", eng_index)
|
||||
|
||||
|
||||
def wait_for_completion_or_failure(
|
||||
api_server_manager: APIServerProcessManager,
|
||||
engine_manager: Optional[Union[CoreEngineProcManager,
|
||||
CoreEngineActorManager]] = None,
|
||||
coordinator: Optional["DPCoordinator"] = None) -> None:
|
||||
"""Wait for all processes to complete or detect if any fail.
|
||||
|
||||
Raises an exception if any process exits with a non-zero status.
|
||||
|
||||
Args:
|
||||
api_server_manager: The manager for API servers.
|
||||
engine_manager: The manager for engine processes.
|
||||
If CoreEngineProcManager, it manages local engines;
|
||||
if CoreEngineActorManager, it manages all engines.
|
||||
coordinator: The coordinator for data parallel.
|
||||
"""
|
||||
|
||||
try:
|
||||
logger.info("Waiting for API servers to complete ...")
|
||||
# Create a mapping of sentinels to their corresponding processes
|
||||
# for efficient lookup
|
||||
sentinel_to_proc: dict[Any, BaseProcess] = {
|
||||
proc.sentinel: proc
|
||||
for proc in api_server_manager.processes
|
||||
}
|
||||
|
||||
if coordinator:
|
||||
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
|
||||
|
||||
actor_run_refs = []
|
||||
if isinstance(engine_manager, CoreEngineProcManager):
|
||||
for proc in engine_manager.processes:
|
||||
sentinel_to_proc[proc.sentinel] = proc
|
||||
elif isinstance(engine_manager, CoreEngineActorManager):
|
||||
actor_run_refs = engine_manager.get_run_refs()
|
||||
|
||||
# Check if any process terminates
|
||||
while sentinel_to_proc or actor_run_refs:
|
||||
# Wait for any process to terminate
|
||||
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
|
||||
timeout=5)
|
||||
|
||||
# Process any terminated processes
|
||||
for sentinel in ready_sentinels:
|
||||
proc = sentinel_to_proc.pop(sentinel)
|
||||
|
||||
# Check if process exited with error
|
||||
if proc.exitcode != 0:
|
||||
raise RuntimeError(
|
||||
f"Process {proc.name} (PID: {proc.pid}) "
|
||||
f"died with exit code {proc.exitcode}")
|
||||
|
||||
if actor_run_refs:
|
||||
import ray
|
||||
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received KeyboardInterrupt, shutting down API servers...")
|
||||
except Exception as e:
|
||||
logger.exception("Exception occurred while running API servers: %s",
|
||||
str(e))
|
||||
raise
|
||||
finally:
|
||||
logger.info("Terminating remaining processes ...")
|
||||
api_server_manager.close()
|
||||
if coordinator:
|
||||
coordinator.close()
|
||||
if engine_manager:
|
||||
engine_manager.close()
|
||||
|
||||
|
||||
# Note(rob): shutdown function cannot be a bound method,
|
||||
# else the gc cannot collect the object.
|
||||
def shutdown(procs: list[BaseProcess]):
|
||||
# Shutdown the process.
|
||||
for proc in procs:
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
|
||||
# Allow 5 seconds for remaining procs to terminate.
|
||||
deadline = time.monotonic() + 5
|
||||
for proc in procs:
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if proc.is_alive():
|
||||
proc.join(remaining)
|
||||
|
||||
for proc in procs:
|
||||
if proc.is_alive() and (pid := proc.pid) is not None:
|
||||
kill_process_tree(pid)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
|
||||
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
||||
length: int) -> torch.Tensor:
|
||||
"""
|
||||
Copy the first length elements of a tensor into another tensor in a
|
||||
non-blocking manner.
|
||||
|
||||
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
|
||||
|
||||
Returns the sliced target tensor.
|
||||
"""
|
||||
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
||||
|
||||
|
||||
def report_usage_stats(
|
||||
vllm_config,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
|
||||
"""Report usage statistics if enabled."""
|
||||
|
||||
if not is_usage_stats_enabled():
|
||||
return
|
||||
|
||||
from vllm.model_executor.model_loader import get_architecture_class_name
|
||||
|
||||
usage_message.report_usage(
|
||||
get_architecture_class_name(vllm_config.model_config),
|
||||
usage_context,
|
||||
extra_kvs={
|
||||
# Common configuration
|
||||
"dtype":
|
||||
str(vllm_config.model_config.dtype),
|
||||
"tensor_parallel_size":
|
||||
vllm_config.parallel_config.tensor_parallel_size,
|
||||
"block_size":
|
||||
vllm_config.cache_config.block_size,
|
||||
"gpu_memory_utilization":
|
||||
vllm_config.cache_config.gpu_memory_utilization,
|
||||
|
||||
# Quantization
|
||||
"quantization":
|
||||
vllm_config.model_config.quantization,
|
||||
"kv_cache_dtype":
|
||||
str(vllm_config.cache_config.cache_dtype),
|
||||
|
||||
# Feature flags
|
||||
"enable_lora":
|
||||
bool(vllm_config.lora_config),
|
||||
"enable_prompt_adapter":
|
||||
bool(vllm_config.prompt_adapter_config),
|
||||
"enable_prefix_caching":
|
||||
vllm_config.cache_config.enable_prefix_caching,
|
||||
"enforce_eager":
|
||||
vllm_config.model_config.enforce_eager,
|
||||
"disable_custom_all_reduce":
|
||||
vllm_config.parallel_config.disable_custom_all_reduce,
|
||||
})
|
||||
0
vllm/v1/worker/__init__.py
Normal file
0
vllm/v1/worker/__init__.py
Normal file
142
vllm/v1/worker/block_table.py
Normal file
142
vllm/v1/worker/block_table.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockTable:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_blocks_per_req: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
self.block_table = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
block_ids: list[int],
|
||||
row_idx: int,
|
||||
) -> None:
|
||||
if not block_ids:
|
||||
return
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
self.num_blocks_per_row[row_idx] += num_blocks
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
|
||||
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
||||
self.num_blocks_per_row[row_idx] = 0
|
||||
self.append_row(block_ids, row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
|
||||
src, :num_blocks]
|
||||
self.num_blocks_per_row[tgt] = num_blocks
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks_src = self.num_blocks_per_row[src]
|
||||
num_blocks_tgt = self.num_blocks_per_row[tgt]
|
||||
self.num_blocks_per_row[src] = num_blocks_tgt
|
||||
self.num_blocks_per_row[tgt] = num_blocks_src
|
||||
|
||||
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.block_table.fill_(0)
|
||||
self.block_table_cpu.fill_(0)
|
||||
|
||||
def get_device_tensor(self) -> torch.Tensor:
|
||||
"""Ruturns the device tensor of the block table."""
|
||||
return self.block_table
|
||||
|
||||
def get_cpu_tensor(self) -> torch.Tensor:
|
||||
"""Returns the CPU tensor of the block table."""
|
||||
return self.block_table_cpu
|
||||
|
||||
def get_numpy_array(self) -> np.ndarray:
|
||||
"""Returns the numpy array of the block table."""
|
||||
return self.block_table_np
|
||||
|
||||
|
||||
class MultiGroupBlockTable:
|
||||
"""The BlockTables for each KV cache group."""
|
||||
|
||||
def __init__(self, max_num_reqs: int, max_model_len: int,
|
||||
max_num_batched_tokens: int, pin_memory: bool,
|
||||
device: torch.device, block_sizes: list[int]) -> None:
|
||||
self.block_tables = [
|
||||
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
|
||||
max_num_batched_tokens, pin_memory, device)
|
||||
for block_size in block_sizes
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: tuple[list[int], ...],
|
||||
row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.append_row(block_ids[i], row_idx)
|
||||
|
||||
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.add_row(block_ids[i], row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.move_row(src, tgt)
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.swap_row(src, tgt)
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit(num_reqs)
|
||||
|
||||
def clear(self) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.clear()
|
||||
|
||||
def __getitem__(self, idx: int) -> "BlockTable":
|
||||
"""Returns the BlockTable for the i-th KV cache group."""
|
||||
return self.block_tables[idx]
|
||||
86
vllm/v1/worker/cpu_model_runner.py
Normal file
86
vllm/v1/worker/cpu_model_runner.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUModelRunner(GPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
assert device == torch.device("cpu")
|
||||
assert self.speculative_config is None, "spec decode is not supported."
|
||||
|
||||
self.use_cuda_graph = False
|
||||
self.cascade_attn_enabled = False
|
||||
|
||||
self._postprocess_tenosrs()
|
||||
|
||||
def _postprocess_tenosrs(self) -> None:
|
||||
# Note: replace device tensors with cpu tensors
|
||||
def replace_tensor(obj: Any, cpu_attr_name: str,
|
||||
device_attr_name) -> None:
|
||||
cpu_tensor = getattr(obj, cpu_attr_name, None)
|
||||
device_tensor = getattr(obj, device_attr_name, None)
|
||||
if cpu_tensor is not None and device_tensor is not None:
|
||||
assert isinstance(cpu_tensor, torch.Tensor)
|
||||
assert isinstance(device_tensor, torch.Tensor)
|
||||
setattr(obj, device_attr_name, cpu_tensor)
|
||||
|
||||
for k, v in vars(self).items():
|
||||
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self, k, k[:-4])
|
||||
|
||||
for k, v in vars(self.input_batch).items():
|
||||
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self.input_batch, k, k[:-11])
|
||||
|
||||
for k, v in vars(self.input_batch.block_table).items():
|
||||
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self.input_batch.block_table, k, k[:-4])
|
||||
|
||||
def load_model(self) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model, self.model_config,
|
||||
self.scheduler_config,
|
||||
self.lora_config, self.device)
|
||||
|
||||
def warming_up_model(self) -> None:
|
||||
logger.info("Warming up model for the compilation...")
|
||||
# Only generate graph for the generic shape
|
||||
self._dummy_run(max(16, self.max_num_reqs))
|
||||
logger.info("Warming up done.")
|
||||
|
||||
def _init_device_properties(self) -> None:
|
||||
pass
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_global_compilation_settings():
|
||||
import torch._inductor.config
|
||||
|
||||
# Note: The CPPGEMM backend requires freezing parameters.
|
||||
freezing_value = torch._inductor.config.freezing
|
||||
torch._inductor.config.freezing = True
|
||||
# Note: workaround for "ValueError: fast mode: can't pickle cyclic objects
|
||||
# including object type dict"
|
||||
force_disable_caches = torch._inductor.config.force_disable_caches
|
||||
torch._inductor.config.force_disable_caches = True
|
||||
yield
|
||||
torch._inductor.config.freezing = freezing_value
|
||||
torch._inductor.config.force_disable_caches = force_disable_caches
|
||||
152
vllm/v1/worker/cpu_worker.py
Normal file
152
vllm/v1/worker/cpu_worker.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
from importlib import util
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_worker import (Worker,
|
||||
init_worker_distributed_environment)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUWorker(Worker):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False):
|
||||
super().__init__(vllm_config,
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
self.parallel_config.disable_custom_all_reduce = True
|
||||
|
||||
def init_device(self):
|
||||
# Setup OpenMP threads affinity.
|
||||
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
|
||||
self.local_omp_cpuid = "all"
|
||||
if omp_cpuids == "auto":
|
||||
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
|
||||
)
|
||||
else:
|
||||
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
|
||||
|
||||
if self.local_omp_cpuid != "all":
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
# Note: unique identifier for creating allreduce shared memory
|
||||
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
|
||||
":")[-1]
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank, "gloo")
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner: CPUModelRunner = CPUModelRunner(
|
||||
self.vllm_config, torch.device("cpu"))
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
logger.warning("sleep mode is not supported on CPU, ignore it.")
|
||||
pass
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
logger.warning("sleep mode is not supported on CPU, ignore it.")
|
||||
pass
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
return self.cache_config.cpu_kvcache_space_bytes # type: ignore
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
self.model_runner.warming_up_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return None
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
|
||||
"""Return CPUs id binding based on NUMA nodes.
|
||||
"""
|
||||
rank_to_cpus = self.local_omp_cpuid
|
||||
# Setup OpenMP thread affinity based on NUMA nodes automatically
|
||||
world_size = self.vllm_config.parallel_config.world_size
|
||||
libnuma_found = util.find_spec("numa") is not None
|
||||
psutil_found = util.find_spec("psutil") is not None
|
||||
if libnuma_found and psutil_found:
|
||||
import psutil
|
||||
from numa import info
|
||||
cpu_count = psutil.cpu_count(logical=False)
|
||||
cpus_allow_list = psutil.Process().cpu_affinity()
|
||||
numa_size = info.get_num_configured_nodes()
|
||||
cpu_count_per_numa = cpu_count // numa_size
|
||||
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
|
||||
cpu_count_per_numa // 2)
|
||||
|
||||
# check allow node_to_cpus list
|
||||
node_to_cpus = []
|
||||
for i in range(numa_size):
|
||||
node_intersect = set(
|
||||
info.node_to_cpus(i)).intersection(cpus_allow_list)
|
||||
if bool(node_intersect):
|
||||
node_to_cpus.append(list(node_intersect))
|
||||
|
||||
if world_size > len(node_to_cpus):
|
||||
logger.error(
|
||||
"Auto thread-binding failed due to "
|
||||
"world size: %d is larger than "
|
||||
"allowed NUMA nodes number: %d."
|
||||
"Please try to bind threads manually.", world_size,
|
||||
len(node_to_cpus))
|
||||
else:
|
||||
end = cpu_count_per_numa - num_of_reserved_cpu
|
||||
rank_to_cpus_list = node_to_cpus[self.rank][:end]
|
||||
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
|
||||
logger.info("auto thread-binding list: %s", rank_to_cpus)
|
||||
else:
|
||||
logger.warning(
|
||||
"Auto thread-binding is not supported due to "
|
||||
"the lack of package numa and psutil,"
|
||||
"fallback to no thread-binding. To get better performance,"
|
||||
"please try to manually bind threads.")
|
||||
return rank_to_cpus
|
||||
681
vllm/v1/worker/gpu_input_batch.py
Normal file
681
vllm/v1/worker/gpu_input_batch.py
Normal file
@@ -0,0 +1,681 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Datastructures defining an input batch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestState:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_inputs: list[MultiModalKwargs]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: SamplingParams
|
||||
generator: Optional[torch.Generator]
|
||||
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
output_token_ids: list[int]
|
||||
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
return self.prompt_token_ids[idx]
|
||||
else:
|
||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||
|
||||
|
||||
class InputBatch:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self._req_ids: list[Optional[str]] = []
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
|
||||
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
||||
# Find a way to reduce the CPU memory usage.
|
||||
# This buffer is not directly transferred to the GPU, so it does not
|
||||
# need to be pinned.
|
||||
self.token_ids_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, ),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.num_computed_tokens_cpu = \
|
||||
self.num_computed_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Block table.
|
||||
self.block_table = MultiGroupBlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
||||
self.greedy_reqs: set[str] = set()
|
||||
self.random_reqs: set[str] = set()
|
||||
|
||||
self.top_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
||||
self.top_p_reqs: set[str] = set()
|
||||
|
||||
self.top_k = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: set[str] = set()
|
||||
|
||||
self.min_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
self.min_p_reqs: set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.frequency_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.frequency_penalties_cpu = \
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_reqs: set[str] = set()
|
||||
|
||||
# Presence penalty related data structures
|
||||
self.presence_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
||||
)
|
||||
self.presence_penalties_reqs: set[str] = set()
|
||||
|
||||
# Repetition penalty related data structures
|
||||
self.repetition_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.repetition_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.repetition_penalties_cpu = \
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# req_index -> (min_tokens, stop_token_ids)
|
||||
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||
dtype=np.int32)
|
||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
||||
|
||||
# req_index -> generator
|
||||
# NOTE(woosuk): The indices of the requests that do not have their own
|
||||
# generator should not be included in the dictionary.
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: dict[str, int] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
|
||||
self.logit_bias: list[Optional[dict[int,
|
||||
float]]] = [None] * max_num_reqs
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
||||
# the value is False. Since we use masked_fill_ to set -inf.
|
||||
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
# None elements should only be present transiently
|
||||
# while performing state updates to the batch.
|
||||
return cast(list[str], self._req_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
req_index: Optional[int] = None,
|
||||
) -> None:
|
||||
if req_index is None:
|
||||
req_index = self.num_reqs
|
||||
assert req_index < self.max_num_reqs
|
||||
|
||||
req_id = request.req_id
|
||||
if req_index == len(self._req_ids):
|
||||
self._req_ids.append(req_id)
|
||||
self.req_output_token_ids.append(request.output_token_ids)
|
||||
else:
|
||||
self._req_ids[req_index] = req_id
|
||||
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
||||
self.token_ids_cpu[
|
||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
start_idx = num_prompt_tokens
|
||||
end_idx = start_idx + len(request.output_token_ids)
|
||||
self.token_ids_cpu[req_index,
|
||||
start_idx:end_idx] = request.output_token_ids
|
||||
# Number of token ids in token_ids_cpu.
|
||||
# NOTE(woosuk): This may include spec decode tokens.
|
||||
self.num_tokens[req_index] = request.num_tokens
|
||||
# Number of tokens without spec decode tokens.
|
||||
self.num_tokens_no_spec[req_index] = request.num_tokens
|
||||
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Avoid later division by zero.
|
||||
self.temperature_cpu[req_index] = -1.0
|
||||
self.greedy_reqs.add(req_id)
|
||||
else:
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1:
|
||||
self.top_p_reqs.add(req_id)
|
||||
top_k = sampling_params.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_reqs.add(req_id)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k_cpu[req_index] = top_k
|
||||
self.min_p_cpu[req_index] = sampling_params.min_p
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.min_p > _SAMPLING_EPS:
|
||||
self.min_p_reqs.add(req_id)
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[
|
||||
req_index] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
if sampling_params.min_tokens:
|
||||
self.min_tokens[req_index] = (sampling_params.min_tokens,
|
||||
sampling_params.all_stop_token_ids)
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
if request.generator is not None:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||
# Lazy allocation for this tensor, which can be large.
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device)
|
||||
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device="cpu")
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids] = False
|
||||
|
||||
if sampling_params.bad_words_token_ids:
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
if lora_id not in self.lora_id_to_request_ids:
|
||||
self.lora_id_to_request_ids[lora_id] = set()
|
||||
|
||||
self.request_lora_mapping[req_index] = lora_id
|
||||
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
||||
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
||||
else:
|
||||
# No LoRA
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
def remove_request(self, req_id: str) -> Optional[int]:
|
||||
"""This method must always be followed by a call to condense()."""
|
||||
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.min_p_reqs.discard(req_id)
|
||||
self.min_tokens.pop(req_index, None)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
self.lora_id_to_request_ids[lora_id].discard(req_id)
|
||||
if len(self.lora_id_to_request_ids[lora_id]) == 0:
|
||||
self.lora_id_to_request_ids.pop(lora_id)
|
||||
self.lora_id_to_lora_request.pop(lora_id)
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
self.logit_bias[req_index] = None
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||
self.bad_words_token_ids.pop(req_index, None)
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
old_id_i2 = self._req_ids[i2]
|
||||
self._req_ids[i1], self._req_ids[i2] =\
|
||||
self._req_ids[i2], self._req_ids[i1] # noqa
|
||||
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
||||
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
||||
assert old_id_i1 is not None and old_id_i2 is not None
|
||||
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
||||
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
||||
self.num_tokens[i1], self.num_tokens[i2] =\
|
||||
self.num_tokens[i2], self.num_tokens[i1]
|
||||
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
||||
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
||||
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
||||
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
||||
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
||||
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
||||
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
||||
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
||||
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
||||
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
||||
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
||||
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
||||
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
|
||||
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
||||
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
|
||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
||||
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
||||
|
||||
# NOTE: the following is unsafe
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
||||
# instead, we need to temporiarily copy the data for one of the indices
|
||||
# TODO(lucas): optimize this by only copying valid indices
|
||||
tmp = self.token_ids_cpu[i1, ...].copy()
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
swap_dict_values(self.generators, i1, i2)
|
||||
swap_dict_values(self.min_tokens, i1, i2)
|
||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
self.logit_bias[i1], self.logit_bias[i2] =\
|
||||
self.logit_bias[i2], self.logit_bias[i1]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1]
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
def condense(self, empty_req_indices: list[int]) -> None:
|
||||
num_reqs = self.num_reqs
|
||||
if num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
self._req_ids.clear()
|
||||
self.req_output_token_ids.clear()
|
||||
return
|
||||
|
||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||
# is sorted in descending order.
|
||||
last_req_index = num_reqs + len(empty_req_indices) - 1
|
||||
while empty_req_indices:
|
||||
# Find the largest non-empty index.
|
||||
while last_req_index in empty_req_indices:
|
||||
last_req_index -= 1
|
||||
|
||||
# Find the smallest empty index.
|
||||
empty_index = empty_req_indices.pop()
|
||||
if empty_index >= last_req_index:
|
||||
break
|
||||
|
||||
# Swap the states.
|
||||
req_id = self._req_ids[last_req_index]
|
||||
output_token_ids = self.req_output_token_ids[last_req_index]
|
||||
assert req_id is not None
|
||||
self._req_ids[empty_index] = req_id
|
||||
self._req_ids[last_req_index] = None
|
||||
self.req_output_token_ids[empty_index] = output_token_ids
|
||||
self.req_output_token_ids[last_req_index] = None
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
num_tokens = self.num_tokens[last_req_index]
|
||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||
last_req_index, :num_tokens]
|
||||
self.num_tokens[empty_index] = num_tokens
|
||||
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
||||
last_req_index]
|
||||
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
||||
last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||
self.frequency_penalties_cpu[
|
||||
empty_index] = self.frequency_penalties_cpu[last_req_index]
|
||||
self.presence_penalties_cpu[
|
||||
empty_index] = self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[
|
||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
min_token = self.min_tokens.pop(last_req_index, None)
|
||||
if min_token is not None:
|
||||
self.min_tokens[empty_index] = min_token
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[
|
||||
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
||||
last_req_index]
|
||||
|
||||
bad_words_token_ids = self.bad_words_token_ids.pop(
|
||||
last_req_index, None)
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
# Trim lists to the batch size.
|
||||
del self._req_ids[self.num_reqs:]
|
||||
del self.req_output_token_ids[self.num_reqs:]
|
||||
|
||||
def refresh_sampling_metadata(self):
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
def _make_sampling_metadata(self) -> SamplingMetadata:
|
||||
num_reqs = self.num_reqs
|
||||
if not self.all_greedy:
|
||||
temperature = copy_slice(self.temperature_cpu_tensor,
|
||||
self.temperature, num_reqs)
|
||||
else:
|
||||
temperature = None
|
||||
if not self.no_top_p:
|
||||
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
||||
if not self.no_top_k:
|
||||
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
||||
if not self.no_min_p:
|
||||
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
|
||||
|
||||
if not self.no_penalties:
|
||||
# Since syncing these tensors is expensive only copy them
|
||||
# if necessary i.e. if there are requests which require
|
||||
# penalties to be applied during sampling.
|
||||
copy_slice(self.frequency_penalties_cpu_tensor,
|
||||
self.frequency_penalties, num_reqs)
|
||||
copy_slice(self.presence_penalties_cpu_tensor,
|
||||
self.presence_penalties, num_reqs)
|
||||
copy_slice(self.repetition_penalties_cpu_tensor,
|
||||
self.repetition_penalties, num_reqs)
|
||||
|
||||
# The prompt tokens are used only for applying penalties during
|
||||
# the sampling process. Hence copy these tensors only when
|
||||
# there are requests which need penalties to be applied.
|
||||
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
||||
else:
|
||||
prompt_token_ids = None
|
||||
|
||||
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
if not self.no_allowed_token_ids:
|
||||
assert self.allowed_token_ids_mask is not None
|
||||
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
|
||||
self.allowed_token_ids_mask, num_reqs)
|
||||
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
all_greedy=self.all_greedy,
|
||||
all_random=self.all_random,
|
||||
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
||||
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
||||
min_p=None if self.no_min_p else self.min_p[:num_reqs],
|
||||
generators=self.generators,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
frequency_penalties=self.frequency_penalties[:num_reqs],
|
||||
presence_penalties=self.presence_penalties[:num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
||||
min_tokens=self.min_tokens,
|
||||
no_penalties=self.no_penalties,
|
||||
logit_bias=self.logit_bias[:num_reqs],
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
(self.num_reqs, max_prompt_len),
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
||||
prompt_token_ids[:] = self.token_ids_cpu[:self.
|
||||
num_reqs, :max_prompt_len]
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
for i in range(self.num_reqs):
|
||||
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
||||
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
||||
non_blocking=True)
|
||||
|
||||
def make_lora_inputs(
|
||||
self, num_scheduled_tokens: np.ndarray
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||
"""
|
||||
Given the num_scheduled_tokens for each request in the batch, return
|
||||
datastructures used to activate the current LoRAs.
|
||||
Returns:
|
||||
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
||||
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
||||
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
||||
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
||||
3. lora_requests: Set of relevant LoRA requests.
|
||||
"""
|
||||
|
||||
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
||||
prompt_lora_mapping = tuple(req_lora_mapping)
|
||||
token_lora_mapping = tuple(
|
||||
req_lora_mapping.repeat(num_scheduled_tokens))
|
||||
active_lora_requests: set[LoRARequest] = set(
|
||||
self.lora_id_to_lora_request.values())
|
||||
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
@property
|
||||
def all_greedy(self) -> bool:
|
||||
return len(self.random_reqs) == 0
|
||||
|
||||
@property
|
||||
def all_random(self) -> bool:
|
||||
return len(self.greedy_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_p(self) -> bool:
|
||||
return len(self.top_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_min_p(self) -> bool:
|
||||
return len(self.min_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_penalties(self) -> bool:
|
||||
return (len(self.presence_penalties_reqs) == 0
|
||||
and len(self.frequency_penalties_reqs) == 0
|
||||
and len(self.repetition_penalties_reqs) == 0)
|
||||
|
||||
@property
|
||||
def max_num_logprobs(self) -> Optional[int]:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
||||
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return not self.num_prompt_logprobs
|
||||
|
||||
@property
|
||||
def no_allowed_token_ids(self) -> bool:
|
||||
return len(self.has_allowed_token_ids) == 0
|
||||
2344
vllm/v1/worker/gpu_model_runner.py
Normal file
2344
vllm/v1/worker/gpu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
393
vllm/v1/worker/gpu_worker.py
Normal file
393
vllm/v1/worker/gpu_worker.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class Worker(WorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
||||
|
||||
# Save the buffers before level 2 sleep
|
||||
if level == 2:
|
||||
model = self.model_runner.model
|
||||
self._sleep_saved_buffers = {
|
||||
name: buffer.cpu().clone()
|
||||
for name, buffer in model.named_buffers()
|
||||
}
|
||||
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
||||
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||
used_bytes = total - free_bytes_after_sleep
|
||||
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
||||
logger.info(
|
||||
"Sleep mode freed %.2f GiB memory, "
|
||||
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
||||
used_bytes / GiB_bytes)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.wake_up(tags)
|
||||
|
||||
# Restore the buffers after level 2 sleep
|
||||
if len(self._sleep_saved_buffers):
|
||||
model = self.model_runner.model
|
||||
for name, buffer in model.named_buffers():
|
||||
if name in self._sleep_saved_buffers:
|
||||
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
||||
self._sleep_saved_buffers = {}
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
# this behavior.
|
||||
# Related issue:
|
||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# take current memory snapshot
|
||||
self.init_snapshot = MemorySnapshot()
|
||||
self.requested_memory = (self.init_snapshot.total_memory *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
if self.init_snapshot.free_memory < self.requested_memory:
|
||||
GiB = lambda b: round(b / GiB_bytes, 2)
|
||||
raise ValueError(
|
||||
f"Free memory on device "
|
||||
f"({GiB(self.init_snapshot.free_memory)}/"
|
||||
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
|
||||
f"is less than desired GPU memory utilization "
|
||||
f"({self.cache_config.gpu_memory_utilization}, "
|
||||
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
|
||||
f"utilization or reduce GPU memory used by other processes."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||
self.vllm_config, self.device)
|
||||
|
||||
if self.rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
report_usage_stats(self.vllm_config)
|
||||
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be "
|
||||
"used for one instance per process.")
|
||||
context = allocator.use_memory_pool(tag="weights")
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.load_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Profiles the peak memory usage of the model to determine how much
|
||||
memory can be used for KV cache without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the free memory that can be used for KV cache in
|
||||
bytes.
|
||||
|
||||
Tip:
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
GiB = lambda b: b / GiB_bytes
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with memory_profiling(
|
||||
self.init_snapshot,
|
||||
weights_memory=int(
|
||||
self.model_runner.model_memory_usage)) as profile_result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
free_gpu_memory = profile_result.after_profile.free_memory
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
assert self.init_snapshot.free_memory > free_gpu_memory, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
|
||||
f"current free memory {GiB(free_gpu_memory)} GiB. "
|
||||
"This happens when other processes sharing the same container "
|
||||
"release GPU memory while vLLM is profiling during initialization. "
|
||||
"To fix this, ensure consistent GPU memory allocation or "
|
||||
"isolate vLLM in its own container.")
|
||||
available_kv_cache_memory = self.requested_memory \
|
||||
- profile_result.non_kv_cache_memory
|
||||
|
||||
logger.debug(
|
||||
"Initial free memory: %.2f GiB, free memory: %.2f GiB, "
|
||||
"requested GPU memory: %.2f GiB",
|
||||
GiB(self.init_snapshot.free_memory), GiB(free_gpu_memory),
|
||||
GiB(self.requested_memory))
|
||||
logger.debug(profile_result)
|
||||
logger.info("Available KV cache memory: %.2f GiB",
|
||||
GiB(available_kv_cache_memory))
|
||||
gc.collect()
|
||||
|
||||
return int(available_kv_cache_memory)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
||||
if not self.model_config.enforce_eager:
|
||||
warmup_sizes = [
|
||||
x for x in warmup_sizes if x not in
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
]
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size)
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||
if get_pp_group().is_last_rank:
|
||||
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs))
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
if parallel_config.distributed_executor_backend != "external_launcher" \
|
||||
and not get_pp_group().is_last_rank:
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return None
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
self.profiler.start()
|
||||
else:
|
||||
self.profiler.stop()
|
||||
print(self.profiler.key_averages().table(
|
||||
sort_by="self_cuda_time_total"))
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.model_runner._dummy_run(1)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.model_loader import ShardedStateLoader
|
||||
ShardedStateLoader.save_model(
|
||||
self.model_runner.model,
|
||||
path,
|
||||
pattern=pattern,
|
||||
max_size=max_size,
|
||||
)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config: "TensorizerConfig",
|
||||
) -> None:
|
||||
self.model_runner.save_tensorized_model(
|
||||
tensorizer_config=tensorizer_config, )
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
vllm_config: VllmConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
backend: str = "nccl",
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank, backend)
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not current_platform.has_device_capability(80):
|
||||
capability = current_platform.get_device_capability()
|
||||
gpu_name = current_platform.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
173
vllm/v1/worker/lora_model_runner_mixin.py
Normal file
173
vllm/v1/worker/lora_model_runner_mixin.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Define LoRA functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a mixin for GPUModelRunner
|
||||
class LoRAModelRunnerMixin:
|
||||
|
||||
LORA_WARMUP_RANK = 8
|
||||
|
||||
def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: LoRAConfig, device: str) -> nn.Module:
|
||||
|
||||
if not supports_lora(model):
|
||||
raise ValueError(
|
||||
f"{model.__class__.__name__} does not support LoRA yet.")
|
||||
|
||||
if supports_multimodal(model):
|
||||
logger.warning("Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model.")
|
||||
|
||||
# Use get_text_config() in case of multimodal models
|
||||
text_config = model_config.hf_config.get_text_config()
|
||||
|
||||
# Add LoRA Manager to the Model Runner
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
scheduler_config.max_num_seqs,
|
||||
scheduler_config.max_num_batched_tokens,
|
||||
model_config.get_vocab_size(),
|
||||
lora_config,
|
||||
device,
|
||||
model.embedding_modules,
|
||||
model.embedding_padding_modules,
|
||||
max_position_embeddings=text_config.max_position_embeddings,
|
||||
)
|
||||
return self.lora_manager.create_lora_manager(model)
|
||||
|
||||
def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...],
|
||||
token_lora_mapping: tuple[int, ...],
|
||||
lora_requests: set[LoRARequest]) -> None:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
|
||||
# Set is_prefill to True, so we always use the SGMV kernels on
|
||||
# non-cuda platforms.
|
||||
# On cuda platforms we use the same kernels for prefill and
|
||||
# decode and this flag is generally ignored.
|
||||
lora_mapping = LoRAMapping(token_lora_mapping,
|
||||
prompt_lora_mapping,
|
||||
is_prefill=True)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
def set_active_loras(self, input_batch: InputBatch,
|
||||
num_scheduled_tokens: np.ndarray) -> None:
|
||||
|
||||
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
|
||||
token_lora_mapping: tuple[int,
|
||||
...] # of size np.sum(num_scheduled_tokens)
|
||||
lora_requests: set[LoRARequest]
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests = \
|
||||
input_batch.make_lora_inputs(num_scheduled_tokens)
|
||||
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
||||
lora_requests)
|
||||
|
||||
@contextmanager
|
||||
def maybe_setup_dummy_loras(self, lora_config):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
# __enter__ code
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_loras = lora_config.max_loras
|
||||
|
||||
# Make dummy lora requests
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path")
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
}
|
||||
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
||||
# load from disk.
|
||||
for lr in lora_requests:
|
||||
self.lora_manager.add_dummy_lora(
|
||||
lr, rank=self.LORA_WARMUP_RANK)
|
||||
|
||||
yield
|
||||
|
||||
# __exit__ code
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
@contextmanager
|
||||
def maybe_select_dummy_loras(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
# __enter__ code
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
num_loras = lora_config.max_loras
|
||||
|
||||
# Make prompt lora mapping
|
||||
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
||||
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) %
|
||||
num_loras) + 1
|
||||
|
||||
# Make token lora mapping
|
||||
token_lora_mapping = np.repeat(prompt_lora_mapping,
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Make dummy lora requests
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path")
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
}
|
||||
|
||||
self._set_active_loras(tuple(prompt_lora_mapping),
|
||||
tuple(token_lora_mapping), lora_requests)
|
||||
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
with self.maybe_setup_dummy_loras(
|
||||
lora_config), self.maybe_select_dummy_loras(
|
||||
lora_config, num_scheduled_tokens):
|
||||
yield
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.add_adapter(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.remove_adapter(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.pin_adapter(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.list_adapters()
|
||||
1673
vllm/v1/worker/tpu_model_runner.py
Normal file
1673
vllm/v1/worker/tpu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
299
vllm/v1/worker/tpu_worker.py
Normal file
299
vllm/v1/worker/tpu_worker.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A TPU worker class."""
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import bind_kv_cache, report_usage_stats
|
||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUWorker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.use_spmd = envs.VLLM_XLA_USE_SPMD
|
||||
self.original_parallel_config = None
|
||||
if self.use_spmd:
|
||||
# Under SPMD mode, distributed env is initialized as if there is
|
||||
# only one worker/device.
|
||||
self.original_parallel_config = self.parallel_config
|
||||
self.parallel_config.tensor_parallel_size = 1
|
||||
self.parallel_config.pipeline_parallel_size = 1
|
||||
self.parallel_config.world_size = 1
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.cache_dtype = self.model_config.dtype
|
||||
else:
|
||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Delay profiler initialization to the start of the profiling.
|
||||
# This is because in vLLM V1, MP runtime is initialized before the
|
||||
# TPU Worker is initialized. The profiler server needs to start after
|
||||
# MP runtime is initialized.
|
||||
self.profiler = None
|
||||
self.profile_dir = None
|
||||
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
||||
# For TPU, we can only have 1 active profiler session for 1 profiler
|
||||
# server. So we only profile on rank0.
|
||||
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
self.profile_dir)
|
||||
|
||||
if self.model_config.seed is None:
|
||||
self.model_config.seed = 0
|
||||
|
||||
def init_device(self):
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||
# ring, the xla tpu compiler flag
|
||||
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
|
||||
# fix this. It will be removed after the bug in XLA compiler is fixed.
|
||||
os.environ["LIBTPU_INIT_ARGS"] = (
|
||||
os.environ.get("LIBTPU_INIT_ARGS", "") +
|
||||
" --xla_tpu_force_1d_allreduce_at_chunk_count=1"
|
||||
" --xla_jf_conv_input_fusion=False")
|
||||
# --xla_jf_conv_input_fusion=False is used to improve the perf of
|
||||
# quantized matmul.
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
# Initialize the distributed environment.
|
||||
self._init_tpu_worker_distributed_environment(
|
||||
self.parallel_config, self.rank, self.distributed_init_method,
|
||||
self.local_rank)
|
||||
|
||||
# Device initialization should happen after initializing
|
||||
# the distributed runtime.
|
||||
self.device = xm.xla_device()
|
||||
self.device_config.device = self.device
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
if self.model_config.seed is not None:
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
# TODO (NickLucche) On gsm we compile 80+ graphs.
|
||||
# Re-evaluate limit, with MM we may get close to this limit.
|
||||
torch._dynamo.config.cache_size_limit = 128
|
||||
# Use persistent cache to avoid XLA recompilation.
|
||||
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||
# can have slightly different XLA graphs.
|
||||
world_size = self.parallel_config.world_size
|
||||
rank = xr.global_ordinal()
|
||||
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
|
||||
# Consequently, changes in optimization flags, which affect compilation
|
||||
# results, don't change the cache key. This can result in the wrong
|
||||
# compilation being used. To prevent this, disabling the XLA compilation
|
||||
# cache during development is recommended.We can disable it by
|
||||
# `export VLLM_XLA_CACHE_PATH=`
|
||||
if envs.VLLM_XLA_CACHE_PATH:
|
||||
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
|
||||
f"tp{world_size}_rank{rank}")
|
||||
xr.initialize_cache(per_rank_path, readonly=False)
|
||||
|
||||
# Init ModelRunner here, so that we have access to self.device.
|
||||
self.model_runner = \
|
||||
TPUModelRunner(self.vllm_config, self.device,
|
||||
self.original_parallel_config)
|
||||
|
||||
if rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
report_usage_stats(self.vllm_config)
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
if isinstance(layer_spec, AttentionSpec):
|
||||
dtype = layer_spec.dtype
|
||||
|
||||
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
|
||||
kv_caches[layer_name] = tpu_kv_cache
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported KV cache spec '{type(layer_spec)}'")
|
||||
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
runner_kv_caches)
|
||||
|
||||
# `max_num_tokens >= max_num_batched_tokens` due to padding.
|
||||
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
|
||||
self.model_runner.profile_run(self.model_runner.max_num_tokens)
|
||||
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
# During the profiling run, the model runs without KV cache. After
|
||||
# the profiling run, the model always runs with KV cache. Here we clear
|
||||
# the dynamo cache and cached bytecode to ensure the model always has
|
||||
# one compiled bytecode. Having one FX graph/cached bytecode per
|
||||
# compiled model is required for `support_torch_compile` decorator to
|
||||
# skip dynamo guard.
|
||||
self.model_runner.reset_dynamo_cache()
|
||||
|
||||
# Get the maximum amount of memory used by the model weights and
|
||||
# intermediate activations.
|
||||
if self.use_spmd:
|
||||
# This is a workaround for the TPU SPMD mode. The get_memory_info
|
||||
# API doesn't work with SPMD mode in PyTorch/XLA.
|
||||
# TODO: use xm.get_memory_info for SPMD once it's supported in
|
||||
# PyTorch/XLA.
|
||||
import tpu_info
|
||||
chip_type, _ = tpu_info.device.get_local_chips()
|
||||
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
|
||||
total_memory_size = device_usage[0].total_memory
|
||||
current_mem = device_usage[0].memory_usage
|
||||
else:
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
current_mem = m["bytes_used"]
|
||||
# Ideally we would use profiled = m["peak_bytes_used"] to
|
||||
# get weights + activations. But there is memory used during
|
||||
# compilation / weight loading that impacts the peak and
|
||||
# there is no way to reset peak memory in XLA, So we
|
||||
# use the heuristic of 2% of weights.
|
||||
profiled = current_mem * 1.02
|
||||
|
||||
# Calculate the TPU KV cache size based on profiling.
|
||||
usable_memory_size = int(total_memory_size *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||
|
||||
return int(tpu_kv_cache_bytes)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
output = self.model_runner.execute_model(scheduler_output)
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.rank < 1:
|
||||
if self.profile_dir is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
if self.profiler is None:
|
||||
self.profiler = xp.start_server(9012)
|
||||
xp.start_trace(self.profile_dir)
|
||||
else:
|
||||
xp.stop_trace()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model_runner.load_model()
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def _init_tpu_worker_distributed_environment(
|
||||
self,
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
if self.use_spmd:
|
||||
xr.use_spmd()
|
||||
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||
# own context.
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
try:
|
||||
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
||||
TPUWorker = TPUCommonsWorker # type: ignore
|
||||
except ImportError:
|
||||
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
|
||||
pass
|
||||
111
vllm/v1/worker/utils.py
Normal file
111
vllm/v1/worker/utils.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: object,
|
||||
expected_num_items: int,
|
||||
) -> None:
|
||||
"""
|
||||
Perform sanity checks for the result of
|
||||
[`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
|
||||
"""
|
||||
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
|
||||
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
|
||||
f"or a single 3D tensor, but got {type(mm_embeddings)} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
assert len(mm_embeddings) == expected_num_items, (
|
||||
"Expected number of multimodal embeddings to match number of "
|
||||
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
assert all(e.ndim == 2 for e in mm_embeddings), (
|
||||
"Expected multimodal embeddings to be a sequence of 2D tensors, "
|
||||
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
|
||||
def scatter_mm_placeholders(
|
||||
embeds: torch.Tensor,
|
||||
is_embed: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Scatter the multimodal embeddings into a contiguous tensor that represents
|
||||
the placeholder tokens.
|
||||
|
||||
[`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
|
||||
|
||||
Args:
|
||||
embeds: The multimodal embeddings.
|
||||
Shape: `(num_embeds, embed_dim)`
|
||||
is_embed: A boolean mask indicating which positions in the placeholder
|
||||
tokens need to be filled with multimodal embeddings.
|
||||
Shape: `(num_placeholders, num_embeds)`
|
||||
"""
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
|
||||
placeholders = embeds.new_full(
|
||||
(is_embed.shape[0], embeds.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
placeholders[is_embed] = embeds
|
||||
return placeholders
|
||||
|
||||
|
||||
def gather_mm_placeholders(
|
||||
placeholders: torch.Tensor,
|
||||
is_embed: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reconstructs the embeddings from the placeholder tokens.
|
||||
|
||||
This is the operation of [scatter_mm_placeholders][].
|
||||
"""
|
||||
if is_embed is None:
|
||||
return placeholders
|
||||
|
||||
return placeholders[is_embed]
|
||||
|
||||
|
||||
def initialize_kv_cache_for_kv_sharing(
|
||||
shared_kv_cache_layers: dict[str, str],
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
) -> None:
|
||||
"""
|
||||
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
|
||||
for layers that do not allocate its own KV cache, based on the mapping in
|
||||
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
|
||||
group, which is needed to ensure that attention metadata is assigned later.
|
||||
|
||||
Args:
|
||||
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
|
||||
If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
means this layer will perform attention using the keys and values
|
||||
from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
kv_cache_groups: The KV cache groups of the model.
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
Note that layers in shared_kv_cache_layers.keys() are not
|
||||
originally included as it only contains layers which have its own
|
||||
KV cache allocation.
|
||||
"""
|
||||
# Record index of KV cache group for each layer that allocates a KV cache.
|
||||
layer_to_kv_cache_group_idx: dict[str, int] = {}
|
||||
for i, kv_cache_group in enumerate(kv_cache_groups):
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
layer_to_kv_cache_group_idx[layer_name] = i
|
||||
|
||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
|
||||
kv_cache_groups[group_idx].layer_names.append(layer_name)
|
||||
65
vllm/v1/worker/worker_base.py
Normal file
65
vllm/v1/worker/worker_base.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class WorkerBase(WorkerBaseV0):
|
||||
"""
|
||||
Abstract class for v1 worker, mainly define some methods for v1.
|
||||
For methods shared by v0 and v1, define them in v0 WorkerBase
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize common worker components.
|
||||
|
||||
Args:
|
||||
vllm_config: Complete vLLM configuration
|
||||
local_rank: Local device index
|
||||
rank: Global rank in distributed setup
|
||||
distributed_init_method: Distributed initialization method
|
||||
is_driver_worker: Whether this worker handles driver
|
||||
responsibilities
|
||||
"""
|
||||
# Configuration storage
|
||||
super().__init__(vllm_config=vllm_config)
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
# Device and model state
|
||||
self.device: Optional[torch.device] = None
|
||||
self.model_runner: Optional[nn.Module] = None
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""Get specifications for KV cache implementation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
"""Prepare model for execution through compilation/warmup."""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Basic health check (override for device-specific checks)."""
|
||||
return
|
||||
Reference in New Issue
Block a user