add qwen3

This commit is contained in:
Chranos
2026-02-04 17:22:39 +08:00
parent d1c0f68ab4
commit 8511fe8530
1932 changed files with 300426 additions and 0 deletions

View File

View File

@@ -0,0 +1,250 @@
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.vllm_flash_attn import flash_attn_varlen_func
class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "flash-attn-vllm-v1"
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@dataclass
class FlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_start_loc: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
class FlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
output = torch.empty_like(query)
torch.ops.vllm.unified_v1_flash_attention(
output,
query,
key,
value,
self.num_heads,
self.head_size,
self.num_kv_heads,
kv_cache,
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
)
return output
def unified_v1_flash_attention(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
current_metadata = get_forward_context()
if current_metadata is None:
# Profiling run.
return
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_actual_tokens = attn_metadata.num_actual_tokens
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
# Reshape the input keys and values and store them in the cache.
key_cache = kv_cache[0]
value_cache = kv_cache[1]
torch.ops._C_cache_ops.reshape_and_cache_flash(
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
attn_metadata.slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
attn_output = flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
window_size=window_size,
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
attn_output = attn_output.view(num_actual_tokens, -1)
# TODO(woosuk): Optimize this.
output[:num_actual_tokens].copy_(attn_output)
def unified_v1_flash_attention_fake(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_v1_flash_attention",
op_func=unified_v1_flash_attention,
mutates_args=["kv_cache", "output"],
fake_impl=unified_v1_flash_attention_fake,
)

View File

View File

@@ -0,0 +1,48 @@
from typing import Dict, List, Set, Tuple
from vllm.v1.request import Request
class EncoderCacheManager:
def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
# req_id -> cached input ids
self.cached: Dict[str, Set[int]] = {}
# List of [req_id, input_id]
self.freed: List[Tuple[str, int]] = []
def has_cache(self, request: Request, input_id: int) -> bool:
req_id = request.request_id
return req_id in self.cached and input_id in self.cached[req_id]
def can_allocate(self, request: Request, input_id: int) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id)
return num_tokens <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
req_id = request.request_id
if req_id not in self.cached:
self.cached[req_id] = set()
self.cached[req_id].add(input_id)
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
def get_cached_input_ids(self, request: Request) -> Set[int]:
return self.cached.get(request.request_id, set())
def free(self, request: Request, input_id: int) -> None:
req_id = request.request_id
if req_id not in self.cached:
return
self.cached[req_id].discard(input_id)
if len(self.cached[req_id]) == 0:
del self.cached[req_id]
self.num_free_slots += request.get_num_encoder_tokens(input_id)
self.freed.append((req_id, input_id))
def get_freed_ids(self) -> List[Tuple[str, int]]:
freed = self.freed
self.freed = []
return freed

View File

@@ -0,0 +1,397 @@
from collections import defaultdict
from typing import Dict, List, Optional
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, hash_block_tokens,
hash_request_tokens)
from vllm.v1.request import Request
logger = init_logger(__name__)
class KVCacheManager:
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
sliding_window: Optional[int] = None,
enable_caching: bool = True,
num_preallocate_tokens: int = 64,
) -> None:
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.sliding_window = sliding_window
self.enable_caching = enable_caching
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
# blocks for each request. For example, when a request reaches the end
# of its block table, we preallocate N blocks in advance. This way, we
# reduce the overhead of updating free_block_ids and ref_cnts for each
# request every step (at the cost of some memory waste).
# NOTE(woosuk): This is different from the "lookahead" slots since this
# does not guarantee that the request always has N empty blocks. After
# the request gets N empty blocks, it starts to use the blocks without
# further allocation. When it uses up all the N empty blocks, it gets
# N new empty blocks.
self.num_preallocate_tokens = num_preallocate_tokens
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
# A Block pool of all kv-cache blocks.
self.block_pool: List[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool)
# {block_hash: {block ID: block}}. A cached block is
# a full block with a block hash that can be used for prefix caching.
# The cached block may be used by running requests or in the
# free_block_queue that could potentially be evicted.
# NOTE: We currently don't de-duplicate the blocks in the cache,
# meaning that if a block becomes full and is cached, we don't check
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self.cached_block_hash_to_block: Dict[BlockHashType, Dict[
int, KVCacheBlock]] = defaultdict(dict)
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}
def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
Args:
request: The request to get the computed blocks.
Returns:
A list of blocks that are computed for the request.
"""
if not self.enable_caching:
# Prefix caching is disabled.
return []
computed_blocks = []
block_hashes = hash_request_tokens(self.block_size,
request.all_token_ids)
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := self._get_cached_block(block_hash):
computed_blocks.append(cached_block)
else:
break
return computed_blocks
def append_slots(
self,
request: Request,
num_tokens: int,
) -> Optional[List[KVCacheBlock]]:
"""Append slots to the block table of the request.
We first append slots to already allocated blocks. If the allocated
blocks are not enough, we allocate new blocks.
Args:
request: The request to append slots.
num_tokens: The number of tokens to append.
Returns:
A list of new blocks if new blocks are allocated, or None
if new blocks are required but cannot be allocated.
"""
num_required_blocks = cdiv(request.num_computed_tokens + num_tokens,
self.block_size)
req_blocks = self.req_to_blocks[request.request_id]
num_new_blocks = num_required_blocks - len(req_blocks)
if num_new_blocks > self.free_block_queue.num_free_blocks:
# Need to allocate new blocks due to insufficient pre-allocated
# slots, but we cannot allocate new blocks due to the limit.
return None
# When caching is enabled, assign token IDs to already allocated blocks.
new_token_ids = None
parent_block = None
if self.enable_caching:
# Figure out the token IDs to add to the blocks.
new_token_ids = request.all_token_ids[
request.num_computed_tokens:request.num_computed_tokens +
num_tokens]
# Find the last full block index.
# TODO: This may be optimized by calculating the computed tokens.
last_full_block_idx = len(req_blocks) - 1
while (last_full_block_idx >= 0
and req_blocks[last_full_block_idx].block_hash is None):
last_full_block_idx -= 1
parent_block = (req_blocks[last_full_block_idx]
if last_full_block_idx >= 0 else None)
token_id_idx = self._add_token_ids_to_blocks(
blocks=req_blocks[last_full_block_idx + 1:],
token_ids=new_token_ids,
parent_block=parent_block)
new_token_ids = new_token_ids[token_id_idx:]
parent_block = req_blocks[-1]
# No new block is needed. When caching is enabled, we make sure
# token_id_idx is equal to len(new_token_ids), meaning that all tokens
# are added to allocated blocks.
if num_required_blocks <= len(req_blocks):
assert not self.enable_caching or token_id_idx == num_tokens, \
f"{token_id_idx=} != {num_tokens=}"
return []
# Allocate new blocks considering preallocated blocks, and
# add token IDs to them if caching is enabled.
num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks)
new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids,
parent_block)
req_blocks.extend(new_blocks)
return new_blocks
def allocate_slots(
self,
request: Request,
num_tokens: int,
computed_blocks: List[KVCacheBlock],
) -> Optional[List[KVCacheBlock]]:
"""Allocate slots for a new request.
Args:
request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
computed_blocks: The blocks that have already been computed.
Returns:
A list of new allocated blocks.
"""
if num_tokens == 0:
raise ValueError(
f"num_tokens must be greater than 0, got {num_tokens}")
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = len(
[blk for blk in computed_blocks if blk.ref_cnt == 0])
num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks):
# Cannot allocate new blocks.
return None
# Determine the number of new blocks to allocate considering
# preallocated blocks.
num_new_blocks = min(
num_required_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks)
num_computed_tokens = len(computed_blocks) * self.block_size
# When caching is enabled, get the new token IDs and the parent block
# ID to generate cache keys.
new_token_ids = None
parent_block = None
if self.enable_caching:
# Touch the computed blocks to make sure they won't be evicted.
self._touch(computed_blocks)
# Get the token IDs for the blocks being allocated for hashing.
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_tokens]
if not new_token_ids:
raise RuntimeError(
"Failed to infer the token IDs for allocation. "
f"#all_tokens={len(request.all_token_ids)} < "
f"#computed_tokens={num_computed_tokens}")
# Get the parent block ID to construct the block chain.
parent_block = computed_blocks[-1] if computed_blocks else None
new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids,
parent_block)
# Concatenate the computed block IDs and the new block IDs.
self.req_to_blocks[request.request_id] = computed_blocks + new_blocks
return new_blocks
def free(self, request: Request) -> None:
"""Free the blocks allocated for the request.
When caching is enabled, we free the blocks in reverse order so that
the tail blocks are evicted first.
Args:
request: The request to free the blocks.
"""
# Default to [] in case a request is freed (aborted) before alloc.
blocks = self.req_to_blocks.pop(request.request_id, [])
if self.enable_caching:
# Free blocks in reverse order so that the tail blocks are
# freed first.
blocks = reversed(blocks)
for block in blocks:
block.ref_cnt -= 1
if block.ref_cnt == 0:
self.free_block_queue.append(block)
def _get_new_blocks(
self,
num_blocks: int,
token_ids: Optional[List[int]] = None,
parent_block: Optional[int] = None) -> List[KVCacheBlock]:
"""Get new blocks from the free block pool, and add token IDs to
allocated blocks if caching is enabled.
Note that we do not check block cache in this function.
Args:
num_blocks: The number of blocks to allocate.
token_ids: The token IDs in the blocks. None if caching is disabled.
parent_block: The parent block. Used to include block chain
in the block hash.
Returns:
A list of new block.
"""
if num_blocks > self.free_block_queue.num_free_blocks:
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
# First allocate blocks.
ret: List[KVCacheBlock] = []
idx = 0
while idx < num_blocks:
curr_block = self.free_block_queue.popleft()
assert curr_block.ref_cnt == 0
# Evict blocks from the cache.
if self.enable_caching:
block_hash = curr_block.block_hash
if (block_hash is not None
and block_hash in self.cached_block_hash_to_block):
if len(self.cached_block_hash_to_block[block_hash]) == 1:
del self.cached_block_hash_to_block[block_hash]
else:
del self.cached_block_hash_to_block[block_hash][
curr_block.block_id]
curr_block.reset()
curr_block.ref_cnt = 1
ret.append(curr_block)
idx += 1
# Then assign token IDs to the allocated blocks.
if self.enable_caching:
assert token_ids is not None
token_id_idx = self._add_token_ids_to_blocks(
blocks=ret, token_ids=token_ids, parent_block=parent_block)
assert token_id_idx == len(token_ids)
return ret
def _cache_full_block(self,
block: KVCacheBlock,
parent_block: Optional[KVCacheBlock] = None) -> None:
"""Cache a full block for prefix caching.
Args:
block: The block to cache.
parent_block: The parent block. None if this is the first block.
"""
parent_block_hash = (parent_block.block_hash
if parent_block is not None else None)
assert len(block.token_ids) == self.block_size
block.token_ids = tuple(block.token_ids)
block_hash = hash_block_tokens(parent_block_hash, block.token_ids)
block.block_hash = block_hash
block.num_hashed_tokens = self.block_size + (
parent_block.num_hashed_tokens if parent_block is not None else 0)
self.cached_block_hash_to_block[block_hash][block.block_id] = block
def _get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
If there are duplicated blocks, we return the first block in the cache.
Args:
block_hash: The hash value of the block.
Returns:
The cached block if it exists, or None.
"""
if block_hash in self.cached_block_hash_to_block:
first_block_id = list(
self.cached_block_hash_to_block[block_hash].keys())[0]
return self.cached_block_hash_to_block[block_hash][first_block_id]
return None
def _touch(self, blocks: List[KVCacheBlock]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
Args:
blocks: A list of blocks to touch.
"""
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0:
self.free_block_queue.remove(block)
block.ref_cnt += 1
def _add_token_ids_to_blocks(
self,
blocks: List[KVCacheBlock],
token_ids: List[int],
parent_block: Optional[KVCacheBlock] = None) -> int:
"""Add token IDs to a list of allocated blocks.
If a block becomes full after adding token IDs, cache it.
Return the token ID index that has not been added to the blocks
if the blocks are not enough to hold all the token IDs.
Args:
blocks: A list of blocks to add token IDs.
token_ids: A list of token IDs to add.
parent_block: The parent block. None if this is the
first block.
Returns:
The starting token ID index that has not been added to the blocks
due to insufficient given blocks.
"""
token_id_start = 0
for curr_block in blocks:
# If all token IDs are added, then the rest of the blocks are
# preallocated blocks, so we only need to update the
# parent_block_id. FIXME
if token_id_start == len(token_ids):
continue
# Add token IDs to the empty slots in the block.
empty_slots = self.block_size - len(curr_block.token_ids)
token_id_end = min(token_id_start + empty_slots, len(token_ids))
curr_block.token_ids.extend(token_ids[token_id_start:token_id_end])
# Cache the block if it becomes full.
if len(curr_block.token_ids) == self.block_size:
self._cache_full_block(curr_block, parent_block)
parent_block = curr_block
token_id_start = token_id_end
return token_id_start

View File

@@ -0,0 +1,194 @@
"""KV-Cache Utilities."""
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Union
from vllm.logger import init_logger
logger = init_logger(__name__)
BlockHashType = Tuple[int, Tuple[int]]
@dataclass
class KVCacheBlock:
"""KV-cache block metadata."""
# Block ID, ranging from 0 to num_gpu_blocks - 1.
block_id: int
# Reference count.
ref_cnt: int = 0
# Token IDs in the block. When the block is full, the type of token_ids
# should be Tuple[int] for fast matching.
token_ids: Union[List[int], Tuple[int]] = field(default_factory=list)
# The hash of the block composed of (block hash, tuple of token IDs).
# It is only available when the block is full.
block_hash: Optional[BlockHashType] = None
# The number of hashed tokens. More hashed tokens means the block
# is closer to the end of a prompt and more likely to be evicted.
num_hashed_tokens: int = 0
# Used to construct a doubly linked list for free blocks.
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
prev_free_block: Optional["KVCacheBlock"] = None
next_free_block: Optional["KVCacheBlock"] = None
def reset(self):
"""Reset the block metadata."""
self.ref_cnt = 0
self.token_ids = []
self.block_hash = None
self.num_hashed_tokens = 0
class FreeKVCacheBlockQueue:
"""This class organizes a list of KVCacheBlock objects to a doubly linked
list of free blocks. We implement this class instead of using Python
builtin deque to support removing a block in the middle of the queue
in O(1) time. To close the performance gap to the builtin deque which is
implemented in C++, this class does not allocate any Python objects when
manipulating the linked list. Instead, this class manipulates the
prev_free_block and next_free_block attributes of the given blocks.
The queue is ordered by block ID in the beginning. When a block is allocated
and then freed, it will be appended back with the eviction order:
1. The least recent used block is at the front (LRU).
2. If two blocks have the same last accessed time (allocated by the
same sequence), the one with more hash tokens (the tail of a block
chain) is at the front.
Note that we maintain this order by reversing the block order when free
blocks of a request. This operation is outside of this class.
Args:
blocks: A list of KVCacheBlock objects.
"""
def __init__(self, blocks: List[KVCacheBlock]) -> None:
self.num_free_blocks = len(blocks)
# Initialize the doubly linked list of free blocks.
self.free_list_head = blocks[0]
self.free_list_tail = blocks[-1]
for i in range(self.num_free_blocks):
if i > 0:
blocks[i].prev_free_block = blocks[i - 1]
if i < self.num_free_blocks - 1:
blocks[i].next_free_block = blocks[i + 1]
def popleft(self) -> KVCacheBlock:
"""Pop the first free block and reduce num_free_blocks by 1.
Returns:
The first free block.
"""
if not self.free_list_head:
raise ValueError("No free blocks available")
block = self.free_list_head
self.remove(block)
return block
def remove(self, block: KVCacheBlock) -> None:
"""Remove a block in the free list and reduce num_free_blocks by 1.
Args:
block: The block to remove.
"""
if block.prev_free_block is not None:
# Link the previous block to the next block.
block.prev_free_block.next_free_block = block.next_free_block
if block.next_free_block is not None:
# Link the next block to the previous block.
block.next_free_block.prev_free_block = block.prev_free_block
if block == self.free_list_head:
# Update the head if the block is the head.
self.free_list_head = block.next_free_block
if block == self.free_list_tail:
# Update the tail if the block is the tail.
self.free_list_tail = block.prev_free_block
# Remove the block from the linked list.
block.prev_free_block = block.next_free_block = None
self.num_free_blocks -= 1
def append(self, block: KVCacheBlock) -> None:
"""Put a block back into the free list and increase
num_free_blocks by 1.
Args:
block: The block to append.
"""
if self.free_list_tail is not None:
# Link the last block to the new block.
self.free_list_tail.next_free_block = block
block.prev_free_block = self.free_list_tail
self.free_list_tail = block
else:
# The free list is empty.
assert self.free_list_head is None
self.free_list_head = self.free_list_tail = block
block.next_free_block = None
self.num_free_blocks += 1
def get_all_free_blocks(self) -> List[KVCacheBlock]:
"""Get all free blocks in the free list. Mainly used for testing.
Returns:
A list of free blocks.
"""
ret = []
curr_block = self.free_list_head
while curr_block is not None:
ret.append(curr_block)
curr_block = curr_block.next_free_block
return ret
def hash_block_tokens(parent_block_hash: Optional[int],
curr_block_token_ids: Tuple[int]) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
hash values for the same block contents.
TODO: Support arbitrary metadata so that we could support more
features such as LoRA adapter.
Args:
parent_block_hash: The hash of the parent block. None
if this is the first block.
curr_block_token_ids: A tuple of token ids in the current
block. The current block is assumed to be full.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return (hash(
(parent_block_hash, *curr_block_token_ids)), curr_block_token_ids)
def hash_request_tokens(block_size: int,
token_ids: List[int]) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
block_size: The size of each block.
token_ids: A sequence of token ids in the request.
Returns:
The list of computed hash values.
"""
ret = []
parent_block_hash = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = tuple(token_ids[start:end])
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
block_hash = hash_block_tokens(parent_block_hash, block_token_ids)
ret.append(block_hash)
parent_block_hash = block_hash
return ret

View File

@@ -0,0 +1,591 @@
from collections import deque
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
logger = init_logger(__name__)
class Scheduler:
def __init__(
self,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
# TODO: Support LoRA.
assert lora_config is None, "V1 does not support LoRA yet."
num_gpu_blocks = cache_config.num_gpu_blocks
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
# Create the block space manager.
self.kv_cache_manager = KVCacheManager(
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
self.block_size = self.cache_config.block_size
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = \
self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len
# req_id -> Request
self.requests: Dict[str, Request] = {}
# Priority queues for requests.
self.waiting: Deque[Request] = deque()
self.running: List[Request] = []
# The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished
# requests so that they can free the cached states for those requests.
# This is flushed at the end of each scheduling step.
self.finished_req_ids: Set[str] = set()
# OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> RunningRequestData
self.running_reqs_data: Dict[str, RunningRequestData] = {}
# Encoder-related.
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT).
# FIXME(woosuk): Below are placeholder values. We need to calculate the
# actual values from the configurations.
self.max_num_encoder_input_tokens = 2048
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
self.encoder_cache_manager = EncoderCacheManager(cache_size=2048)
def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and num_tokens,
# which is equal to len(prompt_token_ids) + len(output_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens. This is general enough to cover chunked prefills,
# prefix caching, and the "jump decoding" optimization in the future.
scheduled_new_reqs: List[Request] = []
scheduled_resumed_reqs: List[Request] = []
scheduled_running_reqs: List[Request] = []
preempted_reqs: List[Request] = []
req_to_new_block_ids: Dict[str, List[int]] = {}
num_scheduled_tokens: Dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: Dict[str, List[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
# First, schedule the RUNNING requests.
# NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be
# in the "partial" state, where the request has some tokens computed
# but not all. The constraint is due to the persistent batch in the
# V1 model runner.
# TODO(woosuk): Remove this constraint after refactoring model runner.
has_partial_request = False
req_index = 0
while req_index < len(self.running):
# Only the last request in the RUNNING queue can be "partial".
assert not has_partial_request
assert token_budget > 0
request = self.running[req_index]
num_new_tokens = request.num_tokens - request.num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = (
self._try_schedule_encoder_inputs(request,
request.num_computed_tokens,
num_new_tokens,
encoder_budget))
assert num_new_tokens > 0
while True:
new_blocks = self.kv_cache_manager.append_slots(
request, num_new_tokens)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
self.waiting.appendleft(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt.
can_schedule = False
break
else:
# The request can be scheduled.
can_schedule = True
break
if not can_schedule:
break
# Schedule the request.
scheduled_running_reqs.append(request)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
has_partial_request = (request.num_computed_tokens + num_new_tokens
< request.num_tokens)
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting:
if has_partial_request:
break
if len(self.running) == self.max_num_running_reqs:
break
if token_budget == 0:
break
request = self.waiting[0]
# Get already-cached tokens.
computed_blocks = self.kv_cache_manager.get_computed_blocks(
request)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
# Number of tokens to be scheduled.
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if num_new_tokens == 0:
# The happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last token.
num_computed_tokens -= 1
num_new_tokens = 1
computed_blocks.pop()
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks)
if new_blocks is None:
# The request cannot be scheduled.
break
self.waiting.popleft()
self.running.append(request)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(
f"Invalid request status: {request.status}")
req_to_new_block_ids[request.request_id] = [
b.block_id for b in computed_blocks + new_blocks
]
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
has_partial_request = (num_computed_tokens + num_new_tokens <
request.num_tokens)
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
len(scheduled_running_reqs) == len(self.running))
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
req_to_new_block_ids[req.request_id],
req.num_computed_tokens)
for req in scheduled_new_reqs
]
resumed_reqs_data = [
ResumedRequestData.from_request(
req, req_to_new_block_ids[req.request_id],
req.num_computed_tokens) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_running_request_data(
req, req_to_new_block_ids[req.request_id],
req.num_computed_tokens) for req in scheduled_running_reqs
]
preempted_req_ids = {req.request_id for req in preempted_reqs}
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_resumed_reqs=resumed_reqs_data,
scheduled_running_reqs=running_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
preempted_req_ids=preempted_req_ids,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
)
self.finished_req_ids = set()
return scheduler_output
def _make_running_request_data(
self,
request: Request,
new_block_ids: List[int],
num_computed_tokens: int,
) -> "RunningRequestData":
# OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
# them at each scheduling step.
if request.request_id in self.running_reqs_data:
req_data = self.running_reqs_data[request.request_id]
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
req_data = RunningRequestData.from_request(request, new_block_ids,
num_computed_tokens)
self.running_reqs_data[request.request_id] = req_data
return req_data
def _try_schedule_encoder_inputs(
self,
request: Request,
num_computed_tokens: int,
num_new_tokens: int,
encoder_budget: int,
) -> Tuple[List[int], int, int]:
"""
Determine which encoder inputs need to be scheduled in the current step,
and update `num_new_tokens` and encoder token budget accordingly.
An encoder input will be scheduled if:
- Its output tokens overlap with the range of tokens being computed
in this step, i.e.,
[num_computed_tokens, num_computed_tokens + num_new_tokens).
- It is not already computed and stored in the encoder cache.
- There is sufficient encoder token budget to process it.
- The encoder cache has space to store it.
If an encoder input cannot be scheduled due to cache or budget
limitations, the method adjusts `num_new_tokens` to schedule only the
decoder tokens up to just before the unschedulable encoder input.
"""
if not request.has_encoder_inputs():
return [], num_new_tokens, encoder_budget
encoder_inputs_to_schedule: List[int] = []
mm_positions = request.mm_positions
assert mm_positions is not None
assert len(mm_positions) > 0
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_new_tokens:
# The encoder input is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder input is already computed and stored
# in the decoder's KV cache.
continue
if self.encoder_cache_manager.has_cache(request, i):
# The encoder input is already computed and cached.
continue
if not self.encoder_cache_manager.can_allocate(request, i):
# The encoder cache is full. We can only schedule the decoder
# tokens just before the encoder input.
num_new_tokens = start_pos - num_computed_tokens
break
if num_encoder_tokens > encoder_budget:
# The encoder budget is exhausted. We can only schedule the
# decoder tokens up until the encoder input.
# NOTE(woosuk): We assume that the encoder tokens should be
# processed altogether, as the encoder usually uses
# bidirectional attention.
num_new_tokens = start_pos - num_computed_tokens
break
encoder_budget -= num_encoder_tokens
encoder_inputs_to_schedule.append(i)
return encoder_inputs_to_schedule, num_new_tokens, encoder_budget
def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> List[EngineCoreOutput]:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
engine_core_outputs: List[EngineCoreOutput] = []
for request in self.running:
req_id = request.request_id
request.num_computed_tokens += num_scheduled_tokens[req_id]
# When the request's num_computed_tokens catches up its num_tokens,
# the request generates output tokens. Otherwise, we ignore the
# sampler output for the request.
assert request.num_computed_tokens <= request.num_tokens
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
for input_id in list(cached_encoder_input_ids):
start_pos = request.mm_positions[input_id]["offset"]
num_tokens = request.mm_positions[input_id]["length"]
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free(request, input_id)
if request.num_computed_tokens == request.num_tokens:
req_index = model_runner_output.req_id_to_index[req_id]
# NOTE(woosuk): Currently, we assume that each request
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
request.append_output_token_ids(token_id)
num_new_tokens = 1
# TODO: Update the KV cache manager for prefix caching.
# Check for stop and update request state.
# This must be called before me make the EngineCoreOutput.
stopped = self._check_stop(request)
# Add EngineCoreOutput for this Request.
output = EngineCoreOutput(
request_id=req_id,
new_token_ids=request.output_token_ids[-num_new_tokens:],
finished=request.is_finished(),
finish_reason=request.get_finished_reason(),
stop_reason=request.stop_reason)
engine_core_outputs.append(output)
# Breakout of the loop.
if stopped:
continue
new_running.append(request)
self.running = new_running
return engine_core_outputs
def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
self._free_request(request)
return True
sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
self._free_request(request)
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
self._free_request(request)
return True
return False
def add_request(self, request: Request) -> None:
self.waiting.append(request)
self.requests[request.request_id] = request
def finish_requests(
self,
request_ids: Union[str, Iterable[str]],
finished_status: RequestStatus,
) -> None:
"""Handles the finish signal from outside the scheduler.
For example, the API server can abort a request when the client
disconnects.
"""
assert RequestStatus.is_finished(finished_status)
if isinstance(request_ids, str):
request_ids = (request_ids, )
request_ids = set(request_ids)
for req_id in request_ids:
request = self.requests.get(req_id)
if request is None:
# Invalid request ID.
continue
if request.status == RequestStatus.RUNNING:
self.running.remove(request)
else:
self.waiting.remove(request)
request.status = finished_status
self._free_request(request)
def _free_request(self, request: Request) -> None:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.running_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
self.finished_req_ids.add(request.request_id)
def get_num_unfinished_requests(self) -> int:
return len(self.waiting) + len(self.running)
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0
@dataclass
class NewRequestData:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
block_ids: List[int]
num_computed_tokens: int
@classmethod
def from_request(
cls,
request: Request,
block_ids: List[int],
num_computed_tokens: int,
) -> "NewRequestData":
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
mm_inputs=request.mm_inputs,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
)
@dataclass
class ResumedRequestData:
req_id: str
block_ids: List[int]
num_computed_tokens: int
@classmethod
def from_request(
cls,
request: Request,
block_ids: List[int],
num_computed_tokens: int,
) -> "ResumedRequestData":
return cls(
req_id=request.request_id,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
)
@dataclass
class RunningRequestData:
req_id: str
new_block_ids: List[int]
num_computed_tokens: int
@classmethod
def from_request(
cls,
request: Request,
new_block_ids: List[int],
num_computed_tokens: int,
) -> "RunningRequestData":
return cls(
req_id=request.request_id,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)
@dataclass
class SchedulerOutput:
scheduled_new_reqs: List[NewRequestData]
scheduled_resumed_reqs: List[ResumedRequestData]
scheduled_running_reqs: List[RunningRequestData]
num_scheduled_tokens: Dict[str, int]
total_num_scheduled_tokens: int
scheduled_encoder_inputs: Dict[str, List[int]]
preempted_req_ids: Set[str]
finished_req_ids: Set[str]
free_encoder_input_ids: List[Tuple[str, int]]

View File

@@ -0,0 +1,77 @@
import enum
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import msgspec
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind, SamplingParams
@dataclass
class DetokenizerRequest:
request_id: str
prompt: Optional[str]
prompt_token_ids: List[int]
skip_special_tokens: bool
spaces_between_special_tokens: bool
output_kind: RequestOutputKind
stop: List[str]
include_stop_str_in_output: bool
@dataclass
class EngineCoreRequest:
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec
# due to circular imports and typing we have in data.py
request_id: str
#NOTE(Nick): I don't think we need to pass prompt here since it should
# always be tokenized?
prompt: Optional[str]
prompt_token_ids: List[int]
mm_data: Optional[MultiModalDataDict]
mm_placeholders: Optional[MultiModalPlaceholderDict]
mm_processor_kwargs: Optional[Dict[str, Any]]
sampling_params: SamplingParams
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]
class EngineCoreOutput(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
request_id: str
new_token_ids: List[int]
finished: bool
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
class EngineCoreOutputs(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout and using an int enum for finish/stop reason
# [num_reqs]
outputs: List[EngineCoreOutput]
class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
ADD = b'\x00'
ABORT = b'\x01'

View File

@@ -0,0 +1,372 @@
import asyncio
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_stream import AsyncStream
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.gpu_executor import GPUExecutor
logger = init_logger(__name__)
class AsyncLLM(EngineClient):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
log_requests: bool = True,
start_engine_loop: bool = True,
) -> None:
assert start_engine_loop
self.log_requests = log_requests
self.log_stats = log_stats
self.stat_loggers = stat_loggers
self.model_config = vllm_config.model_config
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
enable_lora=bool(vllm_config.lora_config))
self.tokenizer.ping()
# Request streams (map of request_id -> AsyncStream).
self.request_streams: Dict[str, AsyncStream] = {}
# List of cancelled request ids to be aborted.
self.client_aborted_requests: List[str] = []
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(vllm_config.model_config,
vllm_config.lora_config, self.tokenizer,
input_registry)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer)
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_client(
vllm_config=vllm_config,
executor_class=executor_class,
usage_context=usage_context,
multiprocess_mode=True,
asyncio_mode=True,
)
self.output_handler = None
def __del__(self):
self.shutdown()
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Create an AsyncLLM from the EngineArgs."""
# Create the engine configs.
if engine_config is None:
vllm_config = engine_args.create_engine_config()
else:
vllm_config = engine_config
executor_class = cls._get_executor_cls(vllm_config)
# Create the AsyncLLM.
return cls(
vllm_config=vllm_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC."""
self.engine_core.shutdown()
if handler := getattr(self, "output_handler", None):
handler.cancel()
@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
return GPUExecutor
async def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
"""Add new request to the AsyncLLM."""
if self.detokenizer.is_request_active(request_id):
raise KeyError(f"Request {request_id} already exists.")
# 1) Create a new AsyncStream for the request.
stream = self._add_request_to_streams(request_id)
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
detokenizer_req, engine_core_req = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
trace_headers, prompt_adapter_request, priority)
# 3) Add the request to Detokenizer (this process).
self.detokenizer.add_request(detokenizer_req)
# 4) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(engine_core_req)
# 5) Return the generator.
return stream.generator()
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
# 2) Processing the Input.
* 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process).
A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the
per-request AsyncStream.
The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller.
"""
# We start the output_handler on the first call to generate() so that
# we can call __init__ before the event loop starts, which enables us
# to handle startup failure gracefully in the OpenAI server.
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())
async for output in await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
):
yield output
def _finish_stream(self, request_id: str):
stream = self.request_streams.pop(request_id, None)
if stream is not None:
stream.finish()
def _add_request_to_streams(
self,
request_id: str,
) -> AsyncStream:
if request_id in self.request_streams:
raise ValueError(f"Request id {request_id} already running.")
# Avoid streams having circular ref to parent AsyncLLM object.
aborted_reqs = self.client_aborted_requests
stream = AsyncStream(request_id, aborted_reqs.append)
self.request_streams[request_id] = stream
if self.log_requests:
logger.info("Added request %s.", request_id)
return stream
async def _process_cancellations(self) -> None:
"""
Process requests cancelled from user disconnecting.
When a client disconnects, AsyncStream._cancel() is called.
We passed a callback to AsyncStream(), which appends to
self.client_aborted_requests.
As a result, if any requests are canceled from the user side
the request_id will show up in self.client_aborted_requests.
"""
# Avoid streams having circular ref to parent AsyncLLM object.
if not self.client_aborted_requests:
return
reqs_to_abort = self.client_aborted_requests.copy()
self.client_aborted_requests.clear()
# Remove from Detokenizer.
self.detokenizer.abort_requests(reqs_to_abort)
# Remove from RequestStreams.
for request_id in reqs_to_abort:
if self.log_requests:
logger.info("User-cancelled request %s.", request_id)
self._finish_stream(request_id)
# Remove from EngineCore.
await self.engine_core.abort_requests_async(reqs_to_abort)
def _process_request_outputs(self, request_outputs: List[RequestOutput]):
"""Process outputs by putting them into per-request AsyncStreams."""
for request_output in request_outputs:
request_id = request_output.request_id
assert request_id in self.request_streams
# Each request in the API server pulls from the per-request stream.
stream = self.request_streams.get(request_id)
if stream is not None:
stream.put(request_output)
# If finished, remove from the tracker.
if request_output.finished:
if self.log_requests:
logger.info("Finished request %s.", request_id)
self._finish_stream(request_id)
async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
try:
while True:
# 1) Pull EngineCoreOutput from the EngineCore.
outputs = await self.engine_core.get_output_async()
# 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)
# 3) Put the RequestOutputs into the per-request AsyncStreams.
self._process_request_outputs(request_outputs)
# 4) Abort any requests that finished due to stop strings.
await self.engine_core.abort_requests_async(reqs_to_abort)
# 5) Abort any requests due to client cancellations.
await self._process_cancellations()
except BaseException as e:
logger.error(e)
raise e
# TODO: can we eliminate these?
async def abort(self, request_id: str) -> None:
# Note: Who Calls this? I dont think this is actually used.
raise ValueError("Not Supported on V1 yet.")
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
):
raise ValueError("Not Supported on V1 yet.")
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def get_decoding_config(self):
raise ValueError("Not Supported on V1 yet.")
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.processor.input_preprocessor
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
assert lora_request is None
return self.detokenizer.tokenizer
async def is_tracing_enabled(self) -> bool:
return False
async def do_log_stats(
self,
scheduler_outputs=None,
model_output=None,
) -> None:
logger.debug("Called do_log_stats.")
async def check_health(self) -> None:
logger.debug("Called check_health.")
async def start_profile(self) -> None:
raise ValueError("Not supported on V1 yet.")
async def stop_profile(self) -> None:
raise ValueError("Not supported on V1 yet.")
@property
def is_running(self) -> bool:
return True
@property
def is_stopped(self) -> bool:
return False
@property
def errored(self) -> bool:
return False
@property
def dead_error(self) -> BaseException:
return Exception
# Retain V0 name for backwards compatibility.
AsyncLLMEngine = AsyncLLM

View File

@@ -0,0 +1,55 @@
import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
STOP_ITERATION = Exception() # Sentinel
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id
self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)
def finish(
self,
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
) -> None:
if not self._finished:
self._finished = True
self._queue.put_nowait(exception if self._is_raisable(exception)
else AsyncStream.STOP_ITERATION)
async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
finished = False
try:
while True:
result = await self._queue.get()
if self._is_raisable(result):
finished = True
if result == AsyncStream.STOP_ITERATION:
return
raise result
yield result
finally:
self._finished = True
if not finished:
self._cancel(self.request_id)
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or \
(isinstance(value, type) and \
issubclass(value, BaseException))

View File

@@ -0,0 +1,363 @@
import multiprocessing
import queue
import threading
import time
from contextlib import contextmanager
from multiprocessing.process import BaseProcess
from multiprocessing.sharedctypes import Synchronized
from typing import Any, Iterator, List, Tuple, Type, Union
import zmq
import zmq.asyncio
from msgspec import msgpack
from vllm.config import CacheConfig, VllmConfig
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = 5000
class EngineCore:
"""Inner loop of vLLM's Engine."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
usage_context: UsageContext,
):
# Override the configs for V1.
# FIXME
if usage_context == UsageContext.LLM_CLASS:
vllm_config.scheduler_config.max_num_seqs = 1024
vllm_config.scheduler_config.max_num_batched_tokens = 8192
elif usage_context == UsageContext.OPENAI_API_SERVER:
vllm_config.scheduler_config.max_num_seqs = 1024
vllm_config.scheduler_config.max_num_batched_tokens = 2048
# TODO (ywang96): Enable APC by default when VLM supports it.
if not vllm_config.model_config.is_multimodal_model:
vllm_config.cache_config.enable_prefix_caching = True
assert vllm_config.model_config.task != "embedding"
logger.info("Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
# Setup Model.
self.model_executor = executor_class(vllm_config)
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
vllm_config.cache_config)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
# Set up multimodal input mapper (e.g., convert PIL images to tensors).
self.mm_input_mapper = MMInputMapper(vllm_config.model_config)
# Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config,
vllm_config.lora_config)
self._last_logging_time = time.time()
def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]:
num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
)
if cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_gpu_blocks,
num_gpu_blocks_override)
num_gpu_blocks = num_gpu_blocks_override
num_cpu_blocks = 0
self.model_executor.initialize_cache(num_gpu_blocks)
return num_gpu_blocks, num_cpu_blocks
def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""
req = Request.from_engine_core_request(request)
# FIXME(woosuk): The input mapping (e.g., PIL images to tensors) may
# take 10-50 ms, which can cause a spike in the latency. We should
# consider moving this to a separate thread.
if req.mm_data:
req.mm_inputs = self.mm_input_mapper.process_inputs(
req.mm_data, req.mm_processor_kwargs)
self.scheduler.add_request(req)
def abort_requests(self, request_ids: List[str]):
"""Abort requests from the scheduler."""
# TODO: The scheduler doesn't really need to know the
# specific finish reason, TBD whether we propagate that
# (i.e. client-aborted vs stop criteria met).
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)
def step(self) -> List[EngineCoreOutput]:
"""Schedule, execute, and make output."""
if not self.scheduler.has_unfinished_requests():
return []
scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output)
return engine_core_outputs
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
READY_STR = "READY"
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
usage_context: UsageContext,
input_path: str,
output_path: str,
ready_path: str,
should_shutdown: Synchronized,
):
super().__init__(vllm_config, executor_class, usage_context)
# Signal from main process to shutdown (multiprocessing.Value).
self.should_shutdown = should_shutdown
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = queue.Queue()
self.output_queue = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
daemon=True).start()
threading.Thread(target=self.process_output_socket,
args=(output_path, ),
daemon=True).start()
# Send Readiness signal to EngineClient.
with self.make_socket(ready_path, zmq.constants.PUSH) as ready_socket:
ready_socket.send_string(EngineCoreProc.READY_STR)
@contextmanager
def make_socket(self, path: str, type: Any) -> Iterator[zmq.Socket]:
"""Context manager for use """
ctx = zmq.Context()
try:
socket = ctx.socket(type)
if type == zmq.constants.PULL:
socket.connect(path)
elif type == zmq.constants.PUSH:
socket.bind(path)
else:
raise ValueError(f"Unknown Socket Type: {type}")
yield socket
except KeyboardInterrupt:
logger.debug("EngineCore had Keyboard Interrupt.")
finally:
ctx.destroy(linger=0)
@staticmethod
def wait_for_startup(
proc: BaseProcess,
ready_path: str,
) -> None:
"""Wait until the EngineCore is ready."""
try:
sync_ctx = zmq.Context() # type: ignore[attr-defined]
socket = sync_ctx.socket(zmq.constants.PULL)
socket.connect(ready_path)
# Wait for EngineCore to send EngineCoreProc.READY_STR.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for EngineCoreProc to startup.")
if not proc.is_alive():
raise RuntimeError("EngineCoreProc failed to start.")
message = socket.recv_string()
assert message == EngineCoreProc.READY_STR
except BaseException as e:
logger.exception(e)
raise e
finally:
sync_ctx.destroy(linger=0)
@staticmethod
def make_engine_core_process(
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
usage_context: UsageContext,
input_path: str,
output_path: str,
ready_path: str,
should_shutdown: Synchronized,
) -> BaseProcess:
# The current process might have CUDA context,
# so we need to spawn a new process.
# NOTE(rob): this is a problem for using EngineCoreProc w/
# LLM, since we need a if __name__ == "__main__" guard.
context = multiprocessing.get_context("spawn")
process_kwargs = {
"input_path": input_path,
"output_path": output_path,
"ready_path": ready_path,
"vllm_config": vllm_config,
"executor_class": executor_class,
"usage_context": usage_context,
"should_shutdown": should_shutdown
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=EngineCoreProc.run_engine_core,
kwargs=process_kwargs)
proc.start()
# Wait for startup
EngineCoreProc.wait_for_startup(proc, ready_path)
return proc
@staticmethod
def run_engine_core(*args, **kwargs):
"""Launch EngineCore busy loop in background process."""
try:
engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop()
except KeyboardInterrupt:
logger.debug("EngineCore interrupted.")
except BaseException as e:
logger.exception(e)
raise e
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
# Loop until we get a shutdown signal.
while not self.should_shutdown:
# 1) Poll the input queue until there is work to do.
if not self.scheduler.has_unfinished_requests():
while True:
try:
req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
self._handle_client_request(req)
break
except queue.Empty:
self._log_stats()
logger.debug("EngineCore busy loop waiting.")
if self.should_shutdown:
return
# 2) Handle any new client requests (Abort or Add).
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(req)
# 3) Step the engine core.
outputs = self.step()
# 4) Put EngineCoreOutputs into the output queue.
self.output_queue.put_nowait(outputs)
self._log_stats()
def _log_stats(self):
"""Log basic stats every LOGGING_TIME_S"""
now = time.time()
if now - self._last_logging_time > LOGGING_TIME_S:
logger.info(
"RUNNING: %s | WAITING: %s",
len(self.scheduler.running),
len(self.scheduler.waiting),
)
self._last_logging_time = now
def _handle_client_request(
self, request: Union[EngineCoreRequest, List[str]]) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
if isinstance(request, EngineCoreRequest):
self.add_request(request)
else:
# TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list)
self.abort_requests(request)
def process_input_socket(self, input_path: str):
"""Input socket IO thread."""
# Msgpack serialization decoding.
decoder_add_req = PickleEncoder()
decoder_abort_req = PickleEncoder()
with self.make_socket(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
request_type = type_frame.buffer
request_data = data_frame.buffer
# Deserialize the request data.
if request_type == EngineCoreRequestType.ADD.value:
request = decoder_add_req.decode(request_data)
elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data)
else:
raise ValueError(f"Unknown RequestType: {request_type}")
# Push to input queue for core busy loop.
self.input_queue.put_nowait(request)
def process_output_socket(self, output_path: str):
"""Output socket IO thread."""
# Msgpack serialization encoding.
encoder = msgpack.Encoder()
# Reuse send buffer.
buffer = bytearray()
with self.make_socket(output_path, zmq.constants.PUSH) as socket:
while True:
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ), copy=False)

View File

@@ -0,0 +1,219 @@
import multiprocessing
import time
from typing import List, Union
import msgspec
import zmq
import zmq.asyncio
from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.serial_utils import PickleEncoder
logger = init_logger(__name__)
class EngineCoreClient:
"""
EngineCoreClient: subclasses handle different methods for pushing
and pulling from the EngineCore for asyncio / multiprocessing.
Subclasses:
* InprocClient: In process EngineCore (for V0-style LLMEngine use)
* SyncMPClient: ZMQ + background proc EngineCore (for LLM)
* AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM)
"""
@staticmethod
def make_client(
*args,
multiprocess_mode: bool,
asyncio_mode: bool,
**kwargs,
) -> "EngineCoreClient":
# TODO: support this for debugging purposes.
if asyncio_mode and not multiprocess_mode:
raise NotImplementedError(
"Running EngineCore in asyncio without multiprocessing "
"is not currently supported.")
if multiprocess_mode and asyncio_mode:
return AsyncMPClient(*args, **kwargs)
if multiprocess_mode and not asyncio_mode:
return SyncMPClient(*args, **kwargs)
return InprocClient(*args, **kwargs)
def shutdown(self):
pass
def get_output(self) -> List[EngineCoreOutput]:
raise NotImplementedError
def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError
def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError
async def get_output_async(self) -> List[EngineCoreOutput]:
raise NotImplementedError
async def add_request_async(self, request: EngineCoreRequest) -> None:
raise NotImplementedError
async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError
class InprocClient(EngineCoreClient):
"""
InprocClient: client for in-process EngineCore. Intended
for use in LLMEngine for V0-style add_request() and step()
EngineCore setup in this process (no busy loop).
* pushes EngineCoreRequest directly into the EngineCore
* pulls EngineCoreOutputs by stepping the EngineCore
TODO: support asyncio-mode for debugging.
"""
def __init__(self, *args, **kwargs):
self.engine_core = EngineCore(*args, **kwargs)
def get_output(self) -> List[EngineCoreOutput]:
return self.engine_core.step()
def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request)
def abort_requests(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids)
class MPClient(EngineCoreClient):
"""
MPClient: base client for multi-proc EngineCore.
EngineCore runs in a background process busy loop, getting
new EngineCoreRequests and returning EngineCoreOutputs
* pushes EngineCoreRequests via input_socket
* pulls EngineCoreOutputs via output_socket
* AsyncMPClient subclass for AsyncLLM usage
* SyncMPClient subclass for LLM usage
"""
def __init__(
self,
*args,
asyncio_mode: bool,
**kwargs,
):
# Serialization setup.
self.encoder = PickleEncoder()
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
# ZMQ setup.
self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context())
# Path for IPC.
ready_path = get_open_zmq_ipc_path()
output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
# Get output (EngineCoreOutput) from EngineCore.
self.output_socket = self.ctx.socket(zmq.constants.PULL)
self.output_socket.connect(output_path)
# Send input (EngineCoreRequest) to EngineCore.
self.input_socket = self.ctx.socket(zmq.constants.PUSH)
self.input_socket.bind(input_path)
# Start EngineCore in background process.
self.should_shutdown = multiprocessing.Value('b', False, lock=False)
self.proc = EngineCoreProc.make_engine_core_process(
*args,
input_path=input_path,
output_path=output_path,
ready_path=ready_path,
should_shutdown=self.should_shutdown,
**kwargs,
)
def shutdown(self):
# Send shutdown signal to background process.
self.should_shutdown = True
# Shut down the zmq context.
self.ctx.destroy(linger=0)
# Shutdown the process if needed.
if hasattr(self, "proc") and self.proc.is_alive():
self.proc.terminate()
time.sleep(5)
if self.proc.is_alive():
self.proc.kill()
def __del__(self):
self.shutdown()
class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore."""
def __init__(self, *args, **kwargs):
super().__init__(*args, asyncio_mode=False, **kwargs)
def get_output(self) -> List[EngineCoreOutput]:
(frame, ) = self.output_socket.recv_multipart(copy=False)
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
return engine_core_outputs
def _send_input(self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]]) -> None:
# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)
def add_request(self, request: EngineCoreRequest) -> None:
self._send_input(EngineCoreRequestType.ADD, request)
def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.ABORT, request_ids)
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
def __init__(self, *args, **kwargs):
super().__init__(*args, asyncio_mode=True, **kwargs)
async def get_output_async(self) -> List[EngineCoreOutput]:
frames = await self.output_socket.recv_multipart(copy=False)
engine_core_outputs = self.decoder.decode(frames[0].buffer).outputs
return engine_core_outputs
async def _send_input(
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]]) -> None:
msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False)
async def add_request_async(self, request: EngineCoreRequest) -> None:
await self._send_input(EngineCoreRequestType.ADD, request)
async def abort_requests_async(self, request_ids: List[str]) -> None:
if len(request_ids) > 0:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)

View File

@@ -0,0 +1,272 @@
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput
logger = init_logger(__name__)
@dataclass
class IncrementalDetokenizer:
# Generation data
output_text: str
tokens: List[str]
token_ids: List[int]
# Stop strings
stop: List[str]
include_stop_str_in_output: bool
# Metadata for incremental detokenization
prefix_offset: int
read_offset: int
# Parameters for detokenization
skip_special_tokens: bool
spaces_between_special_tokens: bool
output_kind: RequestOutputKind
# TODO: Probably decouple these
request_id: str
prompt: Optional[str]
prompt_token_ids: List[int]
# Tokenizer for this request
tokenizer: AnyTokenizer
# Accounting for stop string buffering
stop_buffer_length: int
_last_output_text_offset: int = 0
@property
def output_token_ids(self) -> List[int]:
assert len(self.token_ids) >= len(self.prompt_token_ids)
return self.token_ids[len(self.prompt_token_ids):]
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
request: DetokenizerRequest,
) -> "IncrementalDetokenizer":
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=request.skip_special_tokens,
)
stops = request.stop
# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
if stops and not request.include_stop_str_in_output:
stop_buffer_length = max(len(s) for s in stops) - 1
else:
stop_buffer_length = 0
return cls(
output_text="",
tokens=tokens,
# Detokenizer mutates this list, so need a unique copy.
# NOTE(Nick): could we take ownership of it though?
token_ids=request.prompt_token_ids.copy(),
stop=stops,
include_stop_str_in_output=request.include_stop_str_in_output,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=request.
spaces_between_special_tokens,
output_kind=request.output_kind,
request_id=request.request_id,
prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids,
tokenizer=tokenizer,
stop_buffer_length=stop_buffer_length,
)
def add_tokens(
self,
new_token_ids: List[int],
finish_reason: Optional[str],
stop_reason: Optional[str],
) -> Optional[RequestOutput]:
"""
Update RequestState for the request_id by:
1) Detokenize the new token ids incrementally.
2) Update the RequestOutput with the new text.
"""
# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
decoded_text = ""
for new_token_id in new_token_ids:
self.token_ids.append(new_token_id)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=self.token_ids,
prev_tokens=self.tokens,
prefix_offset=self.prefix_offset,
read_offset=self.read_offset,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.
spaces_between_special_tokens,
)
self.tokens.extend(new_tokens)
self.prefix_offset = prefix_offset
self.read_offset = read_offset
self.output_text += new_decoded_token_text
decoded_text += new_decoded_token_text
# 2) Evaluate stop criteria.
if self.stop:
stop = StopChecker.check_stop_strings(
output_text=self.output_text,
new_char_count=len(decoded_text),
stop=self.stop,
include_in_output=self.include_stop_str_in_output,
)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
self.output_text = self.output_text[:truncate_to]
finish_reason = "stop" # TODO: use constant
stop_reason = stop_str
# TODO: handle stop_token_ids here too?
# 3) Update the RequestOutput object with the new text.
finished = bool(finish_reason)
if self.output_kind == RequestOutputKind.FINAL_ONLY \
and not finished:
return None
delta = self.output_kind == RequestOutputKind.DELTA
output_text = self._get_next_output_text(finished, delta)
token_ids = new_token_ids if delta else self.output_token_ids
request_output = RequestOutput.new(
self.request_id,
self.prompt,
self.prompt_token_ids,
output_text,
token_ids,
finished,
)
if finished:
completion_output = request_output.outputs[0]
completion_output.finish_reason = finish_reason
completion_output.stop_reason = stop_reason
return request_output
def _get_next_output_text(self, finished: bool, delta: bool) -> str:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished.
buffer_length = 0 if finished else self.stop_buffer_length
if not delta:
return self.output_text[:-buffer_length] if buffer_length else (
self.output_text)
length = len(self.output_text) - buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
return self.output_text[last_offset:length]
return ""
class Detokenizer:
def __init__(self,
tokenizer_name: str,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
revision: Optional[str] = None):
# TODO: once we support LoRA, we should should pass the tokenizer
# here. We currently have two copies (this + in the LLMEngine).
self.tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
revision=revision)
# Request id -> IncrementalDetokenizer
self.request_states: Dict[str, IncrementalDetokenizer] = {}
def is_request_active(self, request_id: str):
return request_id in self.request_states
def get_num_unfinished_requests(self):
return len(self.request_states)
def has_unfinished_requests(self) -> bool:
return len(self.request_states) > 0
def abort_requests(
self,
request_ids: Iterable[str],
) -> None:
"""Remove the request_ids from the Detokenizer."""
for request_id in request_ids:
self.request_states.pop(request_id, None)
def add_request(
self,
request: DetokenizerRequest,
):
"""Add new request to the Detokenizer."""
assert (request.request_id not in self.request_states)
request_state = IncrementalDetokenizer.from_new_request(
self.tokenizer, request)
self.request_states[request.request_id] = request_state
def step(
self, encore_core_outputs: List[EngineCoreOutput]
) -> Tuple[List[RequestOutput], List[str]]:
"""Update state and request the RequestOutputs to the LLMEngine."""
request_outputs: List[RequestOutput] = []
requests_to_abort: List[str] = []
for engine_core_output in encore_core_outputs:
request_id = engine_core_output.request_id
detokenizer = self.request_states.get(request_id)
if detokenizer is None:
# Ignore output for already-aborted request.
continue
# Detokenize and update state.
request_output = detokenizer.add_tokens(
new_token_ids=engine_core_output.new_token_ids,
finish_reason=engine_core_output.finish_reason,
stop_reason=engine_core_output.stop_reason,
)
if request_output is not None:
# Add to RequestOutputs list.
request_outputs.append(request_output)
# Free completed requests.
if request_output.finished:
self.request_states.pop(request_id)
if not engine_core_output.finished:
requests_to_abort.append(request_id)
# Return to EngineClient.
return request_outputs, requests_to_abort

View File

@@ -0,0 +1,173 @@
from typing import Dict, List, Mapping, Optional, Type, Union
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.gpu_executor import GPUExecutor
logger = init_logger(__name__)
class LLMEngine:
"""Legacy LLMEngine for backwards compatibility."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
) -> None:
# TODO: Can we avoid this?
self.model_config = vllm_config.model_config
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
enable_lora=bool(vllm_config.lora_config))
self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config.model_config,
vllm_config.lora_config, self.tokenizer,
input_registry, mm_registry)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self.detokenizer = Detokenizer(
tokenizer_name=vllm_config.model_config.tokenizer,
tokenizer_mode=vllm_config.model_config.tokenizer_mode,
trust_remote_code=vllm_config.model_config.trust_remote_code,
revision=vllm_config.model_config.tokenizer_revision,
)
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(
vllm_config,
executor_class,
usage_context,
multiprocess_mode=multiprocess_mode,
asyncio_mode=False,
)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
enable_multiprocessing: bool = False,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config()
executor_class = cls._get_executor_cls(vllm_config)
if VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.")
enable_multiprocessing = True
# Create the LLMEngine.
return cls(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing)
@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
return GPUExecutor
def stop_remote_worker_execution_loop(self) -> None:
raise NotImplementedError("TP not implemented yet.")
def get_num_unfinished_requests(self) -> int:
return self.detokenizer.get_num_unfinished_requests()
def has_unfinished_requests(self) -> bool:
return self.detokenizer.has_unfinished_requests()
@classmethod
def validate_outputs(cls, outputs, output_type):
return outputs
def abort_request(self, request_ids: List[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
self.engine_core.abort_requests(request_ids)
self.detokenizer.abort_requests(request_ids)
def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
# 1) Process raw inputs into the request.
detokenizer_req, engine_core_req = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
trace_headers, prompt_adapter_request, priority)
# 2) Add the request to Detokenizer.
self.detokenizer.add_request(detokenizer_req)
# 3) Add the request to EngineCore.
self.engine_core.add_request(engine_core_req)
def step(self) -> List[RequestOutput]:
# 1) Get EngineCoreOutput from the EngineCore.
engine_core_outputs = self.engine_core.get_output()
# 2) Detokenizer the EngineCoreOutput.
request_outputs, requests_to_abort = self.detokenizer.step(
engine_core_outputs)
# 3) Abort requests that finished due to stopping criteria.
if requests_to_abort:
self.abort_request(requests_to_abort)
return request_outputs
# TODO(rob): Can we get rid of these?
def get_model_config(self):
pass
def start_profile(self):
pass
def stop_profile(self):
pass
def get_tokenizer_group(self, group_type):
pass

View File

@@ -0,0 +1,39 @@
from typing import Any, Dict, List, Optional
from vllm.config import ModelConfig
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
class MMInputMapper:
def __init__(
self,
model_config: ModelConfig,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)
def process_inputs(
self,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Dict[str, Any]],
) -> List[MultiModalKwargs]:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
# Process each image input separately so that later we can schedule
# them in a fine-grained manner.
mm_inputs: List[MultiModalKwargs] = []
num_images = len(image_inputs)
for i in range(num_images):
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[i]]},
mm_processor_kwargs=mm_processor_kwargs,
)
mm_inputs.append(mm_input)
return mm_inputs

View File

@@ -0,0 +1,168 @@
import time
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from vllm.config import LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
class Processor:
def __init__(
self,
model_config: ModelConfig,
lora_config: Optional[LoRAConfig],
tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.model_config = model_config
self.lora_config = lora_config
self.tokenizer = tokenizer
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer,
mm_registry)
self.input_processor = input_registry.create_input_processor(
model_config)
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the
# asyncio loop while this is running.
def process_inputs(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Tuple[DetokenizerRequest, EngineCoreRequest]:
# TODO(woosuk): Support embedding mode.
# TODO(woosuk): Check max_logprobs
# TODO(woosuk): Support encoder-decoder models.
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
assert priority == 0, "vLLM V1 does not support priority at the moment."
assert trace_headers is None, "vLLM V1 does not support tracing yet."
# Process inputs.
preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
processed_inputs = self.input_processor(preprocessed_inputs)
self._validate_model_inputs(processed_inputs)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = SingletonInputsAdapter(
processed_inputs["decoder"])
encoder_inputs = SingletonInputsAdapter(
processed_inputs["encoder"])
else:
decoder_inputs = SingletonInputsAdapter(processed_inputs)
encoder_inputs = None
# TODO: Impl encoder-decoder
if encoder_inputs is not None:
raise NotImplementedError
assert isinstance(params, SamplingParams)
# TODO: can we avoid cloning here in multiproc case
sampling_params = params.clone()
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
# Make Request for Detokenizer.
detokenizer_request = DetokenizerRequest(
request_id,
decoder_inputs.prompt,
decoder_inputs.prompt_token_ids,
sampling_params.skip_special_tokens,
sampling_params.spaces_between_special_tokens,
sampling_params.output_kind,
sampling_params.stop,
sampling_params.include_stop_str_in_output,
)
# Make Request for EngineCore.
engine_core_request = EngineCoreRequest(
request_id,
decoder_inputs.prompt,
decoder_inputs.prompt_token_ids,
decoder_inputs.multi_modal_data,
decoder_inputs.multi_modal_placeholders,
decoder_inputs.mm_processor_kwargs,
sampling_params,
eos_token_id,
arrival_time,
lora_request,
)
return detokenizer_request, engine_core_request
def _validate_model_inputs(self, inputs: ProcessorInputs):
if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
else:
prompt_inputs = inputs
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
if self.model_config.is_multimodal_model:
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
raise ValueError(
f"The prompt (total length {len(prompt_ids)}) is too long "
f"to fit into the model (context length {max_prompt_len}). "
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)
if config is None:
return {}
return config.to_diff_dict()

View File

View File

@@ -0,0 +1,77 @@
import os
from typing import Optional, Tuple
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_worker import Worker
logger = init_logger(__name__)
class GPUExecutor:
def __init__(self, vllm_config: VllmConfig) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.worker = self._create_worker()
self.worker.initialize()
self.worker.load_model()
def _create_worker(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Worker:
"""Return worker init args for a given rank."""
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return Worker(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d", num_gpu_blocks)
self.worker.initialize_cache(num_gpu_blocks)
self.worker.compile_or_warm_up_model()
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
output = self.worker.execute_model(scheduler_output)
return output
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return

View File

@@ -0,0 +1,37 @@
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
@dataclass
class SamplerOutput:
# [num_reqs]
sampled_token_ids: torch.Tensor
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: Optional[torch.Tensor]
# [num_reqs, max_num_logprobs + 1]
logprobs: Optional[torch.Tensor]
# TODO: Support prompt logprobs.
prompt_logprob_token_ids: Optional[torch.Tensor]
prompt_logprobs: Optional[torch.Tensor]
@dataclass
class ModelRunnerOutput:
# [num_reqs]
req_ids: List[str]
# req_id -> index
req_id_to_index: Dict[str, int]
# [num_reqs]
sampled_token_ids_cpu: torch.Tensor
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids_cpu: Optional[torch.Tensor]
# [num_reqs, max_num_logprobs + 1]
logprobs_cpu: Optional[torch.Tensor]

View File

@@ -0,0 +1,155 @@
import enum
from typing import List, Optional, Union
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.utils import ConstantList
class Request:
def __init__(
self,
request_id: str,
inputs: DecoderOnlyInputs,
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.request_id = request_id
self.inputs = SingletonInputsAdapter(inputs)
self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.metrics = RequestMetrics(arrival_time=arrival_time,
last_token_time=arrival_time,
first_scheduled_time=None,
first_token_time=None,
time_in_queue=None)
self.lora_request = lora_request
self.status = RequestStatus.WAITING
self.stop_reason: Union[int, str, None] = None
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
self.prompt = self.inputs.prompt
self.prompt_token_ids = self.inputs.prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: List[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0
# Raw multimodal data before the mm input mapper (e.g., PIL images).
self.mm_data = self.inputs.multi_modal_data
self.mm_processor_kwargs = self.inputs.mm_processor_kwargs
mm_positions = self.inputs.multi_modal_placeholders
if mm_positions:
# FIXME(woosuk): Support other modalities.
self.mm_positions = mm_positions.get("image", [])
else:
self.mm_positions = []
# Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls(
request_id=request.request_id,
inputs=token_inputs(
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
multi_modal_data=request.mm_data,
multi_modal_placeholders=request.mm_placeholders,
mm_processor_kwargs=request.mm_processor_kwargs,
),
sampling_params=request.sampling_params,
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request,
)
@property
def output_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the output_token_ids since
# all_token_ids should also be updated simultaneously.
return ConstantList(self._output_token_ids)
@property
def all_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the all_token_ids since
# output_token_ids should also be updated simultaneously
return ConstantList(self._all_token_ids)
def append_output_token_ids(
self,
token_ids: Union[int, List[int]],
) -> None:
if isinstance(token_ids, int):
token_ids = [token_ids]
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
@property
def num_tokens(self) -> int:
return len(self._all_token_ids)
@property
def num_output_tokens(self) -> int:
return len(self._output_token_ids)
def is_finished(self) -> bool:
return RequestStatus.is_finished(self.status)
def get_finished_reason(self) -> Union[str, None]:
return RequestStatus.get_finished_reason(self.status)
def has_encoder_inputs(self) -> bool:
return len(self.mm_data) > 0
@property
def num_encoder_inputs(self) -> int:
return len(self.mm_positions)
def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions)
num_tokens = self.mm_positions[input_id]["length"]
return num_tokens
class RequestStatus(enum.IntEnum):
"""Status of a request."""
WAITING = 0
RUNNING = 1
PREEMPTED = 2
# Note: anything after PREEMPTED (2) will be considered
# as a finished status.
FINISHED_STOPPED = 3
FINISHED_LENGTH_CAPPED = 4
FINISHED_ABORTED = 5
FINISHED_IGNORED = 6
@staticmethod
def is_finished(status: "RequestStatus") -> bool:
return status > RequestStatus.PREEMPTED
@staticmethod
def get_finished_reason(status: "RequestStatus") -> Union[str, None]:
return _FINISHED_REASON_MAP.get(status)
# Mapping of finished statuses to their finish reasons.
# NOTE: The ignored requests are the requests whose prompt lengths
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
RequestStatus.FINISHED_STOPPED: "stop",
RequestStatus.FINISHED_LENGTH_CAPPED: "length",
RequestStatus.FINISHED_ABORTED: "abort",
RequestStatus.FINISHED_IGNORED: "length",
}

View File

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import Dict
import torch
@dataclass
class SamplingMetadata:
temperature: torch.Tensor
all_greedy: bool
all_random: bool
top_p: torch.Tensor
top_k: torch.Tensor
no_top_p: bool
no_top_k: bool
generators: Dict[int, torch.Generator]
max_num_logprobs: int

View File

@@ -0,0 +1,158 @@
"""A layer that samples the next tokens from the model's outputs."""
from typing import Dict
import torch
import torch.nn as nn
from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
logits = self.apply_temperature(logits, sampling_metadata.temperature)
logits = self.apply_top_k_top_p(logits, sampling_metadata)
probs = self.get_probs(logits)
sampled = self.sample(probs, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
if sampling_metadata.max_num_logprobs > 0:
logprobs = self.get_logprobs(logits)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
else:
topk_logprobs = None
topk_indices = None
sampler_output = SamplerOutput(
sampled_token_ids=sampled,
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,
prompt_logprobs=None,
)
return sampler_output
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use float32 to apply temperature scaling.
logits = logits.to(torch.float32)
# Avoid division by zero.
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
# Use in-place division to avoid creating a new tensor.
logits.div_(temp.unsqueeze(dim=1))
return logits
def apply_top_k_top_p(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
return _apply_top_k_top_p(
logits,
sampling_metadata.no_top_k,
sampling_metadata.top_k,
sampling_metadata.no_top_p,
sampling_metadata.top_p,
)
def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.softmax(logits, dim=-1, dtype=torch.float32)
def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
return probs.argmax(dim=-1).view(-1)
def random_sample(
self,
probs: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
# This might still be done here unnecessarily if there are greedies
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def sample(
self,
probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_greedy:
return self.greedy_sample(probs)
if sampling_metadata.all_random:
return self.random_sample(probs, sampling_metadata.generators)
greedy_sampled = self.greedy_sample(probs)
random_sampled = self.random_sample(probs,
sampling_metadata.generators)
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
)
return sampled
# TODO(woosuk): Optimize this with a custom kernel.
def _apply_top_k_top_p(
logits: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
if no_top_k and no_top_p:
return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if not no_top_k:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if not no_top_p:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits

View File

@@ -0,0 +1,10 @@
import pickle
class PickleEncoder:
def encode(self, obj):
return pickle.dumps(obj)
def decode(self, data):
return pickle.loads(data)

View File

@@ -0,0 +1,64 @@
from typing import Generic, List, TypeVar, overload
T = TypeVar("T")
class ConstantList(Generic[T]):
def __init__(self, x: List[T]) -> None:
self._x = x
def append(self, item):
raise Exception("Cannot append to a constant list")
def extend(self, item):
raise Exception("Cannot extend a constant list")
def insert(self, item):
raise Exception("Cannot insert into a constant list")
def pop(self, item):
raise Exception("Cannot pop from a constant list")
def remove(self, item):
raise Exception("Cannot remove from a constant list")
def clear(self):
raise Exception("Cannot clear a constant list")
def index(self, item):
return self._x.index(item)
@overload
def __getitem__(self, item) -> T:
...
@overload
def __getitem__(self, s: slice, /) -> List[T]:
...
def __getitem__(self, item):
return self._x[item]
@overload
def __setitem__(self, item, value):
...
@overload
def __setitem__(self, s: slice, value, /):
...
def __setitem__(self, item, value):
raise Exception("Cannot set item in a constant list")
def __delitem__(self, item):
raise Exception("Cannot delete item from a constant list")
def __iter__(self):
return iter(self._x)
def __contains__(self, item):
return item in self._x
def __len__(self):
return len(self._x)

View File

View File

@@ -0,0 +1,879 @@
import os
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
from vllm import envs
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalKwargs
from vllm.plugins import set_compilation_config
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
if TYPE_CHECKING:
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
class GPUModelRunner:
def __init__(
self,
vllm_config: VllmConfig,
input_registry: InputRegistry = INPUT_REGISTRY,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
# Model-related.
self.num_attn_layers = model_config.get_num_attention_layers(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
# Multi-modal data support
self.input_registry = input_registry
# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: List[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.scheduler_config.max_num_seqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
)
self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)]
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
end_index = start_index + num_new_blocks
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table_cpu[
req_index, start_index:end_index] = req_data.new_block_ids
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for req_data in scheduler_output.scheduled_new_reqs:
req_id = req_data.req_id
sampling_params = req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=req_data.prompt_token_ids,
prompt=req_data.prompt,
mm_inputs=req_data.mm_inputs,
mm_positions=req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=req_data.block_ids,
num_computed_tokens=req_data.num_computed_tokens,
output_token_ids=[],
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for req_data in scheduler_output.scheduled_resumed_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = req_data.block_ids
req_state.num_computed_tokens = req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table[:num_reqs].copy_(
self.input_batch.block_table_cpu_tensor[:num_reqs],
non_blocking=True)
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = []
max_num_scheduled_tokens = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens)
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
assert max_num_scheduled_tokens > 0
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
indices = np.arange(num_reqs)
req_indices = np.repeat(indices, num_scheduled_tokens)
# Get batched arange.
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange_matrix = np.tile(np.arange(max_num_scheduled_tokens),
(num_reqs, 1))
mask = arange_matrix < num_scheduled_tokens[:, np.newaxis]
arange = arange_matrix[mask]
# Get positions.
positions = torch.empty((total_num_scheduled_tokens, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
positions_np = positions.numpy()
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
token_indices = positions_np + req_indices * self.max_model_len
token_indices = torch.from_numpy(token_indices)
input_ids = torch.empty((total_num_scheduled_tokens, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
torch.index_select(torch.from_numpy(
self.input_batch.token_ids_cpu).flatten(),
0,
token_indices,
out=input_ids)
# Calculate the slot mapping.
block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[
token_indices // self.block_size]
block_offsets = token_indices % self.block_size
slot_mapping = torch.empty((total_num_scheduled_tokens, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
torch.add(block_numbers * self.block_size,
block_offsets,
out=slot_mapping)
# Prepare the attention metadata.
query_start_loc = torch.empty((num_reqs + 1, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
query_start_loc_np = query_start_loc.numpy()
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
max_seq_len = seq_lens.max()
seq_start_loc = torch.empty((num_reqs + 1, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
seq_start_loc_np = seq_start_loc.numpy()
seq_start_loc_np[0] = 0
np.cumsum(seq_lens, out=seq_start_loc_np[1:])
input_ids = input_ids.to(self.device, non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(positions,
non_blocking=True)
query_start_loc = query_start_loc.to(self.device, non_blocking=True)
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_start_loc=seq_start_loc,
block_table=self.input_batch.block_table[:num_reqs],
slot_mapping=slot_mapping,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
# partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
return input_ids, attn_metadata, logits_indices
def _prepare_sampling(
self,
scheduler_output: "SchedulerOutput",
) -> SamplingMetadata:
skip_copy = True
if (scheduler_output.finished_req_ids
or scheduler_output.preempted_req_ids):
skip_copy = False
if (scheduler_output.scheduled_new_reqs
or scheduler_output.scheduled_resumed_reqs):
skip_copy = False
# Create the sampling metadata.
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_inputs: List[MultiModalKwargs] = []
req_input_ids: List[Tuple[int, int]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
req_input_ids.append((req_id, input_id))
batched_mm_inputs = MultiModalKwargs.batch(mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
# Run the encoder.
# `encoder_outputs` is either of the following:
# 1. A tensor of shape [num_images, feature_size, hidden_size]
# in case when feature_size is fixed across all images.
# 2. A list (length: num_images) of tensors, each of shape
# [feature_size, hidden_size] in case when the feature size is
# dynamic depending on input images.
encoder_outputs = self.model.process_mm_inputs(**batched_mm_inputs)
# Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output
def _gather_encoder_outputs(
self,
scheduler_output: "SchedulerOutput",
) -> List[torch.Tensor]:
encoder_outputs: List[torch.Tensor] = []
num_reqs = self.input_batch.num_reqs
for req_id in self.input_batch.req_ids[:num_reqs]:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# num_computed_tokens + num_scheduled_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_scheduled_tokens:
# The encoder output is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
self._update_states(scheduler_output)
# Run the encoder.
self._execute_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
# Prepare the decoder inputs.
input_ids, attn_metadata, logits_indices = self._prepare_inputs(
scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self._get_padded_batch_size(
num_scheduled_tokens)
else:
# Eager mode.
num_input_tokens = num_scheduled_tokens
# Get the inputs embeds.
if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings(
input_ids, encoder_outputs)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# NOTE(woosuk): To unify token ids and soft tokens (vision embeddings),
# always use embeddings (rather than token ids) as input to the model.
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata):
hidden_states = self.model(
input_ids=None,
positions=self.positions[:num_input_tokens],
kv_caches=self.kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds[:num_input_tokens],
)
hidden_states = hidden_states[:num_scheduled_tokens]
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None)
# Sample the next token and get logprobs if needed.
sampling_metadata = self._prepare_sampling(scheduler_output)
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
# NOTE: CPU-GPU synchronization happens here.
sampled_token_ids = sampler_output.sampled_token_ids.cpu()
sampled_token_ids_list = sampled_token_ids.tolist()
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
assert seq_len <= req_state.num_tokens
if seq_len == req_state.num_tokens:
# Append the sampled token to the output token ids.
token_id = sampled_token_ids_list[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)
else:
# Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None:
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
if sampler_output.logprob_token_ids is None:
logprob_token_ids = None
else:
logprob_token_ids = sampler_output.logprob_token_ids.cpu()
if sampler_output.logprobs is None:
logprobs = None
else:
logprobs = sampler_output.logprobs.cpu()
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids[:num_reqs],
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids_cpu=sampled_token_ids,
logprob_token_ids_cpu=logprob_token_ids,
logprobs_cpu=logprobs,
)
return model_runner_output
def load_model(self) -> None:
if self.use_cuda_graph:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
os.environ["VLLM_CUSTOM_OPS"] = "none"
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["vllm.unified_v1_flash_attention"],
use_inductor=True,
enable_fusion=False,
))
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
dummy_kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
with set_forward_context(None): # noqa: SIM117
with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape.
model(input_ids=None,
positions=self.positions,
kv_caches=dummy_kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds)
@torch.inference_mode()
def profile_run(self) -> None:
# TODO(woosuk): Profile the max memory usage of the encoder and
# the encoder cache.
self._dummy_run(self.model, self.max_num_tokens)
torch.cuda.synchronize()
@torch.inference_mode()
def capture_model(self) -> None:
if not self.use_cuda_graph:
logger.warning(
"Skipping CUDA graph capture. Please set "
"VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.",
CompilationLevel.PIECEWISE)
return
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with set_forward_context(None):
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for num_tokens in reversed(self.cudagraph_batch_sizes):
self.model(
input_ids=None,
positions=self.positions[:num_tokens],
kv_caches=self.kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds[:num_tokens],
)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
def initialize_kv_cache(self, num_blocks: int) -> None:
assert len(self.kv_caches) == 0
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
for _ in range(self.num_attn_layers):
self.kv_caches.append(
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
# TODO: Optimize this?
for size in self.cudagraph_batch_sizes:
if batch_size <= size:
return size
return None
@dataclass
class CachedRequestState:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List[MultiModalKwargs]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
generator: Optional[torch.Generator]
block_ids: List[int]
num_computed_tokens: int
output_token_ids: List[int]
@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.device = device
self.pin_memory = pin_memory
self.req_ids: List[Optional[str]] = [None] * max_num_reqs
self.req_id_to_index: Dict[str, int] = {}
self.token_ids_cpu = np.empty((max_num_reqs, max_model_len),
dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
# Attention-related.
self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32)
self.block_table_cpu_tensor = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_cpu = self.block_table_cpu_tensor.numpy()
# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: Set[str] = set()
self.random_reqs: Set[str] = set()
self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: Set[str] = set()
self.top_k = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device=device)
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set()
# req_index -> generator
self.generators: Dict[int, torch.Generator] = {}
self.num_logprobs: Dict[str, int] = {}
self.prompt_logprob_reqs: Set[str] = set()
def add_request(
self,
request: "CachedRequestState",
req_index: Optional[int] = None,
) -> None:
if req_index is None:
req_index = self.num_reqs
assert req_index < self.max_num_reqs
req_id = request.req_id
self.req_ids[req_index] = req_id
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
num_blocks = len(request.block_ids)
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
self.greedy_reqs.add(req_id)
else:
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0:
self.top_k_reqs.add(req_id)
self.generators[req_index] = request.generator
num_logprobs = sampling_params.logprobs
if num_logprobs is not None and num_logprobs > 0:
self.num_logprobs[req_id] = num_logprobs
if sampling_params.prompt_logprobs:
self.prompt_logprob_reqs.add(req_id)
def remove_request(self, req_id: str) -> Optional[int]:
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self.req_ids[req_index] = None
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id)
return req_index
def clear(self) -> None:
self.req_ids = [None] * self.max_num_reqs
self.req_id_to_index.clear()
self.greedy_reqs.clear()
self.random_reqs.clear()
self.top_p_reqs.clear()
self.top_k_reqs.clear()
self.generators.clear()
self.num_logprobs.clear()
self.prompt_logprob_reqs.clear()
def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0:
# The batched states are empty.
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = self.num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = empty_req_indices.pop()
if empty_index >= last_req_index:
break
# Swap the states.
req_id = self.req_ids[last_req_index]
self.req_ids[empty_index] = req_id
self.req_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
# TODO(woosuk): Optimize the copy of token_ids_cpu and
# block_table_cpu.
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table_cpu[empty_index] = self.block_table_cpu[
last_req_index]
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
# Decrement last_req_index since it is now empty.
last_req_index -= 1
def make_sampling_metadata(
self,
skip_copy: bool = False,
) -> SamplingMetadata:
if not skip_copy:
self.temperature[:self.num_reqs].copy_(
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
self.top_p[:self.num_reqs].copy_(
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
self.top_k[:self.num_reqs].copy_(
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
return SamplingMetadata(
temperature=self.temperature[:self.num_reqs],
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=self.top_p[:self.num_reqs],
top_k=self.top_k[:self.num_reqs],
no_top_p=self.no_top_p,
no_top_k=self.no_top_k,
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def max_num_logprobs(self) -> int:
return max(self.num_logprobs.values()) if self.num_logprobs else 0
@property
def no_logprob(self) -> bool:
return len(self.num_logprobs) == 0
@property
def no_prompt_logprob(self) -> bool:
return len(self.prompt_logprob_reqs) == 0

View File

@@ -0,0 +1,229 @@
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.distributed
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
class Worker:
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
):
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = GPUModelRunner(vllm_config)
def initialize(self):
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self) -> None:
self.model_runner.load_model()
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = _get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
# if self.model_runner.lora_manager:
# self.model_runner.remove_all_loras()
gc.collect()
torch.cuda.empty_cache()
return num_gpu_blocks, 0
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks."""
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_gpu_blocks
max_model_len = self.model_config.max_model_len
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
self.model_runner.initialize_kv_cache(num_gpu_blocks)
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
output = self.model_runner.execute_model(scheduler_output)
# TODO(woosuk): Send the output to the engine process.
return output
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: # noqa: SIM102
if not current_platform.has_device_capability(80):
capability = current_platform.get_device_capability()
gpu_name = current_platform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.")
def _get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_attention_layers(
parallel_config)
key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
dtype_size = get_dtype_size(dtype)
return dtype_size * total