This commit is contained in:
2026-01-09 13:34:11 +08:00
parent dfa6476b58
commit b2ef04d792
538 changed files with 105693 additions and 2 deletions

25
vllm/__init__.py Normal file
View File

@@ -0,0 +1,25 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
__version__ = "0.4.2"
__all__ = [
"LLM",
"ModelRegistry",
"SamplingParams",
"RequestOutput",
"CompletionOutput",
"LLMEngine",
"EngineArgs",
"AsyncLLMEngine",
"AsyncEngineArgs",
"initialize_ray_cluster",
]

251
vllm/_custom_ops.py Normal file
View File

@@ -0,0 +1,251 @@
from typing import Dict, Optional, Tuple
import torch
try:
from vllm_C import cache_ops as vllm_cache_ops
from vllm_C import ops as vllm_ops
except ImportError:
pass
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.silu_and_mul(out, x)
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_and_mul(out, x)
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_tanh_and_mul(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_fast(out, x)
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_new(out, x)
# page attention ops
def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, kv_scale)
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size,
max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale)
# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
is_neox)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
vllm_ops.batched_rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
vllm_ops.rms_norm(out, input, weight, epsilon)
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon)
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx,
thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
vllm_ops.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
lookup_table: torch.Tensor) -> None:
vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table)
# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
size_n, size_k)
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return vllm_ops.aqlm_gemm(input, codes, codebooks, scales,
codebook_partition_sizes, bias)
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes)
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits)
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int, size_k: int,
is_k_full: bool) -> torch.Tensor:
return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
workspace, num_bits, size_m, size_n,
size_k, is_k_full)
# fp8
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
else:
vllm_ops.static_scaled_fp8_quant(output, input, scale)
return output, scale
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None:
vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, kv_scale)
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
) -> None:
vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
block_mapping: torch.Tensor) -> None:
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: Dict[int, int]) -> None:
vllm_cache_ops.swap_blocks(src, dst, block_mapping)
def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None:
vllm_cache_ops.convert_fp8(output, input)
#TODO: cuda_utils, custom_ar

View File

@@ -0,0 +1,13 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
__all__ = [
"AttentionBackend",
"AttentionMetadata",
"Attention",
"get_attn_backend",
"AttentionMetadataPerStage",
]

View File

View File

@@ -0,0 +1,127 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar)
import torch
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
raise NotImplementedError
@staticmethod
@abstractmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
@abstractmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
raise NotImplementedError
@dataclass
class AttentionMetadataPerStage:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self) if field.name not in skip_fields
}
T = TypeVar("T", bound=AttentionMetadataPerStage)
@dataclass
class AttentionMetadata(Generic[T]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# The attention metadata for prefill requests in a batch.
# None if there's no prefill requests in a batch.
prefill_metadata: Optional[T]
# The attention metadata for decode requests in a batch.
# None if there's no decode requests in a batch.
decode_metadata: Optional[T]
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# The kv cache's data type.
kv_cache_dtype: str
def __post_init__(self):
if self.num_prefill_tokens > 0:
assert self.num_prefills > 0
assert self.prefill_metadata is not None
if self.num_decode_tokens > 0:
assert self.decode_metadata is not None
class AttentionImpl(ABC):
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
raise NotImplementedError
@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_scale: float,
) -> torch.Tensor:
raise NotImplementedError

View File

@@ -0,0 +1,283 @@
"""Attention layer with Flash and PagedAttention.
NOTE(woosuk): At the moment, this file includes a lot of duplicated code from
XFormers backend. The duplicated code will be removed once we use flash-attn or
flashinfer for all the attention operations.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
import torch_musa
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
return FlashAttentionMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class FlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
class FlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = -1
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# enable musa flash attention
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(True)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
query = query.movedim(0, query.dim() - 2).unsqueeze(0)
key = key.movedim(0, key.dim() - 2).unsqueeze(0)
value = value.movedim(0, value.dim() - 2).unsqueeze(0)
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
tensor = torch.full(
(1, 1, num_tokens, num_tokens),
dtype=torch.bool,
fill_value=1,
device=query.device)
att_mask = torch.tril(tensor, diagonal=0)
# Prompt run.
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
attn_output = scaled_dot_product_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=att_mask.contiguous(),
dropout_p=0.0,
is_causal=False,
)
attn_output = attn_output.squeeze(0).permute(1, 0, 2).contiguous()
assert output[:num_prefill_tokens].shape == attn_output.shape
output[:num_prefill_tokens] = attn_output
else:
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)

View File

@@ -0,0 +1,220 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type
try:
import flashinfer
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
except ImportError:
flashinfer = None
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
class FlashInferBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
return FlashInferMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
raise NotImplementedError
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
raise NotImplementedError
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 128, 256]
@dataclass
class FlashInferMetadata(AttentionMetadataPerStage):
is_prompt: bool
use_cuda_graph: bool = False
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
# Metadata for the prefill stage since we still
# use flash attention for prefill.
seq_start_loc: Optional[torch.Tensor] = None
max_seq_len: Optional[int] = None
block_tables: Optional[torch.Tensor] = None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None
# The number of query/output heads
num_qo_heads: Optional[int] = None
# The number of key/value heads
num_kv_heads: Optional[int] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
# Block size of vllm
page_size: Optional[int] = None
# The data type of the paged kv cache
data_type: torch.dtype = None
def __post_init__(self):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if not self.is_prompt:
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD")
self.decode_wrapper.begin_forward(
self.paged_kv_indptr,
self.paged_kv_indices,
self.paged_kv_last_page_len,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
data_type=self.data_type)
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
if skip_fields is None:
skip_fields = set()
# We need to skip the decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)
class FlashInferImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
if sliding_window is not None:
raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1)
self.alibi_slopes = alibi_slopes
self.scale = scale
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
def forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[FlashInferMetadata],
kv_scale: float):
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if attn_metadata.num_prefill_tokens > 0:
assert attn_metadata.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if kv_cache is not None:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
attn_metadata.kv_cache_dtype,
)
if prefill_meta := attn_metadata.prefill_metadata:
assert prefill_meta.block_tables is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
raise NotImplementedError(
"Prefix caching is not supported with flashinfer yet.")
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
query = query.contiguous(
) # Flashinfer requires query to be contiguous
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
sm_scale=self.scale,
)
return output.view(num_tokens, hidden_size)

View File

@@ -0,0 +1,374 @@
"""Attention layer ROCm GPUs."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
logger = init_logger(__name__)
class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
return ROCmFlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
return ROCmFlashAttentionMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
class ROCmFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
else:
# if not using triton, navi3x not use flash-attn either
if torch.cuda.get_device_capability()[0] == 11:
self.use_naive_attn = True
else:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True
if self.use_naive_attn:
self.attn_func = _naive_attention
logger.debug("Using naive attention in ROCmBackend")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens, n_kv_heads, head_dim = x.shape
return (x[:, :,
None, :].expand(tokens, n_kv_heads, n_rep,
head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim))
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale,
)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.seq_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if self.use_triton_flash_attn:
out, _ = self.attn_func(
query,
key,
value,
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_seq_len,
prefill_meta.max_seq_len,
True,
self.scale,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
out = self.attn_func(
query,
key,
value,
prefill_meta.seq_lens,
self.scale,
)
else:
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
)
# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
def _naive_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
seq_lens: List[int],
scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
start = 0
for _, seq_len in enumerate(seq_lens):
end = start + seq_len
out = _naive_masked_attention(
query[start:end],
key[start:end],
value[start:end],
scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
start += seq_len
return output
def _naive_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
seq_len, head_size, head_dim = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out

View File

@@ -0,0 +1,253 @@
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
import torch_musa
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
class TorchSDPABackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
return TorchSDPAMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class TorchSDPAMetadata(PagedAttentionMetadata, AttentionMetadata):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
seq_lens: Optional[List[int]]
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None
class TorchSDPABackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
assert len(alibi_slopes) == num_heads
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(True)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=1)
if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.prefill_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.prefill_metadata.seq_lens)
attn_metadata.prefill_metadata.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2).unsqueeze(0)
key = key.movedim(0, key.dim() - 2).unsqueeze(0)
value = value.movedim(0, value.dim() - 2).unsqueeze(0)
start = 0
output = torch.empty(
(1, num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for seq_len, mask in zip(attn_metadata.prefill_metadata.seq_lens,
attn_metadata.prefill_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[:, :, start:end, :],
key[:, :, start:end, :],
value[:, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).movedim(query.dim() - 2, 0)
output[start:end, :, :] = sub_out[0]
start = end
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")
else:
# Decoding run.
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.seq_lens_tensor,
attn_metadata.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: List[int],
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype))
return attn_biases
def _make_sliding_window_bias(
seq_lens: List[int],
window_size: Optional[int],
dtype: torch.dtype,
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = torch.log(mask)
attn_biases.append(mask.to(dtype))
return attn_biases

View File

@@ -0,0 +1,393 @@
"""Attention layer with xFormers and PagedAttention."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (AttentionBias,
BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
logger = init_logger(__name__)
class XFormersBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "XFormersMetadata":
return XFormersMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch.
max_query_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None
class XFormersImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[XFormersMetadata],
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
out = self._run_memory_efficient_xformers_forward(
query, key, value, prefill_meta)
assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
out = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
if decode_meta := attn_metadata.decode_metadata:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _run_memory_efficient_xformers_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: XFormersMetadata,
) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
See https://facebookresearch.github.io/xformers/components/ops.html
for API spec.
Args:
output: shape = [num_prefill_tokens, num_heads, head_size]
query: shape = [num_prefill_tokens, num_heads, head_size]
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
assert attn_metadata.seq_lens is not None
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
# Note that the output also has the same shape (which is different
# from a spec from the doc).
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if attn_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.seq_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
attn_metadata.attn_bias = [attn_bias]
else:
attn_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, query.dtype,
attn_metadata.seq_lens)
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_metadata.attn_bias[0],
p=0.0,
scale=self.scale)
return out.view_as(original_query)
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(original_query)
start = 0
for i, seq_len in enumerate(attn_metadata.seq_lens):
end = start + seq_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=attn_metadata.attn_bias[i],
p=0.0,
scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.view_as(original_query[start:end]))
start += seq_len
return output
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_lens: List[int],
) -> LowerTriangularMaskWithTensorBias:
attn_biases = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias = bias[None, :] - bias[:, None]
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
return attn_biases

56
vllm/attention/layer.py Normal file
View File

@@ -0,0 +1,56 @@
"""Attention layer."""
from typing import List, Optional
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.selector import get_attn_backend
class Attention(nn.Module):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
self.backend = get_attn_backend(torch.get_default_dtype())
impl_cls = self.backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
kv_scale: float = 1.0,
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
kv_scale)
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
s += f", num_heads={self.impl.num_heads}" # type: ignore
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
s += f", scale={self.impl.scale}" # type: ignore
return s

View File

View File

@@ -0,0 +1,216 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
kv_scale,
)
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
kv_scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
block_tables,
# subquery_start_loc is (batch_size + 1,)
subquery_start_loc[:-1],
seq_lens_tensor,
context_lens,
max_query_len,
alibi_slopes,
sliding_window,
)
return output
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)

View File

@@ -0,0 +1,792 @@
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
import torch
import triton
import triton.language as tl
if triton.__version__ >= "2.1.0":
@triton.jit
def _fwd_kernel(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc = BLOCK_M * start_m
# initialize offsets
# [N]; starts at 0
offs_n = tl.arange(0, BLOCK_N)
# [D]; starts at 0
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
# [M]; starts at current position in query
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# [M,D]
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
dim_mask = tl.where(
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
0).to(tl.int1) # [D]
q = tl.load(Q + off_q,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len),
other=0.0) # [M,D]
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M]
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M]
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
dtype=tl.float32) # [M,D]
# compute query against context (no causal mask here)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0) # [N]
# [D,N]
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
# [N,D]
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
-10000)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1) # [M]
p = tl.exp(qk - m_ij[:, None]) # [M,N]
l_ij = tl.sum(p, 1) # [M]
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij) # [M]
alpha = tl.exp(m_i - m_i_new) # [M]
beta = tl.exp(m_ij - m_i_new) # [M]
l_i_new = alpha * l_i + beta * l_ij # [M]
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
p = p.to(v.dtype)
acc += tl.dot(p, v)
# # update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_query_len),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
# apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
if SLIDING_WINDOW > 0:
qk = tl.where(
offs_m[:, None] -
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_query_len),
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len))
return
@triton.jit
def _fwd_kernel_flash_attn_v2(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(
Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# acc /= l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@triton.jit
def _fwd_kernel_alibi(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
Alibi_slopes,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# attn_bias[]
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(
Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = 0
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v, allow_tf32=False)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
# init alibi
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = cur_batch_ctx_len
# # init debugger
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, allow_tf32=False)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v, allow_tf32=False)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@torch.inference_mode()
def context_attention_fwd(q,
k,
v,
o,
k_cache,
v_cache,
b_loc,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
alibi_slopes=None,
sliding_window=None):
cap = torch.musa.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
num_warps = 8 if Lk <= 64 else 8
if alibi_slopes is not None:
assert Lk == Lk_padded
_fwd_kernel_alibi[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
alibi_slopes,
v_cache.shape[3],
8,
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
_fwd_kernel[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
v_cache.shape[3],
8,
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SLIDING_WINDOW=sliding_window if sliding_window is not None else 0,
num_warps=num_warps,
num_stages=1,
)
return

View File

@@ -0,0 +1,810 @@
#!/usr/bin/env python
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
(https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims
"""
import torch
import triton
import triton.language as tl
torch_dtype: tl.constexpr = torch.float16
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)
@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]
@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
stride).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
stride)
rng_keep = rng_output > dropout_p
return rng_keep
@triton.jit
def load_fn(block_ptr, first, second, pad):
if first and second:
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
elif first:
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit
def _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
actual_seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
PADDED_HEAD: tl.constexpr,
):
# loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k = load_fn(
K_block_ptr,
PADDED_HEAD,
MASK_STEPS and (n_extra_tokens != 0),
"zero",
)
if PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
# if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M],
actual_seqlen_k,
dtype=tl.int32)
size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf"))
if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ----
qk += tl.dot(q, k)
if bias_ptr is not None:
bias = load_fn(bias_ptr, False, MASK_STEPS
and (n_extra_tokens != 0), "zero")
# While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk += bias * 1.44269504089
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
philox_offset = (batch_philox_offset +
start_m * BLOCK_M * actual_seqlen_k + start_n -
BLOCK_N)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p,
-p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
# -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
if not PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
# -- update m_i and l_i
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, BLOCK_N))
return acc, l_i, m_i
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": True,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
)
@triton.jit
def attn_fwd(
Q,
K,
V,
bias,
sm_scale,
L,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
stride_bz,
stride_bh,
stride_bm,
stride_bn,
cu_seqlens_q,
cu_seqlens_k,
dropout_p,
philox_seed,
philox_offset_base,
encoded_softmax,
HQ: tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
else:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if IS_CAUSAL:
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn
# matrix
n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is
# part of the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
off_h_q * stride_oh)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0
# for these masked blocks.
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
# tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel.
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
cu_seqlens_q_start * stride_qm)
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
cu_seqlens_k_start * stride_kn)
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
cu_seqlens_k_start * stride_vk)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
if BIAS_TYPE != 0:
bias_ptr = tl.make_block_ptr(
base=bias + off_h_q * stride_bh,
shape=(seqlen_q, seqlen_k),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base \
+ (off_z * HQ + off_h_q) \
* seqlen_q * seqlen_k
else:
batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout.
# In this case, we return an invalid pointer so indicate the mask is not i
# valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.make_block_ptr(
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
shape=(seqlen_q, seqlen_k),
strides=(seqlen_k, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
encoded_softmax_block_ptr = 0
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale = sm_scale * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q = load_fn(Q_block_ptr, True, padded_head, "zero")
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
# block. In this case we might exceed n_blocks so pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min,
block_max,
0,
0,
0,
bias_ptr,
# IS_CAUSAL, ....
False,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
False,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
padded_head,
)
block_min = block_max
block_max = n_blocks * BLOCK_N
tl.debug_barrier()
# Remaining blocks, if any, are full / not masked.
if masked_blocks > 0:
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, n_full_blocks))
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
True,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
padded_head,
)
# epilogue
acc = acc / l_i[:, None]
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
causal_start_idx,
dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None] >=
out_mask_boundary[None, :])
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
# overflow_size = end_m_idx - seqlen_q
# if overflow_size > 0:
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# # This is a > check because mask being 0 blocks the store.
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
# else:
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
off_h_q * stride_oh)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
def check_args(
q,
k,
v,
o,
varlen=True,
max_seqlens=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
):
assert q.dim() == k.dim() and q.dim() == v.dim()
if varlen:
assert q.dim() == 3
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
assert cu_seqlens_q is not None
assert cu_seqlens_k is not None
assert len(cu_seqlens_q) == len(cu_seqlens_k)
else:
assert q.dim() == 4
batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
assert max_seqlens > 0
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
class _attention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
):
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
check_args(
q,
k,
v,
o,
varlen=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32.
unpadded_head_dims = {32, 64, 128, 256}
if head_size not in unpadded_head_dims:
padded_d_model = None
for i in unpadded_head_dims:
if i > head_size:
padded_d_model = i
break
assert padded_d_model is not None
else:
padded_d_model = head_size
grid = lambda META: (
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
nheads_q,
batch,
)
encoded_softmax = None
# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
if bias is not None:
bias_strides = (
bias.stride(0),
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
else:
bias_strides = (0, 0, 0, 0)
attn_fwd[grid](
q,
k,
v,
bias,
sm_scale,
None,
o,
*q_strides,
*k_strides,
*v_strides,
*o_strides,
*bias_strides,
cu_seqlens_q,
cu_seqlens_k,
dropout_p=0.0,
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
IS_CAUSAL=causal,
VARLEN=True,
BLOCK_DMODEL=padded_d_model,
BIAS_TYPE=0 if bias is None else 1,
ENABLE_DROPOUT=False,
RETURN_ENCODED_SOFTMAX=False,
)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = head_size
ctx.causal = causal
ctx.dropout_p = 0.0
ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = False
return o, encoded_softmax
triton_attention = _attention.apply

View File

@@ -0,0 +1,94 @@
import enum
from functools import lru_cache
from typing import Type
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip, is_musa
logger = init_logger(__name__)
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
@lru_cache(maxsize=None)
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
backend = _which_attn_to_use(dtype)
if backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention-2 backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
elif backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
XFormersBackend)
return XFormersBackend
elif backend == _Backend.ROCM_FLASH:
logger.info("Using ROCmFlashAttention backend.")
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is enforced for the Flashinfer backend. ")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
else:
raise ValueError("Invalid attention backend.")
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
"""Returns which flash attention backend to use."""
if is_cpu():
return _Backend.TORCH_SDPA
if is_musa():
return _Backend.FLASH_ATTN
if is_hip():
# AMD GPUs.
if torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_atten is not supported on NAVI GPUs.")
return _Backend.ROCM_FLASH
# NVIDIA GPUs.
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
return _Backend.XFORMERS
if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return _Backend.XFORMERS
try:
import flash_attn # noqa: F401
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the flash_attn "
"package is not found. Please install it for better performance.")
return _Backend.XFORMERS
backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
return _Backend[backend_by_env_var]
# Default case.
return _Backend.FLASH_ATTN

84
vllm/block.py Normal file
View File

@@ -0,0 +1,84 @@
"""Token blocks."""
from typing import List
from vllm.utils import Device
_BLANK_TOKEN_ID = -1
DEFAULT_LAST_ACCESSED_TIME = -1
class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right.
Logical blocks are used to represent the states of the corresponding
physical blocks in the KV cache.
"""
def __init__(
self,
block_number: int,
block_size: int,
) -> None:
self.block_number = block_number
self.block_size = block_size
self.token_ids = [_BLANK_TOKEN_ID] * block_size
self.num_tokens = 0
def is_empty(self) -> bool:
return self.num_tokens == 0
def get_num_empty_slots(self) -> int:
return self.block_size - self.num_tokens
def is_full(self) -> bool:
return self.num_tokens == self.block_size
def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots()
curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids)
def get_token_ids(self) -> List[int]:
return self.token_ids[:self.num_tokens]
def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]
class PhysicalTokenBlock:
"""Represents the state of a block in the KV cache."""
def __init__(
self,
device: Device,
block_number: int,
block_size: int,
block_hash: int,
num_hashed_tokens: int,
) -> None:
self.device = device
self.block_number = block_number
self.block_size = block_size
self.block_hash = block_hash
self.num_hashed_tokens = num_hashed_tokens
self.ref_count = 0
self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
self.computed = False
def __repr__(self) -> str:
return (f'PhysicalTokenBlock(device={self.device}, '
f'block_number={self.block_number}, '
f'num_hashed_tokens={self.num_hashed_tokens}, '
f'ref_count={self.ref_count}, '
f'last_accessed={self.last_accessed}, '
f'computed={self.computed})')
# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]

1225
vllm/config.py Normal file

File diff suppressed because it is too large Load Diff

0
vllm/core/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,295 @@
from typing import List, Optional
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
from vllm.utils import Device, cdiv, chunk_list
class BlockTable:
"""A class to manage blocks for a specific sequence.
The BlockTable maps a sequence of tokens to a list of blocks, where each
block represents a contiguous memory allocation for a portion of the
sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
responsible for allocating and freeing memory for the blocks.
Args:
block_size (int): The maximum number of tokens that can be stored in a
single block.
block_allocator (DeviceAwareBlockAllocator): The block allocator used to
manage memory for the blocks.
_blocks (Optional[List[Block]], optional): An optional list of existing
blocks to initialize the BlockTable with. If not provided, an empty
BlockTable is created.
Attributes:
_block_size (int): The maximum number of tokens that can be stored in a
single block.
_allocator (DeviceAwareBlockAllocator): The block allocator used to
manage memory for the blocks.
_blocks (Optional[List[Block]]): The list of blocks managed by this
BlockTable.
_num_full_slots (int): The number of tokens currently stored in the
blocks.
"""
def __init__(
self,
block_size: int,
block_allocator: DeviceAwareBlockAllocator,
_blocks: Optional[List[Block]] = None,
):
self._block_size = block_size
self._allocator = block_allocator
if _blocks is None:
_blocks = []
self._blocks: List[Block] = _blocks
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
self._num_full_slots = len(self._get_all_token_ids())
@staticmethod
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
"""Calculates the minimum number of blocks required to store a given
sequence of token IDs.
This assumes worst-case scenario, where every block requires a new
allocation (e.g. ignoring prefix caching).
Args:
token_ids (List[int]): The sequence of token IDs to be stored.
block_size (int): The maximum number of tokens that can be stored in
a single block.
Returns:
int: The minimum number of blocks required to store the given
sequence of token IDs.
"""
return cdiv(len(token_ids), block_size)
def allocate(self,
token_ids: List[int],
device: Device = Device.GPU) -> None:
"""Allocates memory blocks for storing the given sequence of token IDs.
This method allocates the required number of blocks to store the given
sequence of token IDs.
Args:
token_ids (List[int]): The sequence of token IDs to be stored.
device (Device, optional): The device on which the blocks should be
allocated. Defaults to Device.GPU.
"""
assert not self._is_allocated
assert token_ids
self._blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids,
device=device)
self._num_full_slots = len(token_ids)
def append_token_ids(self,
token_ids: List[int],
num_lookahead_slots: int = 0) -> None:
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
This method appends the given sequence of token IDs to the existing
blocks in the BlockTable. If there is not enough space in the existing
blocks, new blocks are allocated using the `ensure_num_empty_slots`
method to accommodate the additional tokens.
The token IDs are divided into chunks of size `block_size` (except for
the first chunk, which may be smaller), and each chunk is appended to a
separate block.
Args:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
assert len(self._blocks) > 0
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)
blocks = self._blocks[self._num_full_slots // self._block_size:]
token_blocks = self._chunk_token_blocks_for_append(token_ids)
for block, token_block in zip(blocks, token_blocks):
block.append_token_ids(token_block)
self._num_full_slots += len(token_ids)
def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
"""Ensures that the BlockTable has at least the specified number of
empty slots available.
This method checks if the BlockTable has enough empty slots (i.e.,
available space) to accommodate the requested number of tokens. If not,
it allocates additional blocks on the GPU to ensure that the required
number of empty slots is available.
Args:
num_empty_slots (int): The minimum number of empty slots required.
"""
# Currently the block table only supports
# appending tokens to GPU blocks.
device = Device.GPU
assert self._is_allocated
if self._num_empty_slots >= num_empty_slots:
return
slots_to_allocate = num_empty_slots - self._num_empty_slots
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
for _ in range(blocks_to_allocate):
assert len(self._blocks) > 0
self._blocks.append(
self._allocator.allocate_mutable(prev_block=self._blocks[-1],
device=device))
def fork(self) -> "BlockTable":
"""Creates a new BlockTable instance with a copy of the blocks from the
current instance.
This method creates a new BlockTable instance with the same block size,
block allocator, and a copy of the blocks from the current instance. The
new BlockTable has its own independent set of blocks, but shares the
same underlying memory allocation with the original BlockTable.
Returns:
BlockTable: A new BlockTable instance with a copy of the blocks from
the current instance.
"""
assert self._is_allocated
assert len(self._blocks) > 0
forked_blocks = self._allocator.fork(self._blocks[-1])
return BlockTable(
block_size=self._block_size,
block_allocator=self._allocator,
_blocks=forked_blocks,
)
def free(self) -> None:
"""Frees the memory occupied by the blocks in the BlockTable.
This method iterates over all the blocks in the `_blocks` list and calls
the `free` method of the `_allocator` object to release the memory
occupied by each block. After freeing all the blocks, the `_blocks` list
is set to `None`.
"""
assert self._is_allocated
for block in self._blocks:
self._allocator.free(block)
self._blocks = []
@property
def physical_block_ids(self) -> List[Optional[int]]:
"""Returns a list of physical block indices for the blocks in the
BlockTable.
This property returns a list of integers, where each integer represents
the physical block index of a corresponding block in the `_blocks` list.
The physical block index is a unique identifier for the memory location
occupied by the block.
Returns:
List[int]: A list of physical block indices for the blocks in the
BlockTable.
"""
assert self._is_allocated
return [block.block_id for block in self._blocks]
def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
"""Get the number of "unseen" tokens in the sequence.
Unseen tokens are tokens in the sequence corresponding to this block
table, but are not yet appended to this block table.
Args:
sequence_token_ids (List[int]): The list of token ids in the
sequence.
Returns:
List[int]: The postfix of sequence_token_ids that has not yet been
appended to the block table.
"""
# Since the block table is append-only, the unseen token ids are the
# ones after the appended ones.
return sequence_token_ids[self.num_full_slots:]
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> List[Block]:
blocks = []
for block_token_ids in chunk_list(token_ids, self._block_size):
if len(block_token_ids) == self._block_size:
# If the block is full, create an immutable block.
prev_block = self._allocator.allocate_immutable(
prev_block, token_ids=block_token_ids, device=device)
else:
# Else, partially fill a mutable block with token ids.
prev_block = self._allocator.allocate_mutable(
prev_block=prev_block, device=device)
prev_block.append_token_ids(block_token_ids)
blocks.append(prev_block)
return blocks
def _get_all_token_ids(self) -> List[int]:
# NOTE: This function is O(seq_len); use sparingly.
token_ids: List[int] = []
if not self._is_allocated:
return token_ids
for block in self._blocks:
token_ids.extend(block.token_ids)
return token_ids
@property
def _is_allocated(self) -> bool:
return len(self._blocks) > 0
@property
def _num_empty_slots(self) -> int:
assert self._is_allocated
return len(self._blocks) * self._block_size - self._num_full_slots
@property
def num_full_slots(self) -> int:
"""Returns the total number of tokens currently stored in the
BlockTable.
Returns:
int: The total number of tokens currently stored in the BlockTable.
"""
return self._num_full_slots
def get_num_blocks_touched_by_append_slots(
self, token_ids: List[int], num_lookahead_slots: int) -> int:
"""Determine how many blocks will be "touched" by appending the token
ids.
This is required for the scheduler to determine whether a sequence can
continue generation, or if it must be preempted.
"""
all_token_ids = token_ids + [-1] * num_lookahead_slots
token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
return len(token_blocks)
def _chunk_token_blocks_for_append(
self, token_ids: List[int]) -> List[List[int]]:
"""Split the token ids into block-sized chunks so they can be easily
appended to blocks. The first such "token block" may have less token ids
than the block size, since the last allocated block may be partially
full.
"""
first_chunk_size = self._block_size - (self._num_full_slots %
self._block_size)
token_blocks = [token_ids[:first_chunk_size]] + chunk_list(
token_ids[first_chunk_size:], self._block_size)
return token_blocks

199
vllm/core/block/common.py Normal file
View File

@@ -0,0 +1,199 @@
from collections import defaultdict
from typing import Dict, Iterable, List, Optional, Protocol
from vllm.core.block.interfaces import Block, BlockAllocator
BlockId = int
RefCount = int
class RefCounterProtocol(Protocol):
def incr(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
def decr(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
def get(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
class RefCounter(RefCounterProtocol):
"""A class for managing reference counts for a set of block indices.
The RefCounter class maintains a dictionary that maps block indices to their
corresponding reference counts. It provides methods to increment, decrement,
and retrieve the reference count for a given block index.
Args:
all_block_indices (Iterable[BlockId]): An iterable of block indices
to initialize the reference counter with.
"""
def __init__(self, all_block_indices: Iterable[BlockId]):
deduped = set(all_block_indices)
self._refcounts: Dict[BlockId,
RefCount] = {index: 0
for index in deduped}
def incr(self, block_id: BlockId) -> RefCount:
assert block_id in self._refcounts
pre_incr_refcount = self._refcounts[block_id]
assert pre_incr_refcount >= 0
post_incr_refcount = pre_incr_refcount + 1
self._refcounts[block_id] = post_incr_refcount
return post_incr_refcount
def decr(self, block_id: BlockId) -> RefCount:
assert block_id in self._refcounts
refcount = self._refcounts[block_id]
assert refcount > 0
refcount -= 1
self._refcounts[block_id] = refcount
return refcount
def get(self, block_id: BlockId) -> RefCount:
assert block_id in self._refcounts
return self._refcounts[block_id]
def as_readonly(self) -> "ReadOnlyRefCounter":
return ReadOnlyRefCounter(self)
class ReadOnlyRefCounter(RefCounterProtocol):
"""A read-only view of the RefCounter class.
The ReadOnlyRefCounter class provides a read-only interface to access the
reference counts maintained by a RefCounter instance. It does not allow
modifications to the reference counts.
Args:
refcounter (RefCounter): The RefCounter instance to create a read-only
view for.
"""
def __init__(self, refcounter: RefCounter):
self._refcounter = refcounter
def incr(self, block_id: BlockId) -> RefCount:
raise ValueError("Incr not allowed")
def decr(self, block_id: BlockId) -> RefCount:
raise ValueError("Decr not allowed")
def get(self, block_id: BlockId) -> RefCount:
return self._refcounter.get(block_id)
class CopyOnWriteTracker:
"""A class for tracking and managing copy-on-write operations for blocks.
The CopyOnWriteTracker class maintains a mapping of source block indices to
their corresponding copy-on-write destination block indices. It works in
conjunction with a RefCounter and a BlockAllocator to handle reference
counting and block allocation.
Args:
refcounter (RefCounter): The reference counter used to track block
reference counts.
allocator (BlockAllocator): The block allocator used to allocate and
free blocks.
"""
def __init__(
self,
refcounter: RefCounterProtocol,
allocator: BlockAllocator,
):
self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
self._refcounter = refcounter
self._allocator = allocator
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
This method checks the reference count of the given block. If the
reference count is greater than 1, indicating that the block is shared,
a copy-on-write operation is performed. The original block is freed,
and a new block is allocated with the same content. The new block index
is returned.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
"""
block_id = block.block_id
if block_id is None:
return block_id
refcount = self._refcounter.get(block_id)
assert refcount != 0
if refcount > 1:
src_block_id = block_id
# Decrement refcount of the old block.
self._allocator.free(block)
# Allocate a fresh new block.
block_id = self._allocator.allocate_mutable(
prev_block=block.prev_block).block_id
# Track src/dst copy.
assert src_block_id is not None
assert block_id is not None
self._copy_on_writes[src_block_id].append(block_id)
return block_id
def clear_cows(self) -> Dict[BlockId, List[BlockId]]:
"""Clears the copy-on-write tracking information and returns the current
state.
This method returns a dictionary mapping source block indices to lists
of destination block indices for the current copy-on-write operations.
It then clears the internal tracking information.
Returns:
Dict[BlockId, List[BlockId]]: A dictionary mapping source
block indices to lists of destination block indices for the
current copy-on-write operations.
"""
cows = dict(self._copy_on_writes)
self._copy_on_writes.clear()
return cows
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
"""Retrieves all the blocks in a sequence starting from the last block.
This function recursively traverses the sequence of blocks in reverse order,
starting from the given last block, and returns a list of all the blocks in
the sequence.
Args:
last_block (Block): The last block in the sequence.
Returns:
List[Block]: A list of all the blocks in the sequence, in the order they
appear.
"""
def recurse(block: Block, lst: List[Block]) -> None:
if block.prev_block is not None:
recurse(block.prev_block, lst)
lst.append(block)
all_blocks: List[Block] = []
recurse(last_block, all_blocks)
return all_blocks

View File

@@ -0,0 +1,228 @@
from typing import Dict, FrozenSet, List, Optional
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
from vllm.utils import Device
class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
"""A block allocator that can allocate blocks on both CPU and GPU memory.
This class implements the `DeviceAwareBlockAllocator` interface and provides
functionality for allocating and managing blocks of memory on both CPU and
GPU devices.
The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
blocks, and allows for allocation, deallocation, forking, and swapping of
blocks across these memory pools.
"""
@staticmethod
def create(
allocator_type: str,
num_gpu_blocks: int,
num_cpu_blocks: int,
block_size: int,
) -> DeviceAwareBlockAllocator:
"""Creates a CpuGpuBlockAllocator instance with the specified
configuration.
This static method creates and returns a CpuGpuBlockAllocator instance
based on the provided parameters. It initializes the CPU and GPU block
allocators with the specified number of blocks, block size, and
allocator type.
Args:
allocator_type (str): The type of block allocator to use for CPU
and GPU blocks. Currently supported values are "naive" and
"prefix_caching".
num_gpu_blocks (int): The number of blocks to allocate for GPU
memory.
num_cpu_blocks (int): The number of blocks to allocate for CPU
memory.
block_size (int): The size of each block in number of tokens.
Returns:
DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
specified configuration.
Notes:
- The block IDs are assigned contiguously, with GPU block IDs coming
before CPU block IDs.
"""
block_ids = list(range(num_gpu_blocks + num_cpu_blocks))
gpu_block_ids = block_ids[:num_gpu_blocks]
cpu_block_ids = block_ids[num_gpu_blocks:]
if allocator_type == "naive":
gpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, # type: ignore
num_blocks=num_gpu_blocks,
block_size=block_size,
block_ids=gpu_block_ids,
)
cpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, # type: ignore
num_blocks=num_cpu_blocks,
block_size=block_size,
block_ids=cpu_block_ids,
)
elif allocator_type == "prefix_caching":
gpu_allocator = PrefixCachingBlockAllocator(
num_blocks=num_gpu_blocks,
block_size=block_size,
block_ids=gpu_block_ids,
)
cpu_allocator = PrefixCachingBlockAllocator(
num_blocks=num_cpu_blocks,
block_size=block_size,
block_ids=cpu_block_ids,
)
else:
raise ValueError(f"Unknown allocator type {allocator_type=}")
return CpuGpuBlockAllocator(
cpu_block_allocator=cpu_allocator,
gpu_block_allocator=gpu_allocator,
)
def __init__(
self,
cpu_block_allocator: BlockAllocator,
gpu_block_allocator: BlockAllocator,
):
assert not (
cpu_block_allocator.all_block_ids
& gpu_block_allocator.all_block_ids
), "cpu and gpu block allocators can't have intersection of block ids"
self._allocators = {
Device.CPU: cpu_block_allocator,
Device.GPU: gpu_block_allocator,
}
self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
for _, allocator in self._allocators.items():
for block_id in allocator.all_block_ids:
self._block_ids_to_allocator[block_id] = allocator
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
"""Allocates a new mutable block on the specified device.
Args:
prev_block (Optional[Block]): The previous block to in the sequence.
Used for prefix hashing.
device (Device): The device on which to allocate the new block.
Returns:
Block: The newly allocated mutable block.
"""
return self._allocators[device].allocate_mutable(prev_block)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
"""Allocates a new immutable block with the provided token IDs on the
specified device.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
Used for prefix hashing.
token_ids (List[int]): The list of token IDs to be stored in the new
block.
device (Device): The device on which to allocate the new block.
Returns:
Block: The newly allocated immutable block containing the provided
token IDs.
"""
return self._allocators[device].allocate_immutable(
prev_block, token_ids)
def free(self, block: Block) -> None:
"""Frees the memory occupied by the given block.
Args:
block (Block): The block to be freed.
"""
block_id = block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.free(block)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: A new list of blocks that shares the same memory as the
original sequence.
"""
block_id = last_block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.fork(last_block)
def get_num_free_blocks(self, device: Device) -> int:
"""Returns the number of free blocks available on the specified device.
Args:
device (Device): The device for which to query the number of free
blocks. AssertionError is raised if None is passed.
Returns:
int: The number of free blocks available on the specified device.
"""
return self._allocators[device].get_num_free_blocks()
def get_num_total_blocks(self, device: Device) -> int:
return self._allocators[device].get_num_total_blocks()
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs.
Returns:
Dict[int, List[int]]: A dictionary mapping source block IDs to lists
of destination block IDs.
"""
# CoW only supported on GPU
device = Device.GPU
return self._allocators[device].clear_copy_on_writes()
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].mark_blocks_as_computed(block_ids)
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].get_common_computed_block_ids(
seq_block_ids)
@property
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys())
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
raise NotImplementedError

View File

@@ -0,0 +1,205 @@
from abc import ABC, abstractmethod
from typing import Dict, FrozenSet, List, Optional, Protocol
from vllm.utils import Device
BlockId = int
class Block(ABC):
@abstractmethod
def append_token_ids(self, token_ids: List[int]) -> None:
pass
@property
@abstractmethod
def block_id(self) -> Optional[int]:
pass
@block_id.setter
@abstractmethod
def block_id(self, value: Optional[int]) -> None:
"""NOTE: Do not use this API outside Block."""
self._block_id = value
@property
@abstractmethod
def token_ids(self) -> List[int]:
pass
@property
@abstractmethod
def num_empty_slots(self) -> int:
pass
@property
@abstractmethod
def is_full(self) -> bool:
pass
@property
@abstractmethod
def prev_block(self) -> Optional["Block"]:
pass
@property
@abstractmethod
def computed(self) -> bool:
raise NotImplementedError
@computed.setter
@abstractmethod
def computed(self, value) -> bool:
"""Should be only used by PrefixCacingAllocator"""
raise NotImplementedError
@property
@abstractmethod
def last_accessed(self) -> float:
raise NotImplementedError
@last_accessed.setter
@abstractmethod
def last_accessed(self, last_accessed_ts: float):
raise NotImplementedError
class Factory(Protocol):
@abstractmethod
def __call__(
self,
prev_block: Optional["Block"],
token_ids: List[int],
block_size: int,
allocator: "BlockAllocator",
block_id: Optional[int] = None,
) -> "Block":
pass
@property
@abstractmethod
def content_hash(self) -> Optional[int]:
"""Return the content-based hash of the current block, or None if it is
not yet defined or not supported.
For the content-based hash to be defined, the current block must be
full.
"""
return None
class BlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
pass
@abstractmethod
def free(self, block: Block) -> None:
pass
@abstractmethod
def fork(self, last_block: Block) -> List[Block]:
pass
@abstractmethod
def get_num_total_blocks(self) -> int:
pass
@abstractmethod
def get_num_free_blocks(self) -> int:
pass
@property
@abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
pass
@abstractmethod
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
pass
@abstractmethod
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
pass
@abstractmethod
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
pass
@abstractmethod
def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]:
"""NOTE: This should not be used besides Block"""
pass
@abstractmethod
def promote_to_immutable_block(self, block: Block) -> BlockId:
"""NOTE: This should not be used besides Block"""
pass
class NoFreeBlocksError(ValueError):
pass
class DeviceAwareBlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
pass
@abstractmethod
def get_num_free_blocks(self, device: Device) -> int:
pass
@abstractmethod
def get_num_total_blocks(self, device: Device) -> int:
pass
@abstractmethod
def free(self, block: Block) -> None:
pass
@abstractmethod
def fork(self, last_block: Block) -> List[Block]:
pass
@property
@abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
pass
@abstractmethod
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
pass
@abstractmethod
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
pass
@abstractmethod
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
pass

View File

@@ -0,0 +1,318 @@
from typing import Dict, FrozenSet, Iterable, List, Optional, Set
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
Refcount = int
class NaiveBlockAllocator(BlockAllocator):
"""A simple block allocator that manages blocks of memory without prefix
caching.
Args:
create_block (Block.Factory): A factory function for creating new
blocks. This is used when a NaiveBlockAllocator is composed within
a prefix caching allocator -- the naive block allocator must
construct prefix caching blocks (but shouldn't know anything else
about them).
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids (Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
def __init__(
self,
create_block: Block.Factory,
num_blocks: int,
block_size: int,
block_ids: Optional[Iterable[int]] = None,
):
if block_ids is None:
block_ids = range(num_blocks)
self._free_block_indices: Set[BlockId] = set(block_ids)
self._all_block_indices = frozenset(block_ids)
assert len(self._all_block_indices) == num_blocks
self._refcounter = RefCounter(
all_block_indices=self._free_block_indices)
self._create_block = create_block
self._block_size = block_size
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly(),
allocator=self,
)
def allocate_immutable(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
Args:
prev_block (Optional[Block]): The previous block in the sequence. If
None, then the block to be allocated is the first block in the
sequence.
token_ids (List[int]): The token IDs to be stored in the new block.
Returns:
Block: The newly allocated immutable block.
"""
assert device is None
block = self.allocate_mutable(prev_block=prev_block)
block.append_token_ids(token_ids)
return block
def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block, linked to the previous block.
Args:
prev_block (Optional[Block]): The previous block in the sequence. If
None, then the block to be allocated is the first block in the
sequence.
Returns:
Block: The newly allocated mutable block.
"""
assert device is None
block_id = self._allocate_new_block_id()
return self._create_block(
prev_block=prev_block,
token_ids=[],
block_id=block_id,
block_size=self._block_size,
allocator=self,
)
def free(self, block: Block) -> None:
assert block.block_id is not None
self._free_block_id(block.block_id)
# Mark the block as having no allocation.
block.block_id = None
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: The new sequence of blocks that shares the same memory
as the original sequence.
"""
source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = []
prev_block = None
for block in source_blocks:
# Increment refcount for each block.
assert block.block_id is not None
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
forked_blocks.append(
self._create_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_id=block.block_id,
block_size=self._block_size,
allocator=self,
))
prev_block = forked_blocks[-1]
return forked_blocks
def get_num_free_blocks(self) -> int:
return len(self._free_block_indices)
def get_num_total_blocks(self) -> int:
return len(self._all_block_indices)
def _allocate_new_block_id(self) -> BlockId:
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError()
block_id = next(iter(self._free_block_indices))
self._refcounter.incr(block_id)
self._free_block_indices.remove(block_id)
return block_id
def _free_block_id(self, block_id: BlockId) -> None:
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.add(block_id)
@property
def refcounter(self):
return self._refcounter
@property
def all_block_ids(self) -> FrozenSet[int]:
return self._all_block_indices
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return self._cow_tracker.cow_block_if_not_appendable(block)
def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it.
Returns:
Dict[BlockId, List[BlockId]]: A dictionary mapping source
block indices to lists of destination block indices.
"""
return self._cow_tracker.clear_cows()
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
nothing.
"""
pass
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
nothing.
"""
pass
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
"""Determine blocks that can be skipped in prefill.
Since the naive allocator does not support prefix caching, always return
an empty list.
"""
return []
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix
caching.
The NaiveBlock class represents a block of token IDs with a fixed size. It
provides methods for appending token IDs to the block and manages copy-on
-write operations when necessary.
Args:
prev_block (Block): The previous block in the sequence.
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
allocator (BlockAllocator): The block allocator associated with this
block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None, which means no allocation has been
made.
_cow_target (Optional[Block], optional): The copy-on-write target block.
If not provided, it defaults to self.
"""
def __init__(self,
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
_cow_target: Optional[Block] = None):
self._token_ids: List[int] = []
self._block_size = block_size
self._prev_block = prev_block
self._block_id = block_id
self._allocator = allocator
self._cow_target = _cow_target if _cow_target is not None else self
self._append_token_ids_no_cow(token_ids)
def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block, instructing the allocator
to perform a copy-on-write if necessary.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
self._append_token_ids_no_cow(token_ids)
if self._block_id is not None:
self._block_id = (self._allocator.cow_block_if_not_appendable(
self._cow_target))
def _append_token_ids_no_cow(self, token_ids: List[int]) -> None:
assert self.num_empty_slots >= len(token_ids)
self._token_ids.extend(token_ids)
@property
def computed(self) -> bool:
raise NotImplementedError
@computed.setter
def computed(self, value) -> None:
raise NotImplementedError
@property
def last_accessed(self) -> float:
raise NotImplementedError
@last_accessed.setter
def last_accessed(self, last_accessed_ts: float):
raise NotImplementedError
@property
def block_id(self) -> Optional[int]:
return self._block_id
@block_id.setter
def block_id(self, value: Optional[int]) -> None:
self._block_id = value
@property
def is_full(self) -> bool:
return self.num_empty_slots == 0
@property
def num_empty_slots(self) -> int:
return self._block_size - len(self._token_ids)
@property
def token_ids(self) -> List[int]:
return self._token_ids
@property
def block_size(self) -> int:
return self._block_size
@property
def prev_block(self) -> Optional["Block"]:
return self._prev_block
@property
def content_hash(self) -> Optional[int]:
return None

View File

@@ -0,0 +1,606 @@
"""Token blocks."""
from itertools import takewhile
from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional
from vllm.core.block.common import (CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
PrefixHash = int
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
# then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME = -1
class PrefixCachingBlockAllocator(BlockAllocator):
"""A block allocator that implements prefix caching.
The PrefixCachingBlockAllocator maintains a cache of blocks based on their
content hash. It reuses blocks with the same content hash to avoid redundant
memory allocation. The allocator also supports copy-on-write operations.
Args:
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids(Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
def __init__(
self,
num_blocks: int,
block_size: int,
block_ids: Optional[Iterable[int]] = None,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
):
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0.
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
# A mapping of blockId to Block to track those cached blocks
self._blocks: Dict[BlockId, Block] = {}
# An allocator for blocks that do not have prefix hashes.
self._hashless_allocator = NaiveBlockAllocator(
create_block=self._create_block, # type: ignore
num_blocks=num_blocks,
block_size=block_size,
block_ids=block_ids,
)
self._block_size = block_size
# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
self.evictor: Evictor = make_evictor(eviction_policy)
# We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable
# blocks.
self._refcounter = self._hashless_allocator.refcounter
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly(),
allocator=self,
)
# Implements Block.Factory.
def _create_block(
self,
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
computed: bool = False,
) -> Block:
# Bind block to self.
allocator = self
return PrefixCachingBlock(
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
prefix_caching_allocator=allocator,
computed=computed,
)
def allocate_immutable(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates an immutable block with the given token IDs, reusing cached
blocks if possible.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
token_ids (List[int]): The token IDs to be stored in the block.
Returns:
Block: The allocated immutable block.
"""
assert device is None
assert_prefix_caching_block_or_none(prev_block)
block = self._create_block(
prev_block=prev_block,
token_ids=token_ids,
block_size=self._block_size,
allocator=self,
)
assert block.content_hash is not None
cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None:
block.block_id = cached_block_id
self._incr_refcount_cached_block(block, block.block_id)
return block
block = self.allocate_mutable(prev_block)
block.append_token_ids(token_ids)
assert block.content_hash is not None
return block
def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks.
Args:
prev_block (Block): The previous block in the sequence.
None is not allowed unlike it is super class.
Returns:
Block: The allocated mutable block.
"""
assert device is None
assert_prefix_caching_block_or_none(prev_block)
try:
block = self._hashless_allocator.allocate_mutable(
prev_block=prev_block)
assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block
return block
except BlockAllocator.NoFreeBlocksError:
# We must check the unused cached blocks before raising OOM.
pass
# If the evictor has blocks available for eviction, evict a block
# and return it.
if self.evictor.num_blocks > 0:
block_id, content_hash_to_evict = self.evictor.evict()
# Here we may have scenario that several blocks have
# the same content hash, but due to the latter coming block
# is coming from mutable to immutable path, their physical
# block is added into evictor.
# However in this case, we shall not pop the _cached_blocks,
# as the same content is still used by others, which means
# we need to check ref before decide to pop the list.
_block_id = self._cached_blocks[content_hash_to_evict]
refcount = self._refcounter.get(_block_id)
if refcount == 1:
self._cached_blocks.pop(content_hash_to_evict)
assert _block_id == block_id
self._refcounter.incr(block_id)
# the block comes from evictor already contain computed result
block = self._create_block(
prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
allocator=self,
block_id=block_id,
computed=True,
)
assert block.content_hash is None
assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block
return block
# No block available in hashless allocator, nor in unused cache blocks.
raise BlockAllocator.NoFreeBlocksError()
def _incr_refcount_cached_block(self, block: Block,
block_id: BlockId) -> None:
# since block is already computed, mark it
block.computed = True
refcount = self._refcounter.incr(block_id)
if refcount == 1:
# if block get referred, then it shall not be in evictor
# and put it into _blocks for tracking
if block_id in self.evictor:
self.evictor.remove(block_id)
self._blocks[block_id] = block
def free(self, block: Block) -> None:
"""Decrement the refcount of the block. If the decremented refcount is
zero, store the block in the freelist.
If the block has a content hash (meaning it is immutable), then we will
keep the block around in case future allocations require it.
"""
assert (block.block_id
is not None), "freeing unallocated block is undefined"
self._free_block_id_for_block(block.block_id, block)
block.block_id = None
def _free_block_id_for_block(self, block_id: BlockId,
block: Block) -> None:
assert isinstance(block, PrefixCachingBlock)
if block.content_hash is None:
refcount = self._refcounter.get(block_id)
# We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1
if refcount <= 1:
assert block.block_id is not None
del self._blocks[block.block_id]
return self._hashless_allocator.free(block)
refcount = self._refcounter.decr(block_id)
# If no longer used, add the block to the evictor.
if refcount == 0:
assert block.content_hash in self._cached_blocks
assert block.block_id is not None
del self._blocks[block.block_id]
self.evictor.add(block.block_id, block.content_hash,
block.num_tokens_total, block.last_accessed)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: The new sequence of blocks that shares the same memory
as the original sequence.
"""
source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = []
prev_block = None
for block in source_blocks:
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
forked_blocks.append(
self._create_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_id=block.block_id,
block_size=self._block_size,
allocator=self,
))
prev_block = forked_blocks[-1]
return forked_blocks
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
assert device is None
# The number of free blocks is the number of hashless free blocks
# plus the number of blocks evictor could free from its list.
return self._hashless_allocator.get_num_free_blocks(
) + self.evictor.num_blocks
def get_num_total_blocks(self) -> int:
return self._hashless_allocator.get_num_total_blocks()
@property
def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids
def promote_to_immutable_block(self, block: Block) -> BlockId:
"""Once a mutable block is full, it can be promoted to an immutable
block. This means that its content can be referenced by future blocks
having the same prefix.
Note that if we already have a cached block with the same content, we
will replace the newly-promoted block's mapping with the existing cached
block.
Args:
block: The mutable block to be promoted.
Returns:
BlockId: Either the original block index, or the block index of
the previously cached block matching the same content.
"""
assert block.content_hash is not None
assert block.block_id is not None
assert self._refcounter.get(block.block_id) > 0
# If the content hash does not have a corresponding cached block,
# set this block as the cached block.
if block.content_hash not in self._cached_blocks:
self._cached_blocks[block.content_hash] = block.block_id
else:
self._free_block_id_for_block(block.block_id, block)
self._incr_refcount_cached_block(
block, self._cached_blocks[block.content_hash])
return self._cached_blocks[block.content_hash]
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return self._cow_tracker.cow_block_if_not_appendable(block)
def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it.
Returns:
Dict[BlockId, List[BlockId]]: A dictionary mapping source
block indices to lists of destination block indices.
"""
return self._cow_tracker.clear_cows()
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, used in prefix caching.
If the block is added into evictor, we need to update corresponding
info in evictor's metadata.
"""
for block_id in block_ids:
if block_id in self._blocks:
self._blocks[block_id].last_accessed = now
elif block_id in self.evictor:
self.evictor.update(block_id, now)
else:
raise ValueError(
"Mark block as accessed which is not belonged to GPU")
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching."""
for block_id in block_ids:
if block_id in self._blocks:
# only those full block is valid for prefix caching
if self._blocks[block_id].is_full:
self._blocks[block_id].computed = True
elif block_id not in self.evictor:
raise ValueError(f"Mark {block_id=} as computed which "
"is not belonged to GPU")
def block_is_computed(self, block_id: int) -> bool:
if block_id in self._blocks:
return self._blocks[block_id].computed
else:
return block_id in self.evictor
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
"""Return the block ids that are common for a given sequence group.
Only those blocks that are immutable and already be marked
compyted would be taken consideration.
"""
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
ids_list = [
list(
takewhile(lambda block_id: self.block_is_computed(block_id),
seq[:-1])) for seq in seq_block_ids
]
# It returns a list of int although type annotation says list of string.
return commonprefix([
ids for ids in ids_list # type: ignore
if ids != []
])
class PrefixCachingBlock(Block):
"""A block implementation that supports prefix caching.
The PrefixCachingBlock class represents a block of token IDs with prefix
caching capabilities. It wraps a NaiveBlock internally and provides
additional functionality for content hashing and promoting immutable blocks
with the prefix caching allocator.
Args:
prev_block (Optional[PrefixCachingBlock]): The previous block in the
sequence.
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
prefix_caching_allocator (BlockAllocator): The prefix
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
"""
def __init__(
self,
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
prefix_caching_allocator: BlockAllocator,
block_id: Optional[int] = None,
computed: bool = False,
):
assert isinstance(prefix_caching_allocator,
PrefixCachingBlockAllocator), (
"Currently this class is only tested with "
"PrefixCachingBlockAllocator.")
assert_prefix_caching_block_or_none(prev_block)
self._prev_block = prev_block
self._cached_content_hash: Optional[int] = None
self._cached_num_tokens_total: Optional[int] = None
self._prefix_caching_allocator = prefix_caching_allocator
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self._computed = computed
self._block = NaiveBlock(
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
allocator=prefix_caching_allocator,
_cow_target=self,
)
@property
def computed(self) -> bool:
return self._computed
@computed.setter
def computed(self, value) -> None:
self._computed = value
@property
def last_accessed(self) -> float:
return self._last_accessed
@last_accessed.setter
def last_accessed(self, last_accessed_ts: float):
self._last_accessed = last_accessed_ts
def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block and registers the block as
immutable if the block becomes full.
Internally, the naive block handles CoW.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
assert token_ids
# naive block handles CoW.
self._block.append_token_ids(token_ids)
# If the content hash is present, then the block can be made immutable.
# Register ourselves with the allocator, potentially replacing the
# physical block index.
if self.content_hash is not None:
self.block_id = (self._prefix_caching_allocator.
promote_to_immutable_block(self))
@property
def block_id(self) -> Optional[int]:
return self._block.block_id
@block_id.setter
def block_id(self, value) -> None:
self._block.block_id = value
@property
def is_full(self) -> bool:
return self._block.is_full
@property
def num_empty_slots(self) -> int:
return self._block.num_empty_slots
@property
def num_tokens_total(self) -> int:
"""return the total tokens so far.
Here we iterate the block chain till to the first block, while
cache the result in local to prevent repeated computations.
"""
if self._cached_num_tokens_total is not None:
return self._cached_num_tokens_total
_block: Optional[Block] = self
self._cached_num_tokens_total = 0
# TODO: current implement here take O(N^2), we expect future
# we have O(1) here
while _block is not None:
self._cached_num_tokens_total += len(_block.token_ids)
_block = _block.prev_block
return self._cached_num_tokens_total
@property
def block_size(self) -> int:
return self._block.block_size
@property
def token_ids(self) -> List[int]:
return self._block.token_ids
@property
def prev_block(self) -> Optional[Block]:
return self._prev_block
@property
def content_hash(self) -> Optional[int]:
"""Return the content-based hash of the current block, or None if it is
not yet defined.
For the content-based hash to be defined, the current block must be
full.
"""
# If the hash is already computed, return it.
if self._cached_content_hash is not None:
return self._cached_content_hash
# We cannot compute a hash for the current block because it is not full.
if not self.is_full:
return None
is_first_block = self._prev_block is None
prev_block_hash = (
None if is_first_block else
self._prev_block.content_hash # type: ignore
)
# Previous block exists but does not yet have a hash.
# Return no hash in this case.
if prev_block_hash is None and not is_first_block:
return None
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block,
prev_block_hash,
cur_block_token_ids=self.token_ids)
return self._cached_content_hash
@staticmethod
def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],
cur_block_token_ids: List[int]) -> int:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching.
NOTE: Content-based hashing does not yet support LoRA.
Parameters:
- is_first_block (bool): A flag indicating if the block is the first in
the sequence.
- prev_block_hash (Optional[int]): The hash of the previous block. None
if this is the first block.
- cur_block_token_ids (List[int]): A list of token ids in the current
block. The current block is assumed to be full.
Returns:
- int: The computed hash value for the block.
"""
assert (prev_block_hash is None) == is_first_block
return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
def assert_prefix_caching_block_or_none(block: Optional[Block]):
if block is None:
return
assert isinstance(block, PrefixCachingBlock)

View File

@@ -0,0 +1,625 @@
"""A block manager that manages token blocks."""
import math
from abc import ABC, abstractmethod
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Set
from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
logger = init_logger(__name__)
class BlockAllocatorBase(ABC):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
@abstractmethod
def __init__(self,
device: Device,
block_size: int,
num_blocks: int,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
pass
@abstractmethod
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
pass
@abstractmethod
def free(self, block: PhysicalTokenBlock) -> None:
pass
@abstractmethod
def get_num_free_blocks(self) -> int:
pass
@abstractmethod
def get_num_total_blocks(self) -> int:
pass
@abstractmethod
def contains_block(self, block_hash: int) -> bool:
pass
@abstractmethod
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
pass
class CachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
def __init__(self,
device: Device,
block_size: int,
num_blocks: int,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
self.current_num_blocks = 0
self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
self.evictor: Evictor = make_evictor(eviction_policy)
self.default_hash_ctr = count()
def allocate_block(self, block_hash: int,
num_hashed_tokens: int) -> PhysicalTokenBlock:
if self.current_num_blocks == self.num_blocks:
block = self.evictor.evict()
block.block_hash = block_hash
block.num_hashed_tokens = num_hashed_tokens
return block
block = PhysicalTokenBlock(device=self.device,
block_number=self.current_num_blocks,
block_size=self.block_size,
block_hash=block_hash,
num_hashed_tokens=num_hashed_tokens)
self.current_num_blocks += 1
return block
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if block_hash is None:
block_hash = next(self.default_hash_ctr)
if block_hash in self.evictor:
assert block_hash not in self.cached_blocks
block = self.evictor.remove(block_hash)
assert block.ref_count == 0
self.cached_blocks[block_hash] = block
block.ref_count += 1
assert block.block_hash == block_hash
return block
if block_hash not in self.cached_blocks:
self.cached_blocks[block_hash] = self.allocate_block(
block_hash, num_hashed_tokens)
block = self.cached_blocks[block_hash]
assert block.block_hash == block_hash
block.ref_count += 1
return block
def free(self, block: PhysicalTokenBlock) -> None:
if block.ref_count == 0:
raise ValueError(f"Double free! {block} is already freed.")
block.ref_count -= 1
if block.ref_count == 0:
assert block.block_hash not in self.evictor
self.evictor.add(block)
# Remove the block from the cached_blocks
del self.cached_blocks[block.block_hash]
def get_num_free_blocks(self) -> int:
return (self.num_blocks - self.current_num_blocks +
self.evictor.num_blocks)
def get_num_total_blocks(self) -> int:
return self.num_blocks
def contains_block(self, block_hash: int) -> bool:
return block_hash in self.cached_blocks or block_hash in self.evictor
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
# Update the hash of block and the cached_blocks dictionary.
assert not self.contains_block(block_hash)
old_hash = block.block_hash
block.block_hash = block_hash
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block
class UncachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
def __init__(
self,
device: Device,
block_size: int,
num_blocks: int,
) -> None:
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
# Initialize the free blocks.
self.free_blocks: BlockTable = []
for i in range(num_blocks):
block = PhysicalTokenBlock(device=device,
block_number=i,
block_size=block_size,
block_hash=-1,
num_hashed_tokens=0)
self.free_blocks.append(block)
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if not self.free_blocks:
raise ValueError("Out of memory! No free blocks are available.")
block = self.free_blocks.pop()
block.ref_count = 1
return block
def free(self, block: PhysicalTokenBlock) -> None:
if block.ref_count == 0:
raise ValueError(f"Double free! {block} is already freed.")
block.ref_count -= 1
if block.ref_count == 0:
self.free_blocks.append(block)
def get_num_free_blocks(self) -> int:
return len(self.free_blocks)
def get_num_total_blocks(self) -> int:
return self.num_blocks
def contains_block(self, block_hash: int) -> bool:
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
class BlockSpaceManagerV1(BlockSpaceManager):
"""Manages the mapping between logical and physical token blocks."""
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
watermark: float = 0.01,
sliding_window: Optional[int] = None,
enable_caching: bool = False,
) -> None:
self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks
if enable_caching and sliding_window is not None:
raise NotImplementedError(
"Sliding window is not allowed with prefix caching enabled!")
self.block_sliding_window = None
if sliding_window is not None:
# Round up to nearest block size to regularize sliding window
# allocation sizes.
self.block_sliding_window = math.ceil(sliding_window / block_size)
self.watermark = watermark
assert watermark >= 0.0
self.enable_caching = enable_caching
self.watermark_blocks = int(watermark * num_gpu_blocks)
if self.enable_caching:
logger.info("Automatic prefix caching is enabled.")
self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
Device.CPU, block_size, num_cpu_blocks)
else:
self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = UncachedBlockAllocator(
Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {}
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks)
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction.
if (self.num_total_gpu_blocks - num_required_blocks <
self.watermark_blocks):
return AllocStatus.NEVER
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks)
block_table: BlockTable = []
for logical_idx in range(num_prompt_blocks):
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window]
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
elif self.enable_caching:
block = self.gpu_allocator.allocate(
seq.hash_of_block(logical_idx),
seq.num_hashed_tokens_of_block(logical_idx))
else:
block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
block_table.append(block)
# Assign the block table for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()
def can_append_slots(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> bool:
assert (num_lookahead_slots == 0
), "lookahead allocation not supported in BlockSpaceManagerV1"
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks
def _promote_last_block(
self,
seq: Sequence,
last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock:
assert self.enable_caching
# Compute a new hash for the block so that it can be shared by other
# Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
# if new_hash is already in the cached table, then free last_block
# and return the cached version
if self.gpu_allocator.contains_block(new_hash):
self.gpu_allocator.free(last_block)
return self.gpu_allocator.allocate(new_hash)
else:
self.gpu_allocator.update_hash(new_hash, last_block)
return last_block
def _is_last_block_full(
self,
seq: Sequence,
) -> bool:
token_ids_len = seq.data.get_len()
return token_ids_len > 0 and token_ids_len % seq.block_size == 0
def _maybe_promote_last_block(
self,
seq: Sequence,
last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock:
if self._is_last_block_full(seq):
return self._promote_last_block(seq, last_block)
else:
return last_block
def _allocate_last_physical_block(
self,
seq: Sequence,
) -> PhysicalTokenBlock:
# Called before a new block is appended.
# This is in charge of allocating a new physical block (to be appended).
# None if the last block is not full. Otherwise, we set it to the
# content hash.
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1)
# num_hashed_tokens is used to compute future hashes
# (e.g. in the hashing function, it is used to ask the sequence for
# prefix tokens)
new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
# If the block has is None, then the block is not full.
# If the block is not full, then we expect it to have a refcount of 1.
if block_hash is None:
assert new_block.ref_count == 1
return new_block
def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int = 0,
) -> Dict[int, List[int]]:
"""Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks
block_table = self.block_tables[seq.seq_id]
# If we need to allocate a new physical block
if len(block_table) < len(logical_blocks):
# Currently this code only supports adding one physical block
assert len(block_table) == len(logical_blocks) - 1
if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window):
# reuse a block
block_table.append(block_table[len(block_table) %
self.block_sliding_window])
else:
# The sequence hash a new logical block.
# Allocate a new physical block.
new_block = self._allocate_last_physical_block(seq)
block_table.append(new_block)
return {}
# We want to append the token to the last physical block.
last_block = block_table[-1]
assert last_block.device == Device.GPU
if last_block.ref_count == 1:
# Not shared with other sequences. Appendable.
if self.enable_caching:
# If the last block is now complete, we may reuse an old block
# to save memory.
maybe_new_block = self._maybe_promote_last_block(
seq, last_block)
block_table[-1] = maybe_new_block
return {}
else:
# The last block is shared with other sequences.
# Copy on Write: Allocate a new block and copy the tokens.
new_block = self._allocate_last_physical_block(seq)
block_table[-1] = new_block
self.gpu_allocator.free(last_block)
return {last_block.block_number: [new_block.block_number]}
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy()
# When using a sliding window, blocks will be eventually reused.
# In this case the block tables will contain repeated blocks.
# When forking, we must make sure that each block's `ref_count`
# is only incremented by one, so we deduplicate them by wrapping
# them in a set.
for block in set(src_block_table):
block.ref_count += 1
def _get_physical_blocks(
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
# NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.get_seqs():
if seq.is_finished():
continue
blocks.update(self.block_tables[seq.seq_id])
return list(blocks)
def can_swap_in(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation"
blocks = self._get_physical_blocks(seq_group)
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
# NOTE: Conservatively, we assume that every sequence will allocate
# at least one free block right after the swap-in.
# NOTE: This should match the logic in can_append_slot().
num_required_blocks = len(blocks) + num_swapped_seqs
if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
return AllocStatus.NEVER
elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def swap_in(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> Dict[int, int]:
assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation"
# CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
for cpu_block in block_table:
if cpu_block in mapping:
gpu_block = mapping[cpu_block]
gpu_block.ref_count += 1
else:
gpu_block = self.gpu_allocator.allocate(
cpu_block.block_hash, cpu_block.num_hashed_tokens)
mapping[cpu_block] = gpu_block
new_block_table.append(gpu_block)
# Free the CPU block swapped in to GPU.
self.cpu_allocator.free(cpu_block)
self.block_tables[seq.seq_id] = new_block_table
block_number_mapping = {
cpu_block.block_number: gpu_block.block_number
for cpu_block, gpu_block in mapping.items()
}
return block_number_mapping
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
blocks = self._get_physical_blocks(seq_group)
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
# GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
for gpu_block in block_table:
if gpu_block in mapping:
cpu_block = mapping[gpu_block]
cpu_block.ref_count += 1
else:
cpu_block = self.cpu_allocator.allocate(
gpu_block.block_hash, gpu_block.num_hashed_tokens)
mapping[gpu_block] = cpu_block
new_block_table.append(cpu_block)
# Free the GPU block swapped out to CPU.
self.gpu_allocator.free(gpu_block)
self.block_tables[seq.seq_id] = new_block_table
block_number_mapping = {
gpu_block.block_number: cpu_block.block_number
for gpu_block, cpu_block in mapping.items()
}
return block_number_mapping
def _free_block_table(self, block_table: BlockTable) -> None:
# when using a sliding window, each seq will only use up
# to `self.block_sliding_window` blocks. When freeing
# the block table, we must make sure to not free blocks more
# than once. If no sliding window is used, there is no block
# reuse in the block table, so we must free all blocks.
blocks_to_free = (block_table[-self.block_sliding_window:]
if self.block_sliding_window is not None else
block_table)
for block in set(blocks_to_free):
if block.device == Device.GPU:
self.gpu_allocator.free(block)
else:
self.cpu_allocator.free(block)
def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
block_table = self.block_tables[seq.seq_id]
self._free_block_table(block_table)
del self.block_tables[seq.seq_id]
def reset(self) -> None:
for block_table in self.block_tables.values():
self._free_block_table(block_table)
self.block_tables.clear()
def get_block_table(self, seq: Sequence) -> List[int]:
block_table = self.block_tables[seq.seq_id]
return [block.block_number for block in block_table]
def get_num_free_gpu_blocks(self) -> int:
return self.gpu_allocator.get_num_free_blocks()
def get_num_free_cpu_blocks(self) -> int:
return self.cpu_allocator.get_num_free_blocks()
def access_all_blocks_in_seq(
self,
seq: Sequence,
access_time: float,
) -> None:
if self.enable_caching:
# Update the last accessed time of all the blocks accessed
# in this step.
block_table = self.block_tables[seq.seq_id]
for block in block_table:
block.last_accessed = access_time
def compute_full_blocks_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // self.block_size - 1
block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
return
for i in reversed(range(max_full_block)):
if block_table[i].computed:
break
block_table[i].computed = True
def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
if seq.seq_id not in self.block_tables:
return []
block_table = self.block_tables[seq.seq_id]
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
return [
b.block_number
for b in takewhile(lambda b: b.computed, block_table[:-1])
]
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
"""
# Can return non-empty result only with prefix caching enabled.
if not self.enable_caching:
return []
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
return commonprefix([ids for ids in ids_list if ids != []])
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
if self.enable_caching:
for seq in seq_group.seqs_dict.values():
self.compute_full_blocks_in_seq(seq)

View File

@@ -0,0 +1,258 @@
"""A block manager that manages token blocks."""
from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
SeqId = int
class BlockSpaceManagerV2(BlockSpaceManager):
"""BlockSpaceManager which manages the allocation of KV cache.
It owns responsibility for allocation, swapping, allocating memory for
autoregressively-generated tokens, and other advanced features such as
prefix caching, forking/copy-on-write, and sliding-window memory allocation.
The current implementation is partial; in particular prefix caching and
sliding-window are not feature complete. This class implements the design
described in https://github.com/vllm-project/vllm/pull/3492.
Lookahead slots
The block manager has the notion of a "lookahead slot". These are slots
in the KV cache that are allocated for a sequence. Unlike the other
allocated slots, the content of these slots is undefined -- the worker
may use the memory allocations in any way.
In practice, a worker could use these lookahead slots to run multiple
forward passes for a single scheduler invocation. Each successive
forward pass would write KV activations to the corresponding lookahead
slot. This allows low inter-token latency use-cases, where the overhead
of continuous batching scheduling is amortized over >1 generated tokens.
Speculative decoding uses lookahead slots to store KV activations of
proposal tokens.
See https://github.com/vllm-project/vllm/pull/3250 for more information
on lookahead scheduling.
Args:
block_size (int): The size of each memory block.
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
num_cpu_blocks (int): The number of memory blocks allocated on CPU.
watermark (float, optional): The threshold used for memory swapping.
Defaults to 0.01.
sliding_window (Optional[int], optional): The size of the sliding
window. Defaults to None.
enable_caching (bool, optional): Flag indicating whether caching is
enabled. Defaults to False.
"""
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
watermark: float = 0.01,
sliding_window: Optional[int] = None,
enable_caching: bool = False,
) -> None:
self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks
assert sliding_window is None, "Sliding window not yet supported"
self.block_sliding_window = None
self.watermark = watermark
assert watermark >= 0.0
self.enable_caching = enable_caching
self.watermark_blocks = int(watermark * num_gpu_blocks)
self.block_allocator = CpuGpuBlockAllocator.create(
allocator_type="prefix_caching" if enable_caching else "naive",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=block_size,
)
self.block_tables: Dict[SeqId, BlockTable] = {}
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = BlockTable.get_num_required_blocks(
seq.get_token_ids(),
block_size=self.block_size,
)
assert self.block_sliding_window is None
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
device=Device.GPU)
# Use watermark to avoid frequent cache eviction.
if (self.num_total_gpu_blocks - num_required_blocks <
self.watermark_blocks):
return AllocStatus.NEVER
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def allocate(self, seq_group: SequenceGroup) -> None:
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
assert not (set(seq.seq_id for seq in waiting_seqs)
& self.block_tables.keys()), "block table already exists"
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = waiting_seqs[0]
block_table = BlockTable(
block_size=self.block_size,
block_allocator=self.block_allocator,
)
assert self.block_sliding_window is None
block_table.allocate(seq.get_token_ids())
self.block_tables[seq.seq_id] = block_table
# Assign the block table for each sequence.
for seq in waiting_seqs[1:]:
self.block_tables[seq.seq_id] = block_table.fork()
def can_append_slots(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
"""Determine if there is enough space in the GPU KV cache to continue
generation of the specified sequence group.
We use a worst-case heuristic: assume each touched block will require a
new allocation (either via CoW or new block). We can append slots if the
number of touched blocks is less than the number of free blocks.
"Lookahead slots" are slots that are allocated in addition to the slots
for known tokens. The contents of the lookahead slots are not defined.
This is used by speculative decoding when speculating future tokens.
"""
num_touched_blocks = 0
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
block_table = self.block_tables[seq.seq_id]
num_touched_blocks += (
block_table.get_num_blocks_touched_by_append_slots(
token_ids=block_table.get_unseen_token_ids(
seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots,
))
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
Device.GPU)
return num_touched_blocks <= num_free_gpu_blocks
def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int,
) -> Dict[int, List[int]]:
block_table = self.block_tables[seq.seq_id]
block_table.append_token_ids(
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots,
)
# Return any new copy-on-writes.
new_cows = self.block_allocator.clear_copy_on_writes()
return new_cows
def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
self.block_tables[seq.seq_id].free()
del self.block_tables[seq.seq_id]
def get_block_table(self, seq: Sequence) -> List[int]:
assert seq.seq_id in self.block_tables
block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids)
return block_ids # type: ignore
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
# Update the last accessed time of all the blocks accessed
# in this step.
# And the accessed time is only useful for prefix caching now,
# as it support internal evictor policy for which cached
# block could be refilled, to keep cached content could be reused
# at max extend.
if self.enable_caching:
block_table = self.block_tables[seq.seq_id]
block_ids = []
for block_id in block_table.physical_block_ids:
block_ids.append(block_id)
self.block_allocator.mark_blocks_as_accessed(
block_ids, # type: ignore
now)
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# The only need for mark block as computed is for prefix caching,
# while currently we could determine whether one block is computed
# or not by check whether it has content hash.
# So this function is useless for block_v2.
pass
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks.
Currently, the attention implementation only supports skipping cached
blocks if they are a contiguous prefix of cached blocks.
This method determines which blocks can be safely skipped for all
sequences in the sequence group.
"""
seq_block_ids = [
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
]
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids(
seq_block_ids) # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork()
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> AllocStatus:
return AllocStatus.LATER
def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> Dict[int, int]:
raise NotImplementedError
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
return False
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
raise NotImplementedError
def get_num_free_gpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.GPU)
def get_num_free_cpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.CPU)

105
vllm/core/evictor_v1.py Normal file
View File

@@ -0,0 +1,105 @@
import enum
from abc import ABC, abstractmethod, abstractproperty
from typing import OrderedDict
from vllm.block import PhysicalTokenBlock
class EvictionPolicy(enum.Enum):
"""Enum for eviction policy used by make_evictor to instantiate the correct
Evictor subclass.
"""
LRU = enum.auto()
class Evictor(ABC):
"""The Evictor subclasses should be used by the BlockAllocator class to
handle eviction of freed PhysicalTokenBlocks.
"""
@abstractmethod
def __init__(self):
pass
@abstractmethod
def __contains__(self, block_hash: int) -> bool:
pass
@abstractmethod
def evict(self) -> PhysicalTokenBlock:
"""Runs the eviction algorithm and returns the evicted block"""
pass
@abstractmethod
def add(self, block: PhysicalTokenBlock):
"""Adds block to the evictor, making it a candidate for eviction"""
pass
@abstractmethod
def remove(self, block_hash: int) -> PhysicalTokenBlock:
"""Simply removes the block with the hash value block_hash from the
evictor. Caller is responsible for making sure that block_hash is
contained in the evictor before calling remove. Should be used to
"bring back" blocks that have been freed but not evicted yet.
"""
pass
@abstractproperty
def num_blocks(self) -> int:
pass
class LRUEvictor(Evictor):
"""Evicts in a least-recently-used order using the last_accessed timestamp
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
"""
def __init__(self):
self.free_table: OrderedDict[int, PhysicalTokenBlock] = OrderedDict()
def __contains__(self, block_hash: int) -> bool:
return block_hash in self.free_table
def evict(self) -> PhysicalTokenBlock:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values()))
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for _, block in self.free_table.items():
if evicted_block.last_accessed < block.last_accessed:
break
if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
evicted_block = block
self.free_table.pop(evicted_block.block_hash)
evicted_block.computed = False
return evicted_block
def add(self, block: PhysicalTokenBlock):
self.free_table[block.block_hash] = block
def remove(self, block_hash: int) -> PhysicalTokenBlock:
if block_hash not in self.free_table:
raise ValueError(
"Attempting to remove block that's not in the evictor")
block: PhysicalTokenBlock = self.free_table[block_hash]
self.free_table.pop(block_hash)
return block
@property
def num_blocks(self) -> int:
return len(self.free_table)
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
if eviction_policy == EvictionPolicy.LRU:
return LRUEvictor()
else:
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")

127
vllm/core/evictor_v2.py Normal file
View File

@@ -0,0 +1,127 @@
import enum
from abc import ABC, abstractmethod, abstractproperty
from typing import OrderedDict, Tuple
class EvictionPolicy(enum.Enum):
"""Enum for eviction policy used by make_evictor to instantiate the correct
Evictor subclass.
"""
LRU = enum.auto()
class Evictor(ABC):
"""The Evictor subclasses should be used by the BlockAllocator class to
handle eviction of freed PhysicalTokenBlocks.
"""
@abstractmethod
def __init__(self):
pass
@abstractmethod
def __contains__(self, block_id: int) -> bool:
pass
@abstractmethod
def evict(self) -> Tuple[int, int]:
"""Runs the eviction algorithm and returns the evicted block's
content hash along with physical block id along with physical block id
"""
pass
@abstractmethod
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: float):
"""Adds block to the evictor, making it a candidate for eviction"""
pass
@abstractmethod
def update(self, block_id: int, last_accessed: float):
"""Update corresponding block's access time in metadata"""
pass
@abstractmethod
def remove(self, block_id: int):
"""Remove a given block id from the cache."""
pass
@abstractproperty
def num_blocks(self) -> int:
pass
class BlockMetaData():
"""Data structure for storing key data describe cached block, so that
evitor could use to make its decision which one to choose for eviction
Here we use physical block id as the dict key, as there maybe several
blocks with the same content hash, but their physical id is unique.
"""
def __init__(self, content_hash: int, num_hashed_tokens: int,
last_accessed: float):
self.content_hash = content_hash
self.num_hashed_tokens = num_hashed_tokens
self.last_accessed = last_accessed
class LRUEvictor(Evictor):
"""Evicts in a least-recently-used order using the last_accessed timestamp
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
"""
def __init__(self):
self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict()
def __contains__(self, block_id: int) -> bool:
return block_id in self.free_table
def evict(self) -> Tuple[int, int]:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values()))
evicted_block_id = next(iter(self.free_table.keys()))
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for _id, block in self.free_table.items():
if evicted_block.last_accessed > block.last_accessed or (
evicted_block.last_accessed == block.last_accessed and
evicted_block.num_hashed_tokens < block.num_hashed_tokens):
evicted_block = block
evicted_block_id = _id
self.free_table.pop(evicted_block_id)
return evicted_block_id, evicted_block.content_hash
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: float):
self.free_table[block_id] = BlockMetaData(content_hash,
num_hashed_tokens,
last_accessed)
def update(self, block_id: int, last_accessed: float):
self.free_table[block_id].last_accessed = last_accessed
def remove(self, block_id: int):
if block_id not in self.free_table:
raise ValueError(
"Attempting to remove block that's not in the evictor")
self.free_table.pop(block_id)
@property
def num_blocks(self) -> int:
return len(self.free_table)
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
if eviction_policy == EvictionPolicy.LRU:
return LRUEvictor()
else:
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")

113
vllm/core/interfaces.py Normal file
View File

@@ -0,0 +1,113 @@
import enum
from abc import ABC, abstractmethod
from typing import Dict, List
from typing import Sequence as GenericSequence
from vllm.sequence import Sequence, SequenceGroup
class AllocStatus(enum.Enum):
"""Result for BlockSpaceManager.can_allocate
1. Ok: seq_group can be allocated now.
2. Later: seq_group cannot be allocated.
The capacity of allocator is larger than seq_group required.
3. Never: seq_group can never be allocated.
The seq_group is too large to allocated in GPU.
"""
OK = enum.auto()
LATER = enum.auto()
NEVER = enum.auto()
class BlockSpaceManager(ABC):
@staticmethod
def get_block_space_manager_class(version: str):
version = version.lower()
if version == "v1":
from vllm.core.block_manager_v1 import BlockSpaceManagerV1
return BlockSpaceManagerV1
if version == "v2":
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
return BlockSpaceManagerV2
raise ValueError(f"Unknown version {version=}")
@abstractmethod
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
pass
@abstractmethod
def allocate(self, seq_group: SequenceGroup) -> None:
pass
@abstractmethod
def can_append_slots(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
pass
@abstractmethod
def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int,
) -> Dict[int, List[int]]:
pass
@abstractmethod
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
pass
@abstractmethod
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> AllocStatus:
pass
@abstractmethod
def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> Dict[int, int]:
pass
@abstractmethod
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
pass
@abstractmethod
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
pass
@abstractmethod
def free(self, seq: Sequence) -> None:
pass
@abstractmethod
def get_block_table(self, seq: Sequence) -> List[int]:
pass
@abstractmethod
def get_num_free_gpu_blocks(self) -> int:
pass
@abstractmethod
def get_num_free_cpu_blocks(self) -> int:
pass
@abstractmethod
def access_all_blocks_in_seq(
self,
seq: Sequence,
access_time: float,
) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
pass
@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass

45
vllm/core/policy.py Normal file
View File

@@ -0,0 +1,45 @@
from collections import deque
from typing import Deque
from vllm.sequence import SequenceGroup
class Policy:
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
raise NotImplementedError
def sort_by_priority(
self,
now: float,
seq_groups: Deque[SequenceGroup],
) -> Deque[SequenceGroup]:
return deque(
sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
))
class FCFS(Policy):
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
return now - seq_group.metrics.arrival_time
class PolicyFactory:
_POLICY_REGISTRY = {'fcfs': FCFS}
@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)

1163
vllm/core/scheduler.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
from .communication_op import *
from .parallel_state import *
from .utils import *

View File

@@ -0,0 +1,237 @@
from collections import namedtuple
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch.distributed import ProcessGroup
from .parallel_state import (get_cpu_world_group,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
is_pynccl_enabled_for_all_reduce)
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group.
NOTE: This operation will be applied in-place on the input tensor if
disable_custom_all_reduce is set to True. Otherwise, this operation may or
may not be applied in place depending on whether custom all reduce is
invoked for a particular tensor, which further depends on the tensor size
and GPU topology.
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
from vllm.distributed.device_communicators import pymccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
custom_all_reduce)
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
out = custom_all_reduce(input_)
if out is not None:
return out
if is_pynccl_enabled_for_all_reduce():
pymccl_utils.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
return input_
def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=get_tensor_model_parallel_group())
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size * input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
"""Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across
all the ranks.
"""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if get_tensor_model_parallel_rank() == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=dst,
group=get_tensor_model_parallel_group())
if get_tensor_model_parallel_rank() == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def broadcast(input_: torch.Tensor,
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input tensor."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(input_, src=src, group=group)
return input_
def broadcast_object_list(obj_list: List[Any],
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input object list."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
return obj_list
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list = []
tensor_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note(youkaichao): currently this only supports broadcasting
# tensors on cuda. In the future, we can add device as a field in
# TensorMetadata to support broadcasting tensors on different
# devices.
assert value.is_musa, (
f"Tensor {key}: {value} is not on musa. Currently we only "
f"support broadcasting tensors on musa.")
metadata_list.append((key, TensorMetadata(value.dtype,
value.size())))
tensor_list.append(value)
else:
metadata_list.append((key, value))
return metadata_list, tensor_list
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
group = group or torch.distributed.group.WORLD
metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return tensor_dict
rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` involves serialization and deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=metadata_group)
async_handles = []
for tensor in tensor_list:
async_handles.append(
torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True))
for async_handle in async_handles:
async_handle.wait()
else:
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=metadata_group)
assert recv_metadata_list[0] is not None
tensor_dict = {}
async_handles = []
for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device="musa")
async_handle = torch.distributed.broadcast(tensor,
src=src,
async_op=True,
group=group)
async_handles.append(async_handle)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict

View File

@@ -0,0 +1,274 @@
from contextlib import contextmanager
from typing import Any, List, Optional
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import init_logger
try:
import pynvml
from vllm_C import custom_ar
except ImportError:
# For AMD GPUs
custom_ar = None
pynvml = None
logger = init_logger(__name__)
_CA_HANDLE: Optional["CustomAllreduce"] = None
_IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
def init_custom_ar() -> None:
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
global _CA_HANDLE
if _CA_HANDLE is not None:
return
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in _SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.", world_size,
str(_SUPPORTED_WORLD_SIZES))
return
num_dev = torch.musa.device_count()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if num_dev < world_size:
logger.warning(
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(num_dev))
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
_CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
def begin_capture() -> None:
global _IS_CAPTURING
_IS_CAPTURING = True
def end_capture() -> None:
global _IS_CAPTURING
_IS_CAPTURING = False
def is_capturing() -> bool:
return _IS_CAPTURING and _CA_HANDLE is not None
def get_handle() -> Optional["CustomAllreduce"]:
return _CA_HANDLE
def is_initialized() -> bool:
return _CA_HANDLE is not None
@contextmanager
def capture():
try:
begin_capture()
yield
finally:
end_capture()
handle = get_handle()
if handle is not None:
handle.register_graph_buffers()
def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle()
# when custom allreduce is disabled, this will be None
if ca_handle is None:
return None
if is_capturing():
if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_reg(input)
else:
if ca_handle.should_custom_ar(input):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return torch.empty_like(input)
else:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input)
return None
@contextmanager
def _nvml():
try:
pynvml.nvmlInit()
yield
finally:
pynvml.nvmlShutdown()
@_nvml()
def _is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
so it works on real physical device ids.
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True
def _can_p2p(rank: int, world_size: int) -> bool:
from vllm.distributed.utils import gpu_p2p_access_check
for i in range(world_size):
if i == rank:
continue
if not gpu_p2p_access_check(rank, i):
return False
return True
class CustomAllreduce:
# max_size: max supported allreduce size
def __init__(self,
rank,
world_size,
full_nvlink,
max_size=8192 * 1024) -> None:
# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
# allreduce results.
self.meta = torch.zeros(custom_ar.meta_size() + max_size,
dtype=torch.uint8,
device="musa")
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer = torch.empty(max_size, dtype=torch.uint8, device="musa")
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8,
device="musa")
self.max_size = max_size
self.world_size = world_size
handles, offsets = self._get_ipc_meta(self.meta)
self.full_nvlink = full_nvlink
self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
handles, offsets, rank,
self.full_nvlink)
self.register_buffer(self.buffer)
def _get_ipc_meta(self, inp: torch.Tensor):
data = inp.untyped_storage()._share_cuda_()
shard_data = (
data[1], # ipc handle to base ptr
data[3], # offset of base ptr
)
return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
all_data: List[Optional[Any]] = [None] * self.world_size
dist.all_gather_object(all_data, shard_data)
handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0]) # type: ignore
offsets.append(all_data[i][1]) # type: ignore
return handles, offsets
def register_buffer(self, inp: torch.Tensor):
handles, offsets = self._get_ipc_meta(inp)
custom_ar.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self):
handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset))
custom_ar.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
self.full_nvlink)
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
custom_ar.all_reduce_reg(self._ptr, inp, out)
return out
# all reduce, assuming inp tensor is NOT IPC registered
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out
def close(self):
if self._ptr:
custom_ar.dispose(self._ptr)
self._ptr = 0
def __del__(self):
self.close()

View File

@@ -0,0 +1,284 @@
# This file is a pure Python wrapper for the MCCL library.
# The main purpose is to use MCCL combined with MUSA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls MCCL correctly, but `cupy` itself
# often gets stuck when initializing the MCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential musa APIs, that are not allowed during
# capturing the MUSA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-musagraph-with-mccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for MCCL. It is usually
# doable, but we often encounter issues related with mccl versions, and need
# to switch between different versions of MCCL. See
# https://github.com/NVIDIA/mccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of MCCL by
# changing the environment variable `VLLM_MCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import platform
from typing import Optional, Union
# ===================== import region =====================
import torch
import torch_musa
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger
from vllm.utils import find_mccl_library, mccl_integrity_check
logger = init_logger(__name__)
so_file = find_mccl_library()
try:
# load the library in another process.
# if it core dumps, it will not crash the current process
mccl_integrity_check(so_file)
mccl = ctypes.CDLL(so_file)
except Exception as e:
logger.error(
"Failed to load MCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the mccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"One solution is to download libmccl2 version 2.18 from "
"https://developer.download.nvidia.com/compute/musa/repos/ "
"and extract the libmccl.so.2 file. If you already have the "
"library, please set the environment variable VLLM_MCCL_SO_PATH"
" to point to the correct mccl library path.", so_file,
platform.platform())
raise e
# === export types and functions from mccl to Python ===
# for the original mccl definition, please check
# https://github.com/NVIDIA/mccl/blob/master/src/mccl.h.in
mcclResult_t = ctypes.c_int
_c_mcclGetErrorString = mccl.mcclGetErrorString
_c_mcclGetErrorString.restype = ctypes.c_char_p
_c_mcclGetErrorString.argtypes = [mcclResult_t]
def MCCL_CHECK(result: mcclResult_t) -> None:
if result != 0:
error_str = _c_mcclGetErrorString(result)
error_str = error_str.decode("utf-8")
raise RuntimeError(f"MCCL error: {error_str}")
# equivalent to c declaration:
# mcclResult_t mcclGetVersion(int *version);
_c_mcclGetVersion = mccl.mcclGetVersion
_c_mcclGetVersion.restype = ctypes.c_int
_c_mcclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
def mcclGetVersion() -> str:
version = ctypes.c_int()
MCCL_CHECK(_c_mcclGetVersion(ctypes.byref(version)))
version_str = str(version.value)
return version_str
class McclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
# equivalent to c declaration:
# mcclResult_t mcclGetUniqueId(mcclUniqueId* uniqueId);
_c_mcclGetUniqueId = mccl.mcclGetUniqueId
_c_mcclGetUniqueId.restype = ctypes.c_int
_c_mcclGetUniqueId.argtypes = [ctypes.POINTER(McclUniqueId)]
def mcclGetUniqueId() -> McclUniqueId:
unique_id = McclUniqueId()
MCCL_CHECK(_c_mcclGetUniqueId(ctypes.byref(unique_id)))
return unique_id
# equivalent to c declaration:
# mcclResult_t mcclCommInitRank(
# mcclComm_t* comm, int nranks, mcclUniqueId commId, int rank);
# note that mcclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
_c_mcclCommInitRank = mccl.mcclCommInitRank
_c_mcclCommInitRank.restype = ctypes.c_int
_c_mcclCommInitRank.argtypes = [
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, McclUniqueId, ctypes.c_int
]
mcclDataType_t = ctypes.c_int
class mcclDataTypeEnum:
mcclInt8 = 0
mcclChar = 0
mcclUint8 = 1
mcclInt32 = 2
mcclInt = 2
mcclUint32 = 3
mcclInt64 = 4
mcclUint64 = 5
mcclFloat16 = 6
mcclHalf = 6
mcclFloat32 = 7
mcclFloat = 7
mcclFloat64 = 8
mcclDouble = 8
mcclBfloat16 = 9
mcclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.mcclInt8
if dtype == torch.uint8:
return cls.mcclUint8
if dtype == torch.int32:
return cls.mcclInt32
if dtype == torch.int64:
return cls.mcclInt64
if dtype == torch.float16:
return cls.mcclFloat16
if dtype == torch.float32:
return cls.mcclFloat32
if dtype == torch.float64:
return cls.mcclFloat64
if dtype == torch.bfloat16:
return cls.mcclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
mcclRedOp_t = ctypes.c_int
class mcclRedOpTypeEnum:
mcclSum = 0
mcclProd = 1
mcclMax = 2
mcclMin = 3
mcclAvg = 4
mcclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.mcclSum
if op == ReduceOp.PRODUCT:
return cls.mcclProd
if op == ReduceOp.MAX:
return cls.mcclMax
if op == ReduceOp.MIN:
return cls.mcclMin
if op == ReduceOp.AVG:
return cls.mcclAvg
raise ValueError(f"Unsupported op: {op}")
# equivalent to c declaration:
# mcclResult_t mcclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# mcclDataType_t datatype, mcclRedOp_t op, mcclComm_t comm,
# udaStream_t stream);
# note that musaStream_t is a pointer type, so the last argument is a pointer
_c_mcclAllReduce = mccl.mcclAllReduce
_c_mcclAllReduce.restype = ctypes.c_int
_c_mcclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, mcclRedOp_t,
mcclDataType_t, ctypes.c_void_p, ctypes.c_void_p
]
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# equivalent to c declaration:
# mcclResult_t mcclCommDestroy(mcclComm_t comm);
_c_mcclCommDestroy = mccl.mcclCommDestroy
_c_mcclCommDestroy.restype = ctypes.c_int
_c_mcclCommDestroy.argtypes = [ctypes.c_void_p]
class MCCLCommunicator:
def __init__(
self,
group: Optional[ProcessGroup] = None,
device: Optional[Union[int, str, torch.device]] = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the MCCLCommunicator to. If None,
it will be bind to f"musa:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
assert dist.is_initialized()
group = get_cpu_world_group() if group is None else group
assert dist.get_backend(group) != dist.Backend.MCCL, (
"MCCLCommunicator should be attached to a non-MCCL group.")
self.group = group
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
if self.rank == 0:
self.unique_id = mcclGetUniqueId()
else:
self.unique_id = McclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p()
if device is None:
local_rank = get_local_rank()
device = torch.device(f"musa:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"musa:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# mccl communicator and stream will use this device
# `torch.musa.device` is a context manager that changes the
# current musa device to the specified one
with torch.musa.device(device):
MCCL_CHECK(
_c_mcclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.musa.Stream()
def all_reduce(self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
# mccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
f"this mccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
MCCL_CHECK(
_c_mcclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
mcclDataTypeEnum.from_torch(tensor.dtype),
mcclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.musa_stream)))

View File

@@ -0,0 +1,66 @@
import contextlib
from typing import Optional
import torch
from torch.distributed import ProcessGroup, ReduceOp
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
from vllm.distributed.device_communicators.pymccl import (MCCLCommunicator,
mcclGetVersion)
except Exception as e:
# in non-MTHREADS environments, we can't import the mccl module
# e.g. when running on machines with AMD GPUs
logger.info("Failed to import MCCL library: %s", e)
logger.info("It is expected if you are not running on Mthreads GPUs.")
pass
comm: Optional["MCCLCommunicator"] = None
def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return comm is not None
@contextlib.contextmanager
def set_pymccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
try:
assert comm is not None
comm.stream = stream
yield
finally:
pass
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
assert not is_initialized()
global comm
logger.info("vLLM is using nccl==%s", mcclGetVersion())
comm = MCCLCommunicator(group=group)
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_musa, f"{input_} should be a musa tensor"
assert comm is not None
comm.all_reduce(input_, op)
def destroy_process_group() -> None:
global comm
comm = None
def get_world_size() -> int:
"""Returns the world size."""
assert comm is not None
return comm.world_size
def get_nccl_backend() -> Optional["MCCLCommunicator"]:
return comm

View File

@@ -0,0 +1,287 @@
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import platform
from typing import Optional, Union
# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check
logger = init_logger(__name__)
so_file = find_nccl_library()
try:
# load the library in another process.
# if it core dumps, it will not crash the current process
nccl_integrity_check(so_file)
nccl = ctypes.CDLL(so_file)
except Exception as e:
logger.error(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"One solution is to download libnccl2 version 2.18 from "
"https://developer.download.nvidia.com/compute/cuda/repos/ "
"and extract the libnccl.so.2 file. If you already have the "
"library, please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.", so_file,
platform.platform())
raise e
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t = ctypes.c_int
_c_ncclGetErrorString = nccl.ncclGetErrorString
_c_ncclGetErrorString.restype = ctypes.c_char_p
_c_ncclGetErrorString.argtypes = [ncclResult_t]
def NCCL_CHECK(result: ncclResult_t) -> None:
if result != 0:
error_str = _c_ncclGetErrorString(result)
error_str = error_str.decode("utf-8")
raise RuntimeError(f"NCCL error: {error_str}")
# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion
_c_ncclGetVersion.restype = ctypes.c_int
_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
def ncclGetVersion() -> str:
version = ctypes.c_int()
NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
# something like 21903 --> "2.19.3"
version_str = str(version.value)
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
class NcclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
# equivalent to c declaration:
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
_c_ncclGetUniqueId = nccl.ncclGetUniqueId
_c_ncclGetUniqueId.restype = ctypes.c_int
_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
def ncclGetUniqueId() -> NcclUniqueId:
unique_id = NcclUniqueId()
NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
return unique_id
# equivalent to c declaration:
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
_c_ncclCommInitRank = nccl.ncclCommInitRank
_c_ncclCommInitRank.restype = ctypes.c_int
_c_ncclCommInitRank.argtypes = [
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
]
ncclDataType_t = ctypes.c_int
class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
ncclInt32 = 2
ncclInt = 2
ncclUint32 = 3
ncclInt64 = 4
ncclUint64 = 5
ncclFloat16 = 6
ncclHalf = 6
ncclFloat32 = 7
ncclFloat = 7
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
return cls.ncclUint8
if dtype == torch.int32:
return cls.ncclInt32
if dtype == torch.int64:
return cls.ncclInt64
if dtype == torch.float16:
return cls.ncclFloat16
if dtype == torch.float32:
return cls.ncclFloat32
if dtype == torch.float64:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
ncclMin = 3
ncclAvg = 4
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
return cls.ncclProd
if op == ReduceOp.MAX:
return cls.ncclMax
if op == ReduceOp.MIN:
return cls.ncclMin
if op == ReduceOp.AVG:
return cls.ncclAvg
raise ValueError(f"Unsupported op: {op}")
# equivalent to c declaration:
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# udaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument is a pointer
_c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
]
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# equivalent to c declaration:
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
_c_ncclCommDestroy = nccl.ncclCommDestroy
_c_ncclCommDestroy.restype = ctypes.c_int
_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
class NCCLCommunicator:
def __init__(
self,
group: Optional[ProcessGroup] = None,
device: Optional[Union[int, str, torch.device]] = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the NCCLCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
assert dist.is_initialized()
group = get_cpu_world_group() if group is None else group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"NCCLCommunicator should be attached to a non-NCCL group.")
self.group = group
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
if self.rank == 0:
self.unique_id = ncclGetUniqueId()
else:
self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p()
if device is None:
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with torch.cuda.device(device):
NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.cuda.Stream()
def all_reduce(self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
NCCL_CHECK(
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream)))

View File

@@ -0,0 +1,341 @@
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import contextlib
from typing import Optional
import torch
import torch_musa
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
# Tensor model parallel group that the current rank belongs to.
_TP_DEVICE_GROUP = None
_TP_CPU_GROUP = None
# Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
# parameter of `init_distributed_environment` below.
# Essentially, this is `torch.distributed.group.WORLD`.
# We leave a line here to note that this is device-specific.
# Note that this variable is not safe to use, because when users
# call `init_distributed_environment` first, and then destroy
# the process group themselves, this variable will keep a reference to the
# destroyed process group, which is not useful.
_DEVICE_WORLD_GROUP = None
# duing `init_distributed_environment`, we will also initialize a
# group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
_CPU_WORLD_GROUP = None
# In summary, after calling `init_distributed_environment`, we will
# always have two groups: one for device-specific (and is the default)
# and one for CPU. All processes will be part of both groups.
# A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
_LOCAL_RANK = -1
def get_local_rank():
global _LOCAL_RANK
return _LOCAL_RANK
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
distributed_init_method: str = "env://",
local_rank: int = -1,
backend: str = "nccl",
):
logger.debug(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
distributed_init_method, backend)
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
"distributed environment")
# this backend is used for WORLD
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank)
global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
_DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
backend="gloo")
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1 and distributed_init_method == "env://":
local_rank = envs.LOCAL_RANK
global _LOCAL_RANK
_LOCAL_RANK = local_rank
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend()
if (world_size !=
tensor_model_parallel_size * pipeline_model_parallel_size):
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
num_tensor_model_parallel_groups: int = (world_size //
tensor_model_parallel_size)
num_pipeline_model_parallel_groups: int = (world_size //
pipeline_model_parallel_size)
rank = torch.distributed.get_rank()
# Build the tensor model-parallel groups.
global _TP_DEVICE_GROUP, _TP_CPU_GROUP
assert _TP_DEVICE_GROUP is None, (
"tensor model parallel group is already initialized")
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks, backend=backend)
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if rank in ranks:
_TP_DEVICE_GROUP = group
_TP_CPU_GROUP = cpu_group
# Build the pipeline model-parallel groups.
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
"pipeline model parallel group is already initialized")
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks, backend=backend)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend()
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size, backend)
return
assert (
get_tensor_model_parallel_world_size() == tensor_model_parallel_size
), ("tensor parallel group already initialized, but of unexpected size: "
f"{get_tensor_model_parallel_world_size()=} vs. "
f"{tensor_model_parallel_size=}")
assert (get_pipeline_model_parallel_world_size(
) == pipeline_model_parallel_size), (
"pipeline parallel group already initialized, but of unexpected size: "
f"{get_pipeline_model_parallel_world_size()=} vs. "
f"{pipeline_model_parallel_size=}")
def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return (_TP_DEVICE_GROUP is not None
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
def get_cpu_world_group():
"""Get the CPU world group."""
assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
return _CPU_WORLD_GROUP
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TP_DEVICE_GROUP is not None, (
"tensor model parallel group is not initialized")
return _TP_DEVICE_GROUP
def get_tensor_model_parallel_cpu_group():
"""Get the tensor model parallel cpu group the caller rank belongs to."""
assert _TP_CPU_GROUP is not None, (
"tensor model parallel cpu group is not initialized")
return _TP_CPU_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
"pipeline model parallel group is not initialized")
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return torch.distributed.get_world_size(
group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
return torch.distributed.get_world_size(
group=get_pipeline_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
return torch.distributed.get_rank(
group=get_pipeline_model_parallel_group())
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that precedes the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP_DEVICE_GROUP
if _TP_DEVICE_GROUP:
torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
_TP_DEVICE_GROUP = None
global _TP_CPU_GROUP
if _TP_CPU_GROUP:
torch.distributed.destroy_process_group(_TP_CPU_GROUP)
_TP_CPU_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
if _PIPELINE_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None
from vllm.distributed.device_communicators import pymccl_utils
# Destroy the pynccl states if any.
pymccl_utils.destroy_process_group()
# Whether to use pynccl for nccl all reduce.
# We use pynccl for all reduce when using CUDA graph, because torch.distributed
# is not well supported by CUDA graph.
_ENABLE_PYNCCL_FOR_ALL_REDUCE = False
@contextlib.contextmanager
def with_pynccl_for_all_reduce():
from vllm.distributed.device_communicators import pymccl_utils
"""use pynccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1:
# No-op.
# NOTE(woosuk): We don't initialize pynccl when tp_size is 1.
yield
else:
global _ENABLE_PYNCCL_FOR_ALL_REDUCE
old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
_ENABLE_PYNCCL_FOR_ALL_REDUCE = True
stream = torch.musa.current_stream()
with pymccl_utils.set_pymccl_stream(stream):
yield
_ENABLE_PYNCCL_FOR_ALL_REDUCE = old
def is_pynccl_enabled_for_all_reduce():
"""check if pynccl is enabled for all reduce"""
global _ENABLE_PYNCCL_FOR_ALL_REDUCE
return _ENABLE_PYNCCL_FOR_ALL_REDUCE

137
vllm/distributed/utils.py Normal file
View File

@@ -0,0 +1,137 @@
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import json
import os
from typing import Dict, Optional, Sequence
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import init_logger
from .parallel_state import get_cpu_world_group, get_local_rank
logger = init_logger(__name__)
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# NOTE: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def _can_actually_p2p(idx_a, idx_b):
dev_i = f"musa:{idx_a}"
dev_j = f"musa:{idx_b}"
a = torch.randn(5, device=dev_i) + 123.0
b = a.to(dev_j)
c = b.to(dev_i)
return torch.all(a == c).cpu().item()
# why do we need this cache?
# 1. we can have runtime checks for P2P access, where every process checks
# P2P access to all other GPUs. Unfortunately, the test might cost many
# (world_size * world_size) cuda context, and reduce the memory available
# for the model. see https://github.com/vllm-project/vllm/issues/3821
# 2. alternatively, we can have a p2p map that is generated by the master
# process and broadcasted to all other processes. This still requires
# #world_size of cuda context, belonging to the master process, on each GPU.
# 3. we can have a cache file, that records the p2p access status. The first
# time the master process checks the p2p access, it will generate the cache
# file, at the cost of #world_size of cuda context. Later on, all processes
# can read the cache file to check the p2p access status without any cost of
# additional cuda context.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
def gpu_p2p_access_check(i: int, j: int) -> bool:
"""Check if GPU i can access GPU j."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{i}->{j}"]
is_distributed = dist.is_initialized()
num_dev = torch.musa.device_count()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
path = os.path.expanduser(
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
)
os.makedirs(os.path.dirname(path), exist_ok=True)
if (not is_distributed or get_local_rank() == 0) \
and (not os.path.exists(path)):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger.info("generating GPU P2P access cache for in %s", path)
cache = {}
for _i in range(num_dev):
for _j in range(num_dev):
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
cache[f"{_i}->{_j}"] = torch.musa.can_device_access_peer(
_i, _j) and _can_actually_p2p(_i, _j)
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
cpu_world_group = get_cpu_world_group()
dist.barrier(cpu_world_group)
logger.info("reading GPU P2P access cache from %s", path)
with open(path, "r") as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{i}->{j}"]

0
vllm/engine/__init__.py Normal file
View File

649
vllm/engine/arg_utils.py Normal file
View File

@@ -0,0 +1,649 @@
import argparse
import dataclasses
from dataclasses import dataclass
from typing import List, Optional, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig, VisionLanguageConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple
def nullable_str(val: str):
if not val or val == "None":
return None
return val
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
model: str
served_model_name: Optional[Union[List[str]]] = None
tokenizer: Optional[str] = None
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
enable_prefix_caching: bool = False
use_v2_block_manager: bool = False
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_logprobs: int = 5 # OpenAI default value
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0
tokenizer_pool_type: str = "ray"
tokenizer_pool_extra_config: Optional[dict] = None
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
ray_workers_use_nsight: bool = False
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None
# Related to Vision-language models such as llava
image_input_type: Optional[str] = None
image_token_id: Optional[int] = None
image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine."""
# Model arguments
parser.add_argument(
'--model',
type=str,
default='facebook/opt-125m',
help='Name or path of the huggingface model to use.')
parser.add_argument(
'--tokenizer',
type=nullable_str,
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use.')
parser.add_argument(
'--skip-tokenizer-init',
action='store_true',
help='Skip initialization of tokenizer and detokenizer')
parser.add_argument(
'--revision',
type=nullable_str,
default=None,
help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--code-revision',
type=nullable_str,
default=None,
help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-revision',
type=nullable_str,
default=None,
help='The specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
parser.add_argument('--download-dir',
type=nullable_str,
default=EngineArgs.download_dir,
help='Directory to download and load the weights, '
'default to the default cache dir of '
'huggingface.')
parser.add_argument(
'--load-format',
type=str,
default=EngineArgs.load_format,
choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
],
help='The format of the model weights to load.\n\n'
'* "auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available.\n'
'* "pt" will load the weights in the pytorch bin format.\n'
'* "safetensors" will load the weights in the safetensors format.\n'
'* "npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading.\n'
'* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave which assumes tensorizer_uri is set to the location of '
'the serialized weights.')
parser.add_argument(
'--dtype',
type=str,
default=EngineArgs.dtype,
choices=[
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
],
help='Data type for model weights and activations.\n\n'
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
'BF16 precision for BF16 models.\n'
'* "half" for FP16. Recommended for AWQ quantization.\n'
'* "float16" is the same as "half".\n'
'* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.')
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
default=None,
help='Path to the JSON file containing the KV cache '
'scaling factors. This should generally be supplied, when '
'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
'default to 1.0, which may cause accuracy issues. '
'FP8_E5M2 (without scaling) is only supported on cuda version'
'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
parser.add_argument('--max-model-len',
type=int,
default=EngineArgs.max_model_len,
help='Model context length. If unspecified, will '
'be automatically derived from the model config.')
parser.add_argument(
'--guided-decoding-backend',
type=str,
default='outlines',
choices=['outlines', 'lm-format-enforcer'],
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines and '
'https://github.com/noamgat/lm-format-enforcer.'
' Can be overridden per request via guided_decoding_backend'
' parameter.')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
help='Use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU.')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size,
help='Number of pipeline stages.')
parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.')
parser.add_argument(
'--max-parallel-loading-workers',
type=int,
default=EngineArgs.max_parallel_loading_workers,
help='Load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor '
'parallel and large models.')
parser.add_argument(
'--ray-workers-use-nsight',
action='store_true',
help='If specified, use nsight to profile Ray workers.')
# KV cache arguments
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
help='Token block size for contiguous chunks of '
'tokens.')
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='Enables automatic prefix caching.')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2.')
parser.add_argument(
'--num-lookahead-slots',
type=int,
default=EngineArgs.num_lookahead_slots,
help='Experimental scheduling config necessary for '
'speculative decoding. This will be replaced by '
'speculative config in the future; it is present '
'to enable correctness tests until then.')
parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
help='Random seed for operations.')
parser.add_argument('--swap-space',
type=int,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU.')
parser.add_argument(
'--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization,
help='The fraction of GPU memory to be used for the model '
'executor, which can range from 0 to 1. For example, a value of '
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9.')
parser.add_argument(
'--num-gpu-blocks-override',
type=int,
default=None,
help='If specified, ignore GPU profiling result and use this number'
'of GPU blocks. Used for testing preemption.')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
help='Maximum number of batched tokens per '
'iteration.')
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='Maximum number of sequences per iteration.')
parser.add_argument(
'--max-logprobs',
type=int,
default=EngineArgs.max_logprobs,
help=('Max number of log probs to return logprobs is specified in'
' SamplingParams.'))
parser.add_argument('--disable-log-stats',
action='store_true',
help='Disable logging statistics.')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=nullable_str,
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` '
'attribute in the model config file. If that is '
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
'will use eager mode and CUDA graph in hybrid '
'for maximal performance and flexibility.')
parser.add_argument('--max-context-len-to-capture',
type=int,
default=EngineArgs.max_context_len_to_capture,
help='Maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq_len-to-capture instead'
')')
parser.add_argument('--max-seq_len-to-capture',
type=int,
default=EngineArgs.max_seq_len_to_capture,
help='Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.')
parser.add_argument('--disable-custom-all-reduce',
action='store_true',
default=EngineArgs.disable_custom_all_reduce,
help='See ParallelConfig.')
parser.add_argument('--tokenizer-pool-size',
type=int,
default=EngineArgs.tokenizer_pool_size,
help='Size of tokenizer pool to use for '
'asynchronous tokenization. If 0, will '
'use synchronous tokenization.')
parser.add_argument('--tokenizer-pool-type',
type=str,
default=EngineArgs.tokenizer_pool_type,
help='Type of tokenizer pool to use for '
'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.')
parser.add_argument('--tokenizer-pool-extra-config',
type=nullable_str,
default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. '
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--max-loras',
type=int,
default=EngineArgs.max_loras,
help='Max number of LoRAs in a single batch.')
parser.add_argument('--max-lora-rank',
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument(
'--lora-dtype',
type=str,
default=EngineArgs.lora_dtype,
choices=['auto', 'float16', 'bfloat16', 'float32'],
help=('Data type for LoRA. If auto, will default to '
'base model dtype.'))
parser.add_argument(
'--max-cpu-loras',
type=int,
default=EngineArgs.max_cpu_loras,
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'))
parser.add_argument(
'--fully-sharded-loras',
action='store_true',
help=('By default, only half of the LoRA computation is '
'sharded with tensor parallelism. '
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
choices=["auto", "cuda", "neuron", "cpu", "musa"],
help='Device type for vLLM execution.')
# Related to Vision-language models such as llava
parser.add_argument(
'--image-input-type',
type=nullable_str,
default=None,
choices=[
t.name.lower() for t in VisionLanguageConfig.ImageInputType
],
help=('The image input type passed into vLLM. '
'Should be one of "pixel_values" or "image_features".'))
parser.add_argument('--image-token-id',
type=int,
default=None,
help=('Input id for image token.'))
parser.add_argument(
'--image-input-shape',
type=nullable_str,
default=None,
help=('The biggest image input shape (worst for memory footprint) '
'given an input type. Only used for vLLM\'s profile_run.'))
parser.add_argument(
'--image-feature-size',
type=int,
default=None,
help=('The image feature size along the context dimension.'))
parser.add_argument(
'--scheduler-delay-factor',
type=float,
default=EngineArgs.scheduler_delay_factor,
help='Apply a delay (of delay factor multiplied by previous'
'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
action='store_true',
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument(
'--speculative-model',
type=nullable_str,
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-max-model-len',
type=int,
default=EngineArgs.speculative_max_model_len,
help='The maximum sequence length supported by the '
'draft model. Sequences over this length will skip '
'speculation.')
parser.add_argument(
'--ngram-prompt-lookup-max',
type=int,
default=EngineArgs.ngram_prompt_lookup_max,
help='Max size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument(
'--ngram-prompt-lookup-min',
type=int,
default=EngineArgs.ngram_prompt_lookup_min,
help='Min size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument('--model-loader-extra-config',
type=nullable_str,
default=EngineArgs.model_loader_extra_config,
help='Extra config for model loader. '
'This will be passed to the model loader '
'corresponding to the chosen load_format. '
'This should be a JSON string that will be '
'parsed into a dictionary.')
parser.add_argument(
"--served-model-name",
nargs="+",
type=str,
default=None,
help="The model name(s) used in the API. If multiple "
"names are provided, the server will respond to any "
"of the provided names. The model name in the model "
"field of a response will be the first name in this "
"list. If not specified, the model name will be the "
"same as the `--model` argument. Noted that this name(s)"
"will also be used in `model_name` tag content of "
"prometheus metrics, if multiple names provided, metrics"
"tag will take the first one.")
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args
def create_engine_config(self, ) -> EngineConfig:
device_config = DeviceConfig(self.device)
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.dtype, self.seed, self.revision,
self.code_revision, self.tokenizer_revision, self.max_model_len,
self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture,
self.max_seq_len_to_capture, self.max_logprobs,
self.skip_tokenizer_init, self.served_model_name)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
self.num_gpu_blocks_override,
model_config.get_sliding_window(),
self.enable_prefix_caching)
parallel_config = ParallelConfig(
self.pipeline_parallel_size, self.tensor_parallel_size,
self.worker_use_ray, self.max_parallel_loading_workers,
self.disable_custom_all_reduce,
TokenizerPoolConfig.create_config(
self.tokenizer_pool_size,
self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
), self.ray_workers_use_nsight)
speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
)
scheduler_config = SchedulerConfig(
self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
self.use_v2_block_manager,
num_lookahead_slots=(self.num_lookahead_slots
if speculative_config is None else
speculative_config.num_lookahead_slots),
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
load_config = LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
)
if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size):
raise ValueError(
'Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.')
vision_language_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.
get_image_input_enum_type(self.image_input_type),
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
)
else:
vision_language_config = None
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
return EngineConfig(model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
load_config=load_config,
decoding_config=decoding_config)
@dataclass
class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False
max_log_len: Optional[int] = None
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser,
async_args_only: bool = False) -> argparse.ArgumentParser:
if not async_args_only:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray',
action='store_true',
help='Use Ray to start the LLM engine in a '
'separate process as the server process.')
parser.add_argument('--disable-log-requests',
action='store_true',
help='Disable logging requests.')
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
return parser
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
return EngineArgs.add_cli_args(argparse.ArgumentParser())
def _async_engine_args_parser():
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
async_args_only=True)

View File

@@ -0,0 +1,737 @@
import asyncio
import time
from functools import partial
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
import vllm.envs as envs
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
class AsyncEngineDeadError(RuntimeError):
pass
def _raise_exception_on_finish(
task: asyncio.Task, error_callback: Callable[[Exception],
None]) -> None:
msg = ("Task finished unexpectedly. This should never happen! "
"Please open an issue on Github.")
exception = None
try:
task.result()
# NOTE: This will be thrown if task exits normally (which it should not)
raise AsyncEngineDeadError(msg)
except Exception as e:
exception = e
logger.error("Engine background task failed", exc_info=e)
error_callback(exception)
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from e
class AsyncStream:
"""A stream of RequestOutputs for a request that can be
iterated over asynchronously."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, Exception]) -> None:
if self._finished:
return
self._queue.put_nowait(item)
def finish(self) -> None:
self._queue.put_nowait(StopAsyncIteration())
self._finished = True
@property
def finished(self) -> bool:
return self._finished
def __aiter__(self):
return self
async def __anext__(self) -> RequestOutput:
result = await self._queue.get()
if isinstance(result, Exception):
raise result
return result
class RequestTracker:
"""Synchronous abstraction for tracking requests."""
def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
self.new_requests_event = asyncio.Event()
def __contains__(self, item):
return item in self._request_streams
def __len__(self) -> int:
return len(self._request_streams)
def propagate_exception(self,
exc: Exception,
request_id: Optional[str] = None) -> None:
"""Propagate an exception to request streams
(all if request_id is None)."""
if request_id is not None:
self._request_streams[request_id].put(exc)
self.abort_request(request_id)
else:
for rid, stream in self._request_streams.items():
stream.put(exc)
self.abort_request(rid)
def process_request_output(self,
request_output: RequestOutput,
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
request_id = request_output.request_id
self._request_streams[request_id].put(request_output)
if request_output.finished:
if verbose:
logger.info("Finished request %s.", request_id)
self.abort_request(request_id)
def process_exception(self,
request_id: str,
exception: Exception,
*,
verbose: bool = False) -> None:
"""Propagate an exception from the engine."""
self._request_streams[request_id].put(exception)
if verbose:
logger.info("Finished request %s.", request_id)
self.abort_request(request_id)
def add_request(self, request_id: str,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
loop iteration."""
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
self._new_requests.put_nowait((stream, {
"request_id": request_id,
**engine_add_request_kwargs
}))
self.new_requests_event.set()
return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
logger.info("Aborted request %s.", request_id)
self._finished_requests.put_nowait(request_id)
if request_id not in self._request_streams or self._request_streams[
request_id].finished:
# The request has already finished or been aborted.
return
self._request_streams[request_id].finish()
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[Dict] = []
finished_requests: Set[str] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
self._request_streams.pop(request_id, None)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
# The request has already been aborted.
stream.finish()
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
return new_requests, finished_requests
async def wait_for_new_requests(self):
if not self.has_new_requests():
await self.new_requests_event.wait()
self.new_requests_event.clear()
def has_new_requests(self):
return not self._new_requests.empty()
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
async def step_async(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if not scheduler_outputs.is_empty():
# Execute the model.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
output = await self.model_executor.execute_model_async(
execute_model_req)
else:
output = []
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
return request_outputs
async def encode_request_async(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = await self.tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
async def add_request_async(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = await self.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
return self.add_request(request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
async def check_health_async(self) -> None:
self.model_executor.check_health()
class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine.
This class is used to wrap the LLMEngine class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMEngine to the caller.
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests.
max_log_len: Maximum number of prompt characters or prompt ID numbers
being printed in log.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for LLMEngine.
*kwargs: Arguments for LLMEngine.
"""
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
def __init__(self,
worker_use_ray: bool,
engine_use_ray: bool,
*args,
log_requests: bool = True,
max_log_len: Optional[int] = None,
start_engine_loop: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray
self.log_requests = log_requests
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)
self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
self.start_engine_loop = start_engine_loop
self._errored_with: Optional[BaseException] = None
# Lazy initialized fields
self._request_tracker: RequestTracker
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "cpu":
assert not engine_config.parallel_config.worker_use_ray, (
"Ray is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(
engine_config.parallel_config.worker_use_ray,
engine_args.engine_use_ray,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
)
return engine
@property
def is_running(self) -> bool:
return (self.background_loop is not None
and self._background_loop_unshielded is not None
and not self._background_loop_unshielded.done())
@property
def is_stopped(self) -> bool:
return self.errored or (self.background_loop is not None and
self._background_loop_unshielded is not None
and self._background_loop_unshielded.done())
@property
def errored(self) -> bool:
return self._errored_with is not None
def set_errored(self, exc: Exception) -> None:
self._errored_with = exc
def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)
async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote() # type: ignore
else:
return self.engine.get_tokenizer()
def start_background_loop(self) -> None:
"""Start the background loop."""
if self.errored:
raise AsyncEngineDeadError(
"Background loop has errored already.") from self._errored_with
if self.is_running:
raise RuntimeError("Background loop is already running.")
# Initialize the RequestTracker here so it uses the right event loop.
self._request_tracker = RequestTracker()
self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish,
error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
if not self.engine_use_ray:
engine_class = self._engine_class
elif self.worker_use_ray:
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
else:
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
# order of the arguments.
cache_config = kwargs["cache_config"]
parallel_config = kwargs["parallel_config"]
if parallel_config.tensor_parallel_size == 1:
num_gpus = cache_config.gpu_memory_utilization
else:
num_gpus = 1
engine_class = ray.remote(num_gpus=num_gpus)(
self._engine_class).remote
return engine_class(*args, **kwargs)
async def engine_step(self) -> bool:
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
new_requests, finished_requests = (
self._request_tracker.get_new_and_finished_requests())
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead
try:
if self.engine_use_ray:
await self.engine.add_request.remote( # type: ignore
**new_request)
else:
await self.engine.add_request_async(**new_request)
except ValueError as e:
# TODO: use a vLLM specific error for failed validation
self._request_tracker.process_exception(
new_request["request_id"],
e,
verbose=self.log_requests,
)
if finished_requests:
await self._engine_abort(finished_requests)
if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore
else:
request_outputs = await self.engine.step_async()
# Put the outputs into the corresponding streams.
for request_output in request_outputs:
self._request_tracker.process_request_output(
request_output, verbose=self.log_requests)
return len(request_outputs) > 0
async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids) # type: ignore
else:
self.engine.abort_request(request_ids)
async def run_engine_loop(self):
has_requests_in_progress = False
while True:
if not has_requests_in_progress:
logger.debug("Waiting for new requests...")
await self._request_tracker.wait_for_new_requests()
logger.debug("Got new requests!")
# Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts).
try:
has_requests_in_progress = await asyncio.wait_for(
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")
self.set_errored(exc)
raise
await asyncio.sleep(0)
async def add_request(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
shortened_token_ids = prompt_token_ids
if self.max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:self.max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self.
max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"sampling_params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt,
sampling_params, shortened_token_ids, lora_request)
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
else:
raise AsyncEngineDeadError(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
if arrival_time is None:
arrival_time = time.time()
if self.engine_use_ray:
prompt_token_ids = await (
self.engine.encode_request_async.remote( # type: ignore
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request))
else:
prompt_token_ids = await self.engine.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
return stream
async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `RequestOutput` objects from the LLMEngine for the
request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "prompt": "What is LLM?",
>>> "stream": False, # assume the non-streaming case
>>> "temperature": 0.0,
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.generate(
>>> example_input["prompt"],
>>> SamplingParams(temperature=example_input["temperature"]),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
"""
# Preprocess the request.
arrival_time = time.time()
try:
stream = await self.add_request(
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
async for request_output in stream:
yield request_output
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the
# request.
self._abort(request_id)
raise e
async def abort(self, request_id: str) -> None:
"""Abort a request.
Abort a submitted request. If the request is finished or not found,
this method will be a no-op.
Args:
request_id: The unique id of the request.
"""
if not self.is_running:
raise AsyncEngineDeadError(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
return self._abort(request_id)
def _abort(self, request_id: str) -> None:
"""Abort a request.
Abort a submitted request. If the request is finished or not found,
this method will be a no-op.
Args:
request_id: The unique id of the request.
"""
self._request_tracker.abort_request(request_id,
verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_model_config.remote() # type: ignore
else:
return self.engine.get_model_config()
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_decoding_config.remote( # type: ignore
)
else:
return self.engine.get_decoding_config()
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None:
if self.engine_use_ray:
await self.engine.do_log_stats.remote( # type: ignore
scheduler_outputs, model_output)
else:
self.engine.do_log_stats()
async def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
t = time.perf_counter()
logger.debug("Starting health check...")
if self.is_stopped:
raise AsyncEngineDeadError("Background loop is stopped.")
if self.engine_use_ray:
try:
await self.engine.check_health.remote() # type: ignore
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else:
await self.engine.check_health_async()
logger.debug("Health check took %fs", time.perf_counter() - t)

784
vllm/engine/llm_engine.py Normal file
View File

@@ -0,0 +1,784 @@
import time
from typing import Iterable, List, Optional, Type, Union
from transformers import GenerationConfig, PreTrainedTokenizer
import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig):
try:
return GenerationConfig.from_pretrained(
model_config.model,
revision=model_config.revision,
).to_diff_dict()
except OSError:
# Not found.
return {}
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
This is the main class for the vLLM engine. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `EngineArgs`.
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA.
vision_language_config (Optional): The configuration related to vision
language models.
speculative_config (Optional): The configuration related to speculative
decoding.
executor_class: The model executor class for managing distributed
execution.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection
"""
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
"max_seq_len=%d, download_dir=%r, load_format=%s, "
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)",
vllm.__version__,
model_config.model,
speculative_config,
model_config.tokenizer,
model_config.skip_tokenizer_init,
model_config.tokenizer_mode,
model_config.revision,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
model_config.max_model_len,
load_config.download_dir,
load_config.load_format,
parallel_config.tensor_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
device_config.device,
decoding_config,
model_config.seed,
model_config.served_model_name,
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.log_stats = log_stats
if not self.model_config.skip_tokenizer_init:
self.tokenizer: BaseTokenizerGroup
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
else:
self.detokenizer = None
self.tokenizer = None
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
load_config=load_config,
)
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (
get_architecture_class_name)
usage_message.report_usage(
get_architecture_class_name(model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(model_config.dtype),
"tensor_parallel_size":
parallel_config.tensor_parallel_size,
"block_size":
cache_config.block_size,
"gpu_memory_utilization":
cache_config.gpu_memory_utilization,
# Quantization
"quantization":
model_config.quantization,
"kv_cache_dtype":
cache_config.cache_dtype,
# Feature flags
"enable_lora":
bool(lora_config),
"enable_prefix_caching":
cache_config.enable_prefix_caching,
"enforce_eager":
model_config.enforce_eager,
"disable_custom_all_reduce":
parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
self.get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
self.get_tokenizer_for_seq,
),
))
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
The workers will determine the number of blocks in both the GPU cache
and the swap CPU cache.
"""
num_gpu_blocks, num_cpu_blocks = (
self.model_executor.determine_num_available_blocks())
if self.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_gpu_blocks,
num_gpu_blocks_override)
num_gpu_blocks = num_gpu_blocks_override
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
# Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
# Create the LLM engine.
engine = cls(
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
)
return engine
def __reduce__(self):
# This is to ensure that the LLMEngine is not referenced in
# the closure used to initialize Ray worker actors
raise RuntimeError("LLMEngine should not be pickled!")
def __del__(self):
# Shutdown model executor when engine is garbage collected
# Use getattr since __init__ can fail before the field is set
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs):
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
def encode_request(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
def add_request(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
scheduler as `engine.step()` is called. The exact scheduling policy is
determined by the scheduler.
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
multi_modal_data: Multi modal data per request.
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `best_of` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
Example:
>>> # initialize engine
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> # set request arguments
>>> example_prompt = "Who is the president of the United States?"
>>> sampling_params = SamplingParams(temperature=0.0)
>>> request_id = 0
>>>
>>> # add the request to the engine
>>> engine.add_request(
>>> str(request_id),
>>> example_prompt,
>>> SamplingParams(temperature=0.0))
>>> # continue the request processing
>>> ...
"""
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
max_logprobs = self.get_model_config().max_logprobs
if (sampling_params.logprobs
and sampling_params.logprobs > max_logprobs) or (
sampling_params.prompt_logprobs
and sampling_params.prompt_logprobs > max_logprobs):
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = self.encode_request(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = None
if self.tokenizer:
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
else:
logger.warning("Use None for EOS token id because tokenizer is "
"not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# Add the eos token id into the sampling_params to support min_tokens
# processing
if seq.eos_token_id is not None:
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
sampling_params.update_from_generation_config(
self.generation_config_fields)
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time, lora_request, multi_modal_data)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID.
Args:
request_id: The ID(s) of the request to abort.
Details:
- Refer to the
:meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
from class :class:`~vllm.core.scheduler.Scheduler`.
Example:
>>> # initialize engine and add a request with request_id
>>> request_id = str(0)
>>> # abort the request
>>> engine.abort_request(request_id)
"""
self.scheduler.abort_seq_group(request_id)
def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config
def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
def _process_model_outputs(
self,
output: List[SamplerOutput],
scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[RequestOutput]:
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
"""
now = time.time()
# Organize outputs by [sequence group][step] instead of
# [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group(
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(
scheduled_seq_groups, output_by_sequence_group,
seq_group_metadata_list):
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
self.output_processor.process_prompt_logprob(seq_group, outputs)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
# Create the outputs.
request_outputs: List[RequestOutput] = []
for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
for seq_group in ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
return request_outputs
def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png
:alt: Overview of the step function
:align: center
Overview of the step function.
Details:
- Step 1: Schedules the sequences to be executed in the next
iteration and the token blocks to be swapped in/out/copy.
- Depending on the scheduling policy,
sequences may be `preempted/reordered`.
- A Sequence Group (SG) refer to a group of sequences
that are generated from the same prompt.
- Step 2: Calls the distributed executor to execute the model.
- Step 3: Processes the model output. This mainly includes:
- Decodes the relevant outputs.
- Updates the scheduled sequence groups with model outputs
based on its `sampling parameters` (`use_beam_search` or not).
- Frees the finished sequence groups.
- Finally, it creates and returns the newly generated results.
Example:
>>> # Please see the example/ folder for more detailed examples.
>>>
>>> # initialize engine and request arguments
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> example_inputs = [(0, "What is LLM?",
>>> SamplingParams(temperature=0.0))]
>>>
>>> # Start the engine with an event loop
>>> while True:
>>> if example_inputs:
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> engine.add_request(str(req_id), prompt, sampling_params)
>>>
>>> # continue the request processing
>>> request_outputs = engine.step()
>>> for request_output in request_outputs:
>>> if request_output.finished:
>>> # return or show the request output
>>>
>>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if not scheduler_outputs.is_empty():
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
else:
output = []
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
return request_outputs
def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None:
"""Forced log when no requests active."""
if self.log_stats:
self.stat_logger.log(
self._get_stats(scheduler_outputs, model_output))
def _get_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs],
model_output: Optional[List[SamplerOutput]] = None) -> Stats:
"""Get Stats to be Logged to Prometheus.
Args:
scheduler_outputs: Optional, used to populate metrics related to
the scheduled batch,
model_output: Optional, used to emit speculative decoding metrics
which are created by the workers.
"""
now = time.time()
# System State
# Scheduler State
num_running_sys = len(self.scheduler.running)
num_swapped_sys = len(self.scheduler.swapped)
num_waiting_sys = len(self.scheduler.waiting)
# KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage_sys = 0.
if num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
)
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []
# Request stats
# Latency
time_e2e_requests: List[float] = []
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
best_of_requests: List[int] = []
n_requests: List[int] = []
finished_reason_requests: List[str] = []
# NOTE: This loop assumes prefill seq_groups are before
# decode seq_groups in scheduled_seq_groups.
if scheduler_outputs is not None:
num_generation_tokens_from_prefill_groups = 0.
# NOTE: if scheduler_outputs.num_prefill_groups > 0 and
# the len of scheduler_outputs.scheduled_seq_groups is !=
# scheduler_outputs.num_prefill_groups, this means that
# chunked prefills have been detected.
for idx, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
group_was_prefill = idx < scheduler_outputs.num_prefill_groups
seq_group = scheduled_seq_group.seq_group
# NOTE: a seq_group that completed all of its prefill tokens
# in the last iteration will have seq_group.is_prefill() = False
# with group_was_prefill = True
if group_was_prefill:
# Number of prompt tokens.
num_prompt_tokens_iter += (
scheduled_seq_group.token_chunk_size)
# If the seq_group just finished the prefill state
# get TTFT.
if not seq_group.is_prefill():
latency = seq_group.get_last_latency(now)
time_to_first_tokens_iter.append(latency)
# One generation token per finished prefill.
num_generation_tokens_from_prefill_groups += (
seq_group.num_seqs())
else:
# TPOTs.
latency = seq_group.get_last_latency(now)
time_per_output_tokens_iter.append(latency)
# Because of chunked prefill, we can have a single sequence
# group that does multiple prompt_runs. To prevent logging
# the same metadata more than once per request, we standardize
# on logging request level information for finished requests,
# which can only happen once.
if seq_group.is_finished():
# Latency timings
time_e2e_requests.append(now -
seq_group.metrics.arrival_time)
# Metadata
num_prompt_tokens_requests.append(
len(seq_group.prompt_token_ids))
num_generation_tokens_requests.extend([
seq.get_output_len()
for seq in seq_group.get_finished_seqs()
])
best_of_requests.append(seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n)
finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status)
for seq in seq_group.get_finished_seqs()
])
# Number of generation tokens.
# num_batched_tokens equals the number of prompt_tokens plus the
# number of decode_tokens in a single iteration. So,
# num_generation_tokens = num_batched_tokens - num_prompt_tokens
# + num_generation_tokens_from_prefill_groups (since we generate
# one token on prefills on iters where the prefill finishes).
num_generation_tokens_iter = (
scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
num_generation_tokens_from_prefill_groups)
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if model_output and (model_output[0].spec_decode_worker_metrics
is not None):
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
else:
spec_decode_metrics = None
return Stats(
now=now,
# System stats
# Scheduler State
num_running_sys=num_running_sys,
num_swapped_sys=num_swapped_sys,
num_waiting_sys=num_waiting_sys,
# KV Cache Usage in %
gpu_cache_usage_sys=gpu_cache_usage_sys,
cpu_cache_usage_sys=cpu_cache_usage_sys,
# Iteration stats
num_prompt_tokens_iter=num_prompt_tokens_iter,
num_generation_tokens_iter=num_generation_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
spec_decode_metrics=spec_decode_metrics,
# Request stats
# Latency
time_e2e_requests=time_e2e_requests,
# Metadata
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
best_of_requests=best_of_requests,
n_requests=n_requests,
finished_reason_requests=finished_reason_requests,
)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_executor.remove_lora(lora_id)
def list_loras(self) -> List[int]:
return self.model_executor.list_loras()
def check_health(self) -> None:
self.model_executor.check_health()

368
vllm/engine/metrics.py Normal file
View File

@@ -0,0 +1,368 @@
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Protocol, Union
import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
disable_created_metrics)
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
logger = init_logger(__name__)
disable_created_metrics()
# The begin-* and end* here are used by the documentation generator
# to extract the metrics definitions.
# begin-metrics-definitions
class Metrics:
labelname_finish_reason = "finished_reason"
def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors
for collector in list(REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
REGISTRY.unregister(collector)
# Config Information
self.info_cache_config = Info(
name='vllm:cache_config',
documentation='information of cache_config')
# System stats
# Scheduler State
self.gauge_scheduler_running = Gauge(
name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.",
labelnames=labelnames)
self.gauge_scheduler_waiting = Gauge(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
self.gauge_scheduler_swapped = Gauge(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
# KV Cache Usage in %
self.gauge_gpu_cache_usage = Gauge(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
self.gauge_cpu_cache_usage = Gauge(
name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
# Iteration stats
self.counter_prompt_tokens = Counter(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
labelnames=labelnames)
self.counter_generation_tokens = Counter(
name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labelnames)
self.histogram_time_to_first_token = Histogram(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labelnames,
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
])
self.histogram_time_per_output_token = Histogram(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labelnames,
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5
])
# Request stats
# Latency
self.histogram_e2e_time_request = Histogram(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Metadata
self.histogram_num_prompt_tokens_request = Histogram(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_num_generation_tokens_request = Histogram(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_best_of_request = Histogram(
name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = Histogram(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.counter_request_success = Counter(
name="vllm:request_success_total",
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])
# Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = Gauge(
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
)
# Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = Gauge(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
labelnames=labelnames,
)
# end-metrics-definitions
def build_1_2_5_buckets(max_value: int):
"""
Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values (1, 2, 5) until the value exceeds the specified maximum.
Example:
>>> build_1_2_5_buckets(100)
[1, 2, 5, 10, 20, 50, 100]
"""
mantissa_lst = [1, 2, 5]
exponent = 0
buckets = []
while True:
for m in mantissa_lst:
value = m * 10**exponent
if value <= max_value:
buckets.append(value)
else:
return buckets
exponent += 1
@dataclass
class Stats:
"""Created by LLMEngine for use by StatLogger."""
now: float
# System stats (should have _sys suffix)
# Scheduler State
num_running_sys: int
num_waiting_sys: int
num_swapped_sys: int
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int
num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float]
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None:
# Metadata for logging locally.
self.last_local_log = time.time()
self.local_interval = local_interval
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
# Prometheus metrics
self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()),
max_model_len=max_model_len)
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())
def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
return float(np.sum(tracked_stats) / (now - self.last_local_log))
def _local_interval_elapsed(self, now: float) -> bool:
elapsed_time = now - self.last_local_log
return elapsed_time > self.local_interval
def _log_prometheus(self, stats: Stats) -> None:
# System state data
self._log_gauge(self.metrics.gauge_scheduler_running,
stats.num_running_sys)
self._log_gauge(self.metrics.gauge_scheduler_swapped,
stats.num_swapped_sys)
self._log_gauge(self.metrics.gauge_scheduler_waiting,
stats.num_waiting_sys)
self._log_gauge(self.metrics.gauge_gpu_cache_usage,
stats.gpu_cache_usage_sys)
self._log_gauge(self.metrics.gauge_cpu_cache_usage,
stats.cpu_cache_usage_sys)
# Iteration level data
self._log_counter(self.metrics.counter_prompt_tokens,
stats.num_prompt_tokens_iter)
self._log_counter(self.metrics.counter_generation_tokens,
stats.num_generation_tokens_iter)
self._log_histogram(self.metrics.histogram_time_to_first_token,
stats.time_to_first_tokens_iter)
self._log_histogram(self.metrics.histogram_time_per_output_token,
stats.time_per_output_tokens_iter)
# Request level data
# Latency
self._log_histogram(self.metrics.histogram_e2e_time_request,
stats.time_e2e_requests)
# Metadata
finished_reason_counter = CollectionsCounter(
stats.finished_reason_requests)
self._log_counter_labels(self.metrics.counter_request_success,
finished_reason_counter,
Metrics.labelname_finish_reason)
self._log_histogram(self.metrics.histogram_num_prompt_tokens_request,
stats.num_prompt_tokens_requests)
self._log_histogram(
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests)
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
self._log_histogram(self.metrics.histogram_best_of_request,
stats.best_of_requests)
def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def _log_counter(self, counter: Counter, data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self.labels).inc(data)
def _log_counter_labels(self, counter: Counter, data: CollectionsCounter,
label_key: str) -> None:
# Convenience function for collection counter of labels.
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram: Histogram,
data: Union[List[int], List[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
# Logs metrics to prometheus that are computed every logging_interval.
# Support legacy gauge metrics that make throughput calculations on
# the vLLM side. Moving forward, we should use counters like
# counter_prompt_tokens, counter_generation_tokens
# Which log raw data and calculate summaries using rate() on the
# grafana/prometheus side. See
# https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
self.metrics.gauge_avg_prompt_throughput.labels(
**self.labels).set(prompt_throughput)
self.metrics.gauge_avg_generation_throughput.labels(
**self.labels).set(generation_throughput)
def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
Logs to prometheus and tracked stats every iteration.
Logs to Stdout every self.local_interval seconds."""
# Log to prometheus.
self._log_prometheus(stats)
# Save tracked stats for token counters.
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Log locally every local_interval seconds.
if self._local_interval_elapsed(stats.now):
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput = self._get_throughput(self.num_prompt_tokens,
now=stats.now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now=stats.now)
self._log_prometheus_interval(
prompt_throughput=prompt_throughput,
generation_throughput=generation_throughput)
# Log to stdout.
logger.info(
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%",
prompt_throughput,
generation_throughput,
stats.num_running_sys,
stats.num_swapped_sys,
stats.num_waiting_sys,
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")

View File

View File

@@ -0,0 +1,76 @@
from abc import ABC, abstractmethod
from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
class SequenceGroupOutputProcessor(ABC):
"""Interface for logic that processes new token ids in sequence groups,
managing detokenization, stop checking, and freeing/forking sequences with
the scheduler.
This is highly coupled with the LLMEngine and should be seen as an extension
of it. The logic is separated to simplify the LLMEngine class and allow
separate implementations for single-step decoding (which supports beam
search sequence forking) and multi-step decoding (which does not support
beam search, but does support speculative decoding).
"""
@staticmethod
def create_output_processor(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker",
):
"""Create an output processor.
This returns a single-step output processor if num_lookahead_slots is
zero, else returns a multi-step output processor.
"""
if scheduler_config.num_lookahead_slots == 0:
# Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor)
return SingleStepOutputProcessor(
scheduler_config,
detokenizer,
scheduler,
seq_counter,
stop_checker,
)
else:
# Importing here to avoid cycle.
from vllm.engine.output_processor.multi_step import (
MultiStepOutputProcessor)
return MultiStepOutputProcessor(
detokenizer,
scheduler,
seq_counter,
get_tokenizer_for_seq,
stop_checker,
)
@abstractmethod
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
"""
pass
@abstractmethod
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Update prompt logprobs received from outputs to seq_group."""
pass

View File

@@ -0,0 +1,142 @@
import functools
from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles logic related to
detokenization and stopping conditions. It specializes to "multi-step
decoding", where vLLM's worker may generate multiple tokens per invocation.
This is currently mutually exclusive with advanced sampling techniques like
beam search, which motivates the separation of this logic from the single
step output processor.
This class is responsible for things such as correctly appending all new
token ids to their sequence, detokenizing new token ids, truncating new
output tokens after an eos token, and correctly handling the case where the
number of new output tokens per sequence differs in a single batch.
"""
def __init__(
self,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker,
):
self.detokenizer = detokenizer
self.scheduler = scheduler
self.seq_counter = seq_counter
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.stop_checker = stop_checker
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
# TODO(sang): Prompt logprob currently not implemented in multi step
# workers.
self._log_prompt_logprob_unsupported_warning_once()
@staticmethod
@functools.lru_cache()
def _log_prompt_logprob_unsupported_warning_once():
logger.warning(
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).")
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
one new token per sequence.
This applies logic like stop condition checking and detokenization,
including freeing finished sequences. It also handles cases where there
are tokens emitted after the EOS token.
"""
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
assert seqs, "expected running sequences"
assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
seq = seqs[0]
# Since there's only one sequence per sequence group, we can take the
# first sample.
samples = [outputs[step].samples[0] for step in range(len(outputs))]
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples = [
sample for sample in samples if sample.output_token != -1
]
assert valid_samples
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput],
sampling_params: SamplingParams) -> None:
output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples]
# Truncate to max_tokens if necessary.
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
len(output_token_ids))
if remaining_tokens < 0:
valid_samples = valid_samples[:remaining_tokens]
output_token_ids = output_token_ids[:remaining_tokens]
# Truncate any tokens after EOS. This is required as spec decode
# generates a fixed number of tokens without evaluating stopping
# conditions within the block. This can cause an eos token to be
# unintentionally ignored.
if not sampling_params.ignore_eos:
eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
# Avoiding .index calls as exception throwing in the happy path
# is expensive.
for i in range(len(output_token_ids)):
if output_token_ids[i] == eos_token_id:
output_token_ids = output_token_ids[:i + 1]
valid_samples = valid_samples[:i + 1]
break
# Incrementally append tokens to the sequence, as if we had only one new
# token.
for output_token_id, output_logprob in zip(output_token_ids,
output_logprobs):
seq.append_token_id(
token_id=output_token_id,
logprobs=output_logprob,
)
new_char_count = 0
if sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params)
if seq.is_finished():
break
if seq.is_finished():
self.scheduler.free_seq(seq)

View File

@@ -0,0 +1,284 @@
from typing import Dict, List, Tuple, Union
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles "output processing" logic,
which happens after the model returns generated token ids and before
scheduling of the next batch. Output processing logic includes
detokenization, and determining if a sequence is finished (e.g. via max len
or eos token).
The SingleStepOutputProcessor is specialized to the case where the model
emits at most a single token per invocation, which precludes configurations
such as speculative decoding or multi-step decoding. This enables beam
search sampling, which requires forking/finishing/freeing sequences in a way
that is currently difficult to schedule multiple steps ahead of time.
"""
def __init__(
self,
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Counter,
stop_checker: StopChecker,
):
self.scheduler_config = scheduler_config
self.detokenizer = detokenizer
self.scheduler = scheduler
self.seq_counter = seq_counter
self.stop_checker = stop_checker
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Append all new tokens to sequences in the sequence group. Fork any
surviving beam candidates; free any unsurviving ones.
Invokes detokenizer to detokenize new tokens, and also marks sequences
as finished if they meet stop conditions.
"""
assert (len(outputs) == 1
), f"{type(self)} does not support multiple outputs per step"
return self._process_sequence_group_outputs(sequence_group, outputs[0])
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0]
prompt_logprobs = output.prompt_logprobs
if (prompt_logprobs is not None
and seq_group.sampling_params.detokenize and self.detokenizer):
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
if not seq_group.prompt_logprobs:
# The first prompt token's logprob is None because it doesn't
# have tokens that are precedent.
seq_group.prompt_logprobs = [None]
seq_group.prompt_logprobs.extend(prompt_logprobs)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutput] = parent_child_dict[
parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(seq, new_char_count,
seq_group.sampling_params)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs = []
unselected_child_seqs = []
beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
seq_group.sampling_params.early_stopping,
seq_group.sampling_params, best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq)
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=current_worst_seq.eos_token_id)
if early_stopping is False:
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id)
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id))
return current_worst_score >= highest_attainable_score

View File

@@ -0,0 +1,101 @@
from typing import Callable, Optional
from transformers import PreTrainedTokenizer
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
class StopChecker:
"""LLMEngine helper class which separates out the logic involving stop
checking. This checks things such as: whether the eos token was emitted,
whether the max_tokens has been consumed, whether a stop string has been
emitted, or if we have exceeded the max model len.
"""
def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence],
PreTrainedTokenizer]):
self.max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop_str = self._check_stop_strings(seq, new_char_count,
sampling_params)
if stop_str is not None:
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
@staticmethod
def _check_stop_strings(seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> Optional[str]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns the stop string if matched or else None.
"""
if not new_char_count:
return None
for stop_str in sampling_params.stop:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = seq.output_text.find(
stop_str, -new_char_count - stop_string_len)
if stop_index == -1:
continue
if sampling_params.include_stop_str_in_output:
# Truncate to end of stop string.
stop_index += stop_string_len
if stop_index >= len(seq.output_text):
# No truncation required.
return stop_str
# Truncate the output text to either the beginning
# or end of the stop string.
seq.output_text = seq.output_text[:stop_index]
return stop_str
return None

View File

@@ -0,0 +1,19 @@
from typing import List
from vllm.sequence import SamplerOutput, SequenceGroupOutput
def create_output_by_sequence_group(
sampler_outputs: List[SamplerOutput],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group: List[List[SamplerOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in sampler_outputs:
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)
return output_by_sequence_group

View File

View File

@@ -0,0 +1,119 @@
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""
import argparse
import json
import ssl
from typing import AsyncGenerator
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
engine = None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [
prompt + output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
if stream:
return StreamingResponse(stream_results())
# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return Response(status_code=499)
final_output = request_output
assert final_output is not None
prompt = final_output.prompt
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument("--ssl-ca-certs",
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER)
app.root_path = args.root_path
uvicorn.run(app,
host=args.host,
port=args.port,
log_level=args.log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)

259
vllm/entrypoints/llm.py Normal file
View File

@@ -0,0 +1,259 @@
from typing import List, Optional, Union
import torch
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer. Expect valid prompt_token_ids and None for prompt
from the input.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
If None, we first check the `quantization_config` attribute in the
model config file. If that is None, we assume the model weights are
not quantized and use `dtype` to determine the data type of
the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
disable_custom_all_reduce: See ParallelConfig
"""
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
def get_tokenizer(
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer.tokenizer
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
self.llm_engine.tokenizer.tokenizer = tokenizer
def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
if self.llm_engine.model_config.skip_tokenizer_init \
and prompts is not None:
raise ValueError("prompts must be None if skip_tokenizer_init "
"is True")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if (prompts is not None and prompt_token_ids is not None
and len(prompts) != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
if prompts is not None:
num_requests = len(prompts)
else:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
elif isinstance(sampling_params,
list) and len(sampling_params) != num_requests:
raise ValueError("The lengths of prompts and sampling_params "
"must be the same.")
if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine.
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]
self._add_request(
prompt,
sampling_params[i]
if isinstance(sampling_params, list) else sampling_params,
token_ids,
lora_request=lora_request,
# Get ith image while maintaining the batch dim.
multi_modal_data=MultiModalData(
type=multi_modal_data.type,
data=multi_modal_data.data[i].unsqueeze(0))
if multi_modal_data else None,
)
return self._run_engine(use_tqdm)
def _add_request(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests,
desc="Processed prompts",
dynamic_ncols=True)
# Run the engine.
outputs: List[RequestOutput] = []
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
pbar.update(1)
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs

View File

View File

@@ -0,0 +1,186 @@
import asyncio
import importlib
import inspect
import re
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Any, Set
import fastapi
import uvicorn
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from starlette.routing import Mount
import vllm
import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest, ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
logger = init_logger(__name__)
_running_tasks: Set[asyncio.Task[Any]] = set()
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
await engine.do_log_stats()
if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
yield
app = fastapi.FastAPI(lifespan=lifespan)
def parse_args():
parser = make_arg_parser()
return parser.parse_args()
# Add prometheus asgi middleware to route /metrics requests
route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
@app.get("/health")
async def health() -> Response:
"""Health check."""
await openai_serving_chat.engine.check_health()
return Response(status_code=200)
@app.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
return JSONResponse(content=models.model_dump())
@app.get("/version")
async def show_version():
ver = {"version": vllm.__version__}
return JSONResponse(content=ver)
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump())
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
return JSONResponse(content=generator.model_dump())
if __name__ == "__main__":
args = parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
if token := envs.VLLM_API_KEY or args.api_key:
@app.middleware("http")
async def authentication(request: Request, call_next):
root_path = "" if args.root_path is None else args.root_path
if not request.url.path.startswith(f"{root_path}/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return await call_next(request)
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")
logger.info("vLLM API server version %s", vllm.__version__)
logger.info("args: %s", args)
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
openai_serving_chat = OpenAIServingChat(engine, served_model_names,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, served_model_names, args.lora_modules)
app.root_path = args.root_path
uvicorn.run(app,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)

View File

@@ -0,0 +1,115 @@
"""
This file contains the command line arguments for the vLLM's
OpenAI-compatible server. It is kept in a separate file for documentation
purposes.
"""
import argparse
import json
import ssl
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
setattr(namespace, self.dest, lora_list)
def make_arg_parser():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host",
type=nullable_str,
default=None,
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--uvicorn-log-level",
type=str,
default="info",
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
help="log level for uvicorn")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--api-key",
type=nullable_str,
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument(
"--lora-modules",
type=nullable_str,
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.")
parser.add_argument("--chat-template",
type=nullable_str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile",
type=nullable_str,
default=None,
help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile",
type=nullable_str,
default=None,
help="The file path to the SSL cert file")
parser.add_argument("--ssl-ca-certs",
type=nullable_str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=nullable_str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=nullable_str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server "
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
parser = AsyncEngineArgs.add_cli_args(parser)
return parser

View File

@@ -0,0 +1,460 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
import torch
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
class ErrorResponse(OpenAIBaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int
class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = False
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: bool = False
class ModelCard(OpenAIBaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "vllm"
root: Optional[str] = None
parent: Optional[str] = None
permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(OpenAIBaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam]
model: str
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = Field(None,
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
user: Optional[str] = None
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
echo: Optional[bool] = Field(
default=False,
description=(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."),
)
add_generation_prompt: Optional[bool] = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-chat-completion-extra-params
def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
raise ValueError("Top logprobs must be set when logprobs is.")
logits_processors = None
if self.logit_bias:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
)
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = 16
n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = Field(None,
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
suffix: Optional[str] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
user: Optional[str] = None
# doc: begin-completion-sampling-params
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-completion-extra-params
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = None
if self.logit_bias:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=(self.spaces_between_special_tokens),
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class LogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
class ChatMessage(OpenAIBaseModel):
role: str
content: str
class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)

View File

@@ -0,0 +1,392 @@
import asyncio
import codecs
import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
Optional, Tuple, TypedDict, Union, final)
from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartParam,
ChatCompletionRole)
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.utils import random_uuid
logger = init_logger(__name__)
@final # So that it should be compatible with Dict[str, str]
class ConversationMessage(TypedDict):
role: str
content: str
class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
response_role: str,
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None):
super().__init__(engine=engine,
served_model_names=served_model_names,
lora_modules=lora_modules,
await_post_init=self._load_chat_template(
chat_template=chat_template))
self.response_role = response_role
def _parse_chat_message_content(
self,
role: ChatCompletionRole,
content: Optional[Union[str,
Iterable[ChatCompletionContentPartParam]]],
) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]:
if content is None:
return [], []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)], []
texts: List[str] = []
for _, part in enumerate(content):
if part["type"] == "text":
text = part["text"]
texts.append(text)
else:
raise NotImplementedError(f"Unknown part type: {part['type']}")
return [ConversationMessage(role=role, content="\n".join(texts))], []
async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request
) -> Union[ErrorResponse, AsyncGenerator[str, None],
ChatCompletionResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI
ChatCompletion API.
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
try:
conversation: List[ConversationMessage] = []
for m in request.messages:
messages, _ = self._parse_chat_message_content(
m["role"], m["content"])
conversation.extend(messages)
prompt = self.tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}"
try:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
request, prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logits_processor)
except ValueError as e:
return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt_text, sampling_params,
request_id, prompt_ids,
lora_request)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id,
conversation)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
return self.response_role
else:
return request.messages[-1]["role"]
async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
chunk_object_type = "chat.completion.chunk"
first_iteration = True
# Send response for each token for each request.n (index)
assert request.n is not None
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
try:
async for res in result_generator:
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
if first_iteration:
# Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request)
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the
# last message
if request.echo:
last_msg_content = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get(
"role") == role:
last_msg_content = conversation[-1]["content"]
if last_msg_content:
for i in range(request.n):
choice_data = (
ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
content=last_msg_content),
finish_reason=None))
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
logprobs=None,
model=model_name)
data = chunk.model_dump_json(
exclude_unset=True)
yield f"data: {data}\n\n"
first_iteration = False
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs:
logprobs = self._create_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens +
previous_num_tokens[i],
)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if final_usage is not None:
chunk.usage = final_usage
data = chunk.model_dump_json(exclude_unset=True,
exclude_none=True)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0]
created_time = int(time.time())
final_res: Optional[RequestOutput] = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
return self.create_error_response("Client disconnected")
final_res = res
assert final_res is not None
choices = []
role = self.get_chat_request_role(request)
for output in final_res.outputs:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.logprobs:
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
choices.append(choice_data)
if request.echo:
last_msg_content = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
return response
async def _load_chat_template(self, chat_template: Optional[str]):
while self.tokenizer is None:
# Give the parent class time to load the tokenizer
await asyncio.sleep(0.1)
tokenizer = self.tokenizer
if chat_template is not None:
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")
logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning(
"No chat template provided. Chat API will not work.")

View File

@@ -0,0 +1,347 @@
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple)
from fastapi import Request
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
def parse_prompt_format(prompt) -> Tuple[bool, list]:
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
prompt_is_tokens = False
prompts = [prompt] # case 1: a string
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
elif isinstance(prompt[0], str):
prompt_is_tokens = False
prompts = prompt # case 2: array of strings
elif isinstance(prompt[0], int):
prompt_is_tokens = True
prompts = [prompt] # case 3: array of tokens
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays
else:
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
return prompt_is_tokens, prompts
class OpenAIServingCompletion(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]] = None):
super().__init__(engine=engine,
served_model_names=served_model_names,
lora_modules=lora_modules)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
model_name = self.served_model_names[0]
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
# Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = []
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else:
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats
generators.append(
self.engine.generate(prompt_text,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=prompt_ids,
lora_request=lora_request))
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, RequestOutput]] = merge_async_iterators(*generators)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
# beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
# Streaming response
if stream:
return self.completion_stream_generator(request,
raw_request,
result_generator,
request_id,
created_time,
model_name,
num_prompts=len(prompts))
# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = self.request_output_to_completion_response(
final_res_batch, request, request_id, created_time, model_name)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
if request.stream:
response_json = response.model_dump_json()
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return fake_stream_generator()
return response
async def completion_stream_generator(
self,
request: CompletionRequest,
raw_request: Request,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
) -> AsyncGenerator[str, None]:
assert request.n is not None
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
try:
async for prompt_idx, res in result_generator:
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await self.engine.abort(f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
for output in res.outputs:
i = output.index + prompt_idx * request.n
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = (res.prompt_token_ids +
output.token_ids)
top_logprobs = res.prompt_logprobs + (output.logprobs
or [])
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[
previous_num_tokens[i]:]
top_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None
if request.logprobs is not None:
logprobs = self._create_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
if output.finish_reason is not None: # return final usage
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
else:
final_usage = None
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
stop_reason=stop_reason,
)
],
usage=final_usage,
).model_dump_json(exclude_unset=True)
yield f"data: {response_json}\n\n"
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
def request_output_to_completion_response(
self,
final_res_batch: List[RequestOutput],
request: CompletionRequest,
request_id: str,
created_time: int,
model_name: str,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None)
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
top_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
choices.append(choice_data)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)

View File

@@ -0,0 +1,234 @@
import asyncio
import json
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, ErrorResponse,
LogProbs, ModelCard, ModelList,
ModelPermission)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import get_tokenizer
logger = init_logger(__name__)
@dataclass
class LoRAModulePath:
name: str
local_path: str
class OpenAIServing:
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
await_post_init: Optional[Awaitable[Any]] = None):
self.engine = engine
self.served_model_names = served_model_names
if lora_modules is None:
self.lora_requests = []
else:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_local_path=lora.local_path,
) for i, lora in enumerate(lora_modules, start=1)
]
self.max_model_len = 0
# Lazy initialized
self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop.create_task(self._post_init(await_post_init))
else:
# When using single vLLM without engine_use_ray
asyncio.run(self._post_init(await_post_init))
async def _post_init(self, await_post_init):
engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
tokenizer_revision=engine_model_config.tokenizer_revision,
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
if await_post_init is not None:
await await_post_init
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=served_model_name,
root=self.served_model_names[0],
permission=[ModelPermission()])
for served_model_name in self.served_model_names
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=self.served_model_names[0],
permission=[ModelPermission()])
for lora in self.lora_requests
]
model_cards.extend(lora_cards)
return ModelList(data=model_cards)
def _create_logprobs(
self,
token_ids: List[int],
top_logprobs: List[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(None)
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append(None)
else:
token_logprob = step_top_logprobs[token_id].logprob
token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if num_output_top_logprobs:
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
p.decoded_token: max(p.logprob, -9999.0)
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
return logprobs
def create_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message,
type=err_type,
code=status_code.value)
def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps({
"error":
self.create_error_response(message=message,
err_type=err_type,
status_code=status_code).model_dump()
})
return json_str
async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest]
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(
self, request: Union[CompletionRequest, ChatCompletionRequest]
) -> Optional[LoRARequest]:
if request.model in self.served_model_names:
return None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")
if prompt_ids is None:
tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
"truncation": True,
"max_length": truncate_prompt_tokens,
}
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
else:
input_ids = prompt_ids
input_text = prompt if prompt is not None else self.tokenizer.decode(
prompt_ids)
token_num = len(input_ids)
if request.max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.", )
request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
else:
return input_ids, input_text

217
vllm/envs.py Normal file
View File

@@ -0,0 +1,217 @@
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
VLLM_USE_MODELSCOPE: bool = False
VLLM_INSTANCE_ID: Optional[str] = None
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
VLLM_API_KEY: Optional[str] = None
S3_ACCESS_KEY_ID: Optional[str] = None
S3_SECRET_ACCESS_KEY: Optional[str] = None
S3_ENDPOINT_URL: Optional[str] = None
VLLM_CONFIG_ROOT: str = ""
VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai"
VLLM_NO_USAGE_STATS: bool = False
VLLM_DO_NOT_TRACK: bool = False
VLLM_USAGE_SOURCE: str = ""
VLLM_CONFIGURE_LOGGING: int = 1
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_TARGET_DEVICE: str = "musa"
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
VLLM_BUILD_WITH_NEURON: bool = False
VLLM_USE_PRECOMPILED: bool = False
VLLM_INSTALL_PUNICA_KERNELS: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# begin-env-vars-definition
environment_variables: Dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu]
"VLLM_TARGET_DEVICE":
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"),
# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
"MAX_JOBS":
lambda: os.getenv("MAX_JOBS", None),
# Number of threads to use for nvcc
# By default this is 1.
# If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU.
"NVCC_THREADS":
lambda: os.getenv("NVCC_THREADS", None),
# If set, vllm will build with Neuron support
"VLLM_BUILD_WITH_NEURON":
lambda: bool(os.environ.get("VLLM_BUILD_WITH_NEURON", False)),
# If set, vllm will use precompiled binaries (*.so)
"VLLM_USE_PRECOMPILED":
lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")),
# If set, vllm will install Punica kernels
"VLLM_INSTALL_PUNICA_KERNELS":
lambda: bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))),
# CMake build type
# If not set, defaults to "Debug" or "RelWithDebInfo"
# Available options: "Debug", "Release", "RelWithDebInfo"
"CMAKE_BUILD_TYPE":
lambda: os.getenv("CMAKE_BUILD_TYPE"),
# If set, vllm will print verbose logs during installation
"VERBOSE":
lambda: bool(int(os.getenv('VERBOSE', '0'))),
# Root directory for VLLM configuration files
# Note that this not only affects how vllm finds its configuration files
# during runtime, but also affects how vllm installs its configuration
# files during **installation**.
"VLLM_CONFIG_ROOT":
lambda: os.environ.get("VLLM_CONFIG_ROOT", None) or os.getenv(
"XDG_CONFIG_HOME", None) or os.path.expanduser("~/.config"),
# ================== Runtime Env Vars ==================
# used in distributed environment to determine the master address
'VLLM_HOST_IP':
lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""),
# If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE":
lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true",
# Instance id represents an instance of the VLLM. All processes in the same
# instance should have the same instance id.
"VLLM_INSTANCE_ID":
lambda: os.environ.get("VLLM_INSTANCE_ID", None),
# path to cudatoolkit home directory, under which should be bin, include,
# and lib directories.
"CUDA_HOME":
lambda: os.environ.get("CUDA_HOME", None),
# Path to the NCCL library file. It is needed because nccl>=2.19 brought
# by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234
"VLLM_NCCL_SO_PATH":
lambda: os.environ.get("VLLM_NCCL_SO_PATH", None),
# when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl
# library file in the locations specified by `LD_LIBRARY_PATH`
"LD_LIBRARY_PATH":
lambda: os.environ.get("LD_LIBRARY_PATH", None),
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":
lambda: int(os.environ.get("LOCAL_RANK", "0")),
# used to control the visible devices in the distributed setting
"CUDA_VISIBLE_DEVICES":
lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
# timeout for each iteration in the engine
"VLLM_ENGINE_ITERATION_TIMEOUT_S":
lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")),
# API key for VLLM API server
"VLLM_API_KEY":
lambda: os.environ.get("VLLM_API_KEY", None),
# S3 access information, used for tensorizer to load model from S3
"S3_ACCESS_KEY_ID":
lambda: os.environ.get("S3_ACCESS_KEY", None),
"S3_SECRET_ACCESS_KEY":
lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None),
"S3_ENDPOINT_URL":
lambda: os.environ.get("S3_ENDPOINT_URL", None),
# Usage stats collection
"VLLM_USAGE_STATS_SERVER":
lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"),
"VLLM_NO_USAGE_STATS":
lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1",
"VLLM_DO_NOT_TRACK":
lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get(
"DO_NOT_TRACK", None) or "0") == "1",
"VLLM_USAGE_SOURCE":
lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"),
# Logging configuration
# If set to 0, vllm will not configure logging
# If set to 1, vllm will configure logging using the default configuration
# or the configuration file specified by VLLM_LOGGING_CONFIG_PATH
"VLLM_CONFIGURE_LOGGING":
lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")),
"VLLM_LOGGING_CONFIG_PATH":
lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"),
# Trace function calls
# If set to 1, vllm will trace function calls
# Useful for debugging
"VLLM_TRACE_FUNCTION":
lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")),
# Backend for attention computation
# Available options:
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
# - "FLASH_ATTN": use FlashAttention
# - "XFORMERS": use XFormers
# - "ROCM_FLASH": use ROCmFlashAttention
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
# CPU key-value cache space
# default is 4GB
"VLLM_CPU_KVCACHE_SPACE":
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")),
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
"VLLM_USE_RAY_COMPILED_DAG":
lambda: bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)),
# Use dedicated multiprocess context for workers.
# Both spawn and fork work
"VLLM_WORKER_MULTIPROC_METHOD":
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"),
}
# end-env-vars-definition
def __getattr__(name):
# lazy evaluation of environment variables
if name in environment_variables:
return environment_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(environment_variables.keys())

View File

View File

@@ -0,0 +1,152 @@
from typing import List, Set, Tuple
import torch
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
logger = init_logger(__name__)
class CPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
assert self.lora_config is None, "cpu backend doesn't support LoRA"
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
self.scheduler_config = _verify_and_get_scheduler_config(
self.scheduler_config)
# Instantiate the worker and load the model to CPU.
self._init_worker()
def _init_worker(self):
from vllm.worker.cpu_worker import CPUWorker
assert self.parallel_config.world_size == 1, (
"CPUExecutor only supports single CPU socket currently.")
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = CPUWorker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
# NOTE: `cpu block` for CPU backend is located on CPU memory but is
# referred as `gpu block`. Because we want to reuse the existing block
# management procedure.
logger.info("# CPU blocks: %d", num_gpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None:
# CPUExecutor will always be healthy as long as
# it's running.
return
class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output
async def check_health_async(self) -> None:
# CPUExecutor will always be healthy as long as
# it's running.
return
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype == torch.float16:
logger.warning("float16 is not supported on CPU, casting to bfloat16.")
config.dtype = torch.bfloat16
if not config.enforce_eager:
logger.warning(
"CUDA graph is not supported on CPU, fallback to the eager "
"mode.")
config.enforce_eager = True
return config
def _verify_and_get_scheduler_config(
config: SchedulerConfig) -> SchedulerConfig:
if config.chunked_prefill_enabled:
logger.warning("Chunked prefill is not supported on CPU, disable it.")
config.chunked_prefill_enabled = False
return config
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
_GB = 1 << 30
if config.enable_prefix_caching:
logger.warning("Prefix caching is not supported on CPU, disable it.")
config.enable_prefix_caching = False
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
if kv_cache_space >= 0:
if kv_cache_space == 0:
config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore
logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
"for CPU backend is not set, using 4 by default.")
else:
config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore
else:
raise RuntimeError(
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")
return config

View File

@@ -0,0 +1,115 @@
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput
logger = init_logger(__name__)
class DistributedGPUExecutor(GPUExecutor):
"""Abstract superclass of multi-GPU executor implementations."""
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", )
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache in all workers.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
all_outputs = self._run_workers("execute_model",
driver_args=args,
driver_kwargs=kwargs)
# Only the driver worker returns the sampling results.
return all_outputs[0]
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")
@abstractmethod
def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
raise NotImplementedError
class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
@abstractmethod
async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
raise NotImplementedError
async def execute_model_async(self, *args,
**kwargs) -> List[SamplerOutput]:
all_outputs = await self._run_workers_async("execute_model",
driver_args=args,
driver_kwargs=kwargs)
# Only the driver worker returns the sampling results.
return all_outputs[0]

View File

@@ -0,0 +1,115 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
class ExecutorBase(ABC):
"""Base class for all executors.
An executor is responsible for executing the model on a specific device
type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
that can execute the model on multiple devices.
"""
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.speculative_config = speculative_config
self._init_executor()
@abstractmethod
def _init_executor(self) -> None:
pass
@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
Normally, this should simply delegate to the underlying Worker. Some
ExecutorBase may require modification of the result, e.g. to ensure the
selected cache sizes are compatible with all workers.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
@abstractmethod
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError
@abstractmethod
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences."""
raise NotImplementedError
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def list_loras(self) -> Set[int]:
raise NotImplementedError
@abstractmethod
def check_health(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise NotImplementedError
def shutdown(self) -> None:
"""Shutdown the executor."""
return
def __del__(self):
self.shutdown()
class ExecutorAsyncBase(ExecutorBase):
@abstractmethod
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Executes one model step on the given sequences."""
raise NotImplementedError
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
self.check_health()

View File

@@ -0,0 +1,150 @@
from typing import Any, Dict, List, Optional, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class GPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
If speculative decoding is enabled, we instead create the speculative
worker.
"""
if self.speculative_config is None:
self._init_non_spec_worker()
else:
self._init_spec_worker()
def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
"""Return worker init args for a given rank."""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=rank == 0,
)
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
def _init_non_spec_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _init_spec_worker(self):
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
"""
assert self.speculative_config is not None
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
target_worker = self._create_worker()
draft_worker_kwargs = self._get_worker_kwargs()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config,
# TODO allow draft-model specific load config.
#load_config=self.load_config,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
)
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = spec_decode_worker
# Load model handled in spec decode worker.
self.driver_worker.init_device()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output

View File

@@ -0,0 +1,263 @@
import asyncio
import multiprocessing
import os
import sys
import threading
import traceback
import uuid
from dataclasses import dataclass
from multiprocessing import Queue
from multiprocessing.connection import wait
from multiprocessing.process import BaseProcess
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
TypeVar, Union)
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
T = TypeVar('T')
_TERMINATE = "TERMINATE" # sentinel
# ANSI color codes
CYAN = '\033[1;36m'
RESET = '\033[0;0m'
JOIN_TIMEOUT_S = 2
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
mp = multiprocessing.get_context(mp_method)
@dataclass
class Result(Generic[T]):
"""Result of task dispatched to worker"""
task_id: uuid.UUID
value: Optional[T] = None
exception: Optional[BaseException] = None
class ResultFuture(threading.Event, Generic[T]):
"""Synchronous future for non-async case"""
def __init__(self):
super().__init__()
self.result: Optional[Result[T]] = None
def set_result(self, result: Result[T]):
self.result = result
self.set()
def get(self) -> T:
self.wait()
assert self.result is not None
if self.result.exception is not None:
raise self.result.exception
return self.result.value # type: ignore[return-value]
def _set_future_result(future: Union[ResultFuture, asyncio.Future],
result: Result):
if isinstance(future, ResultFuture):
future.set_result(result)
return
loop = future.get_loop()
if result.exception is not None:
loop.call_soon_threadsafe(future.set_exception, result.exception)
else:
loop.call_soon_threadsafe(future.set_result, result.value)
class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""
def __init__(self) -> None:
super().__init__(daemon=True)
self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
def run(self):
for result in iter(self.result_queue.get, _TERMINATE):
future = self.tasks.pop(result.task_id)
_set_future_result(future, result)
# Ensure that all waiters will receive an exception
for task_id, future in self.tasks.items():
_set_future_result(
future,
Result(task_id=task_id,
exception=ChildProcessError("worker died")))
def close(self):
self.result_queue.put(_TERMINATE)
class WorkerMonitor(threading.Thread):
"""Monitor worker status (in background thread)"""
def __init__(self, workers: List['ProcessWorkerWrapper'],
result_handler: ResultHandler):
super().__init__(daemon=True)
self.workers = workers
self.result_handler = result_handler
self._close = False
def run(self) -> None:
# Blocks until any worker exits
dead_sentinels = wait([w.process.sentinel for w in self.workers])
if not self._close:
self._close = True
# Kill / cleanup all workers
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode)
# Cleanup any remaining workers
logger.info("Killing local vLLM worker processes")
for worker in self.workers:
worker.kill_worker()
# Must be done after worker task queues are all closed
self.result_handler.close()
for worker in self.workers:
worker.process.join(JOIN_TIMEOUT_S)
def close(self):
if self._close:
return
self._close = True
logger.info("Terminating local vLLM worker processes")
for worker in self.workers:
worker.terminate_worker()
# Must be done after worker task queues are all closed
self.result_handler.close()
class ProcessWorkerWrapper:
"""Local process wrapper for vllm.worker.Worker,
for handling single-node multi-GPU tensor parallel."""
def __init__(self, result_handler: ResultHandler,
worker_factory: Callable[[], Any]) -> None:
self._task_queue = mp.Queue()
self.result_queue = result_handler.result_queue
self.tasks = result_handler.tasks
self.process: BaseProcess = mp.Process( # type: ignore[attr-defined]
target=_run_worker_process,
name="VllmWorkerProcess",
kwargs=dict(
worker_factory=worker_factory,
task_queue=self._task_queue,
result_queue=self.result_queue,
),
daemon=True)
self.process.start()
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
method: str, args, kwargs):
task_id = uuid.uuid4()
self.tasks[task_id] = future
try:
self._task_queue.put((task_id, method, args, kwargs))
except BaseException as e:
del self.tasks[task_id]
raise ChildProcessError("worker died") from e
def execute_method(self, method: str, *args, **kwargs):
future: ResultFuture = ResultFuture()
self._enqueue_task(future, method, args, kwargs)
return future
async def execute_method_async(self, method: str, *args, **kwargs):
future = asyncio.get_running_loop().create_future()
self._enqueue_task(future, method, args, kwargs)
return await future
def terminate_worker(self):
try:
self._task_queue.put(_TERMINATE)
except ValueError:
self.process.kill()
self._task_queue.close()
def kill_worker(self):
self._task_queue.close()
self.process.kill()
def _run_worker_process(
worker_factory: Callable[[], Any],
task_queue: Queue,
result_queue: Queue,
) -> None:
"""Worker process event loop"""
# Add process-specific prefix to stdout and stderr
process_name = mp.current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
# Initialize worker
worker = worker_factory()
del worker_factory
# Accept tasks from the engine in task_queue
# and return task output in result_queue
logger.info("Worker ready; awaiting tasks")
try:
for items in iter(task_queue.get, _TERMINATE):
output = None
exception = None
task_id, method, args, kwargs = items
try:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
except BaseException as e:
tb = traceback.format_exc()
logger.error(
"Exception in worker %s while processing method %s: %s, %s",
process_name, method, e, tb)
exception = e
result_queue.put(
Result(task_id=task_id, value=output, exception=exception))
except KeyboardInterrupt:
pass
except Exception:
logger.exception("Worker failed")
logger.info("Worker exiting")
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
"""Prepend each output line with process-specific prefix"""
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
file_write = file.write
def write_with_prefix(s: str):
if not s:
return
if file.start_new_line: # type: ignore[attr-defined]
file_write(prefix)
idx = 0
while (next_idx := s.find('\n', idx)) != -1:
next_idx += 1
file_write(s[idx:next_idx])
if next_idx == len(s):
file.start_new_line = True # type: ignore[attr-defined]
return
file_write(prefix)
idx = next_idx
file_write(s[idx:])
file.start_new_line = False # type: ignore[attr-defined]
file.start_new_line = True # type: ignore[attr-defined]
file.write = write_with_prefix # type: ignore[method-assign]

View File

@@ -0,0 +1,91 @@
from typing import List, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import make_async
logger = init_logger(__name__)
class NeuronExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert (self.lora_config is
None), "LoRA is not supported for Neuron backend."
assert (not self.speculative_config
), "Speculative decoding not yet supported for Neuron backend."
# Instantiate the worker and load the model to the device.
self._init_worker()
def _init_worker(self):
from vllm.worker.neuron_worker import NeuronWorker
self.driver_worker = NeuronWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
assert (execute_model_req.blocks_to_swap_in == {}
and execute_model_req.blocks_to_swap_out == {}
and execute_model_req.blocks_to_copy == {}), (
"Cache operations are not supported for Neuron backend.")
assert execute_model_req.num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.")
output = self.driver_worker.execute_model(
execute_model_req.seq_group_metadata_list)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None:
# NeuronExecutor will always be healthy as long as
# it's running.
return
class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = await make_async(
self.driver_worker.execute_model
)(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, )
return output
async def check_health_async(self) -> None:
# NeuronExecutor will always be healthy as long as
# it's running.
return

View File

@@ -0,0 +1,327 @@
import asyncio
import os
import pickle
from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import vllm.envs as envs
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.worker_use_ray
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
runtime_env.update({
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
})
return ray_remote_kwargs
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
# For single GPU case, we use a ray worker with constrained memory.
num_gpus = self.cache_config.gpu_memory_utilization
else:
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus = 1
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
trust_remote_code=self.model_config.trust_remote_code,
)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
trust_remote_code=self.model_config.trust_remote_code,
)
else:
# Else, added to the list of workers.
self.workers.append(worker)
if self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)
node_workers = defaultdict(list)
node_gpus = defaultdict(list)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
all_outputs = self._run_workers(
"execute_model",
driver_kwargs={"execute_model_req": execute_model_req},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
# Only the driver worker returns the sampling results.
return all_outputs[0]
def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
ways:
- args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
if driver_args is None:
driver_args = args if all_args is None else all_args[0]
if driver_kwargs is None:
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
else:
assert self.driver_dummy_worker is not None
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs
def _compiled_ray_dag(self):
import pkg_resources
required_version = "2.9"
current_version = pkg_resources.get_distribution("ray").version
if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.worker_use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with InputNode() as input_data:
forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote.
bind( # type: ignore[attr-defined]
input_data) for worker in self.workers
])
return forward_dag.experimental_compile()
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
def _check_if_any_actor_is_dead(self):
if not self.workers:
return
dead_actors = []
for actor in self.workers:
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
if actor_state["State"] == "DEAD":
dead_actors.append(actor)
if dead_actors:
raise RuntimeError("At least one Worker is dead. "
f"Dead Workers: {dead_actors}. ")
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_executor = make_async(self.driver_worker.execute_method)
async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
coros.append(
self.driver_executor(method, *driver_args, **driver_kwargs))
# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))
all_outputs = await asyncio.gather(*coros)
return all_outputs

119
vllm/executor/ray_utils.py Normal file
View File

@@ -0,0 +1,119 @@
import pickle
from typing import List, Optional, Tuple
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_ip, is_hip
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
try:
import ray
class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False
def get_node_ip(self) -> str:
return get_ip()
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
node_id = ray.get_runtime_context().get_node_id()
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids
def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled."""
import torch
if not self.compiled_dag_cuda_device_set:
torch.musa.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
output = self.worker.execute_model()
output = pickle.dumps(output)
return output
except ImportError as e:
logger.warning(
"Failed to import Ray with %r. For distributed inference, "
"please install Ray with `pip install ray`.", e)
ray = None # type: ignore
RayWorkerWrapper = None # type: ignore
def initialize_ray_cluster(
parallel_config: ParallelConfig,
ray_address: Optional[str] = None,
):
"""Initialize the distributed cluster with Ray.
it will connect to the Ray cluster and create a placement group
for the workers, which includes the specification of the resources
for each distributed worker.
Args:
parallel_config: The configurations for parallel execution.
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
"""
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
if is_hip():
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
else:
ray.init(address=ray_address, ignore_reinit_error=True)
if parallel_config.placement_group:
# Placement group is already set.
return
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
gpu_bundles = 0
for bundle in bundles:
bundle_gpus = bundle.get("GPU", 0)
if bundle_gpus > 1:
raise ValueError(
"Placement group bundle cannot have more than 1 GPU.")
if bundle_gpus:
gpu_bundles += 1
if parallel_config.world_size > gpu_bundles:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs in the placement group.")
else:
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
if parallel_config.world_size > num_gpus_in_cluster:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs in the cluster.")
# Create a new placement group
placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
current_placement_group = ray.util.placement_group(
placement_group_specs)
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800)
# Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group

153
vllm/logger.py Normal file
View File

@@ -0,0 +1,153 @@
"""Logging configuration for vLLM."""
import datetime
import json
import logging
import os
import sys
from functools import partial
from logging import Logger
from logging.config import dictConfig
from os import path
from typing import Dict, Optional
import vllm.envs as envs
VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"
DEFAULT_LOGGING_CONFIG = {
"formatters": {
"vllm": {
"class": "vllm.logging.NewLineFormatter",
"datefmt": _DATE_FORMAT,
"format": _FORMAT,
},
},
"handlers": {
"vllm": {
"class": "logging.StreamHandler",
"formatter": "vllm",
"level": "INFO",
"stream": "ext://sys.stdout",
},
},
"loggers": {
"vllm": {
"handlers": ["vllm"],
"level": "DEBUG",
"propagate": False,
},
},
"version": 1,
}
def _configure_vllm_root_logger() -> None:
logging_config: Optional[Dict] = None
if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
raise RuntimeError(
"VLLM_CONFIGURE_LOGGING evaluated to false, but "
"VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH "
"implies VLLM_CONFIGURE_LOGGING. Please enable "
"VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.")
if VLLM_CONFIGURE_LOGGING:
logging_config = DEFAULT_LOGGING_CONFIG
if VLLM_LOGGING_CONFIG_PATH:
if not path.exists(VLLM_LOGGING_CONFIG_PATH):
raise RuntimeError(
"Could not load logging config. File does not exist: %s",
VLLM_LOGGING_CONFIG_PATH)
with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8",
mode="r") as file:
custom_config = json.loads(file.read())
if not isinstance(custom_config, dict):
raise ValueError("Invalid logging config. Expected Dict, got %s.",
type(custom_config).__name__)
logging_config = custom_config
if logging_config:
dictConfig(logging_config)
def init_logger(name: str) -> Logger:
"""The main purpose of this function is to ensure that loggers are
retrieved in such a way that we can be sure the root vllm logger has
already been configured."""
return logging.getLogger(name)
# The root logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
_configure_vllm_root_logger()
logger = init_logger(__name__)
def _trace_calls(log_path, root_dir, frame, event, arg=None):
if event in ['call', 'return']:
# Extract the filename, line number, function name, and the code object
filename = frame.f_code.co_filename
lineno = frame.f_lineno
func_name = frame.f_code.co_name
if not filename.startswith(root_dir):
# only log the functions in the vllm root_dir
return
# Log every function call or return
try:
last_frame = frame.f_back
if last_frame is not None:
last_filename = last_frame.f_code.co_filename
last_lineno = last_frame.f_lineno
last_func_name = last_frame.f_code.co_name
else:
# initial frame
last_filename = ""
last_lineno = 0
last_func_name = ""
with open(log_path, 'a') as f:
if event == 'call':
f.write(f"{datetime.datetime.now()} Call to"
f" {func_name} in {filename}:{lineno}"
f" from {last_func_name} in {last_filename}:"
f"{last_lineno}\n")
else:
f.write(f"{datetime.datetime.now()} Return from"
f" {func_name} in {filename}:{lineno}"
f" to {last_func_name} in {last_filename}:"
f"{last_lineno}\n")
except NameError:
# modules are deleted during shutdown
pass
return partial(_trace_calls, log_path, root_dir)
def enable_trace_function_call(log_file_path: str,
root_dir: Optional[str] = None):
"""
Enable tracing of every function call in code under `root_dir`.
This is useful for debugging hangs or crashes.
`log_file_path` is the path to the log file.
`root_dir` is the root directory of the code to trace. If None, it is the
vllm root directory.
Note that this call is thread-level, any threads calling this function
will have the trace enabled. Other threads will not be affected.
"""
logger.warning(
"VLLM_TRACE_FUNCTION is enabled. It will record every"
" function executed by Python. This will slow down the code. It "
"is suggested to be used for debugging hang or crashes only.")
logger.info("Trace frame log is saved to %s", log_file_path)
if root_dir is None:
# by default, this is the vllm root directory
root_dir = os.path.dirname(os.path.dirname(__file__))
sys.settrace(partial(_trace_calls, log_file_path, root_dir))

5
vllm/logging/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from vllm.logging.formatter import NewLineFormatter
__all__ = [
"NewLineFormatter",
]

15
vllm/logging/formatter.py Normal file
View File

@@ -0,0 +1,15 @@
import logging
class NewLineFormatter(logging.Formatter):
"""Adds logging prefix to newlines to align multi-line messages."""
def __init__(self, fmt, datefmt=None, style="%"):
logging.Formatter.__init__(self, fmt, datefmt, style)
def format(self, record):
msg = logging.Formatter.format(self, record)
if record.message != "":
parts = msg.split(record.message)
msg = msg.replace("\n", "\r\n" + parts[0])
return msg

0
vllm/lora/__init__.py Normal file
View File

View File

@@ -0,0 +1,262 @@
# pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
RowParallelLinearWithLoRA)
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
if TYPE_CHECKING:
pass
def _fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
return (can_replace(*args, **kwargs)
and kwargs['lora_config'].fully_sharded_loras)
return dec
# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
"""
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
# now have column partitioned output
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
def _mcp_apply_weights(x, bias, layer):
"""
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
LoRa weight application method.
The main difference is the step by shard_size for lora_b which can
vary for QKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
n = len(layer.lora_a_stacked)
output = layer.base_layer.linear_method.apply_weights(
layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device)
for idx in range(n):
bgmv(buffers[idx], x, layer.lora_a_stacked[idx],
layer.indices[:layer.indices_len[0]], 0, 1.0)
buffers = tensor_model_parallel_all_gather(buffers)
left_offset = 0
for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2]
dispatch_bgmv_low_level(output, buffers[idx],
layer.lora_b_stacked[idx],
layer.indices[:layer.indices_len[0]], 0, 1.0,
left_offset, shard_size)
left_offset += shard_size
output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output
class MergedColumnParallelLinearWithShardedLoRA(
MergedColumnParallelLinearWithLoRA):
"""
Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[i][:, output_start_idx:output_start_idx + output_shard_size]
for i in range(2)
]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
"""
Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[i][:, start_idx[i]:start_idx[i] +
shard_size[i]] if lora_a[i] is not None else None
for i in range(3)
]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
"""
Differs from RowParallelLinearWithLoRA by slicing the
LoRA B's also.
Based on S-LoRA, slicing happens along the output dim.
This yields a combined partial sum from the row parallel base
layer and column partitioned output from the LoRA.
"""
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
return lora_b
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
# tensor, which is a partial sum due to row parallel. All that
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0,
start_idx, shard_size)
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)

1181
vllm/lora/layers.py Normal file

File diff suppressed because it is too large Load Diff

167
vllm/lora/lora.py Normal file
View File

@@ -0,0 +1,167 @@
from typing import List, Optional
import torch
from vllm.utils import is_pin_memory_available
class LoRALayerWeights:
"""LoRA weights for a layer composed of two low rank matrixes."""
def __init__(
self,
module_name: str,
rank: int,
lora_alpha: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor] = None,
scaling: Optional[float] = None,
) -> None:
self.module_name = module_name
self.rank = rank
self.lora_alpha = lora_alpha
self.lora_a = lora_a
self.lora_b = lora_b
self.embeddings_tensor = embeddings_tensor
if scaling is None:
self.scaling = self.lora_alpha / self.rank
else:
self.scaling = scaling
def optimize(self) -> "LoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
if self.scaling == 1:
return self
self.lora_b *= self.scaling
self.scaling = 1
return self
@property
def input_dim(self) -> int:
return self.lora_a.shape[0]
@property
def output_dim(self) -> int:
return self.lora_b.shape[1]
@property
def is_packed(self) -> bool:
return False
@property
def extra_vocab_size(self) -> int:
return self.embeddings_tensor.shape[
0] if self.embeddings_tensor is not None else 0
@classmethod
def create_dummy_lora_weights(
cls,
module_name: str,
input_dim: int,
output_dim: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank],
dtype=dtype,
device=device,
pin_memory=pin_memory)
lora_b = torch.zeros([rank, output_dim],
dtype=dtype,
device=device,
pin_memory=pin_memory)
embeddings_tensor = torch.rand(
10,
embeddings_tensor_dim,
dtype=dtype,
device=device,
pin_memory=pin_memory) if embeddings_tensor_dim else None
return cls(
module_name,
rank=rank,
lora_alpha=1,
lora_a=lora_a,
lora_b=lora_b,
embeddings_tensor=embeddings_tensor,
)
class PackedLoRALayerWeights(LoRALayerWeights):
"""LoRA used for packed layers (eg. qkv_proj)."""
def __init__(
self,
module_name: str,
rank: int,
lora_alphas: List[Optional[int]],
lora_a: List[Optional[torch.Tensor]],
lora_b: List[Optional[torch.Tensor]],
scaling: Optional[List[float]] = None,
) -> None:
super().__init__(
module_name=module_name,
rank=rank,
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
scaling=scaling, # type: ignore
embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:
self.scaling = [ # type: ignore
lora_alpha / self.rank # type: ignore # noqa
for lora_alpha in self.lora_alphas
]
@classmethod
def pack(
cls, loras: List[Optional["LoRALayerWeights"]]
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
"""
first_lora = next(lora for lora in loras if lora is not None)
for lora in loras:
if lora is None:
continue
lora.optimize()
rank = first_lora.rank
module_name = first_lora.module_name
obj = cls(
module_name,
rank,
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
scaling=[
1 if lora is not None else None # type: ignore
for lora in loras
])
return obj
def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
continue
self.lora_b[i] *= self.scaling[i] # type: ignore
self.scaling[i] = 1 # type: ignore
return self
@property
def input_dim(self) -> int:
raise NotImplementedError()
@property
def output_dim(self) -> int:
raise NotImplementedError()
@property
def is_packed(self) -> bool:
return True

645
vllm/lora/models.py Normal file
View File

@@ -0,0 +1,645 @@
import copy
import json
import math
import os
import re
from typing import Callable, Dict, List, Optional, Tuple, Type
import safetensors.torch
import torch
from torch import nn
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.utils import LRUCache, is_pin_memory_available
logger = init_logger(__name__)
_GLOBAL_LORA_ID = 0
def convert_mapping(
mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
max_loras: int, vocab_size: int, extra_vocab_size: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
indices_len: List of lengths of the above tensors.
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
index_mapping_indices[i] = i
lora_indices[i] = lora_idx
indices = torch.tensor(
[index_mapping_indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]
return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len)
def get_lora_id():
global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1
return _GLOBAL_LORA_ID
class LoRAModel:
"""A LoRA fine-tuned model."""
def __init__(
self,
lora_model_id: int,
rank: int,
loras: Dict[str, LoRALayerWeights],
) -> None:
self.id = lora_model_id
assert (lora_model_id >
0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras
@property
def extra_vocab_size(self) -> int:
return max(lora.extra_vocab_size
for lora in self.loras.values()) if self.loras else 0
def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
"""Get LoRA for a given module by name"""
return self.loras.get(module_name, None)
# (yard1): TODO see if we can derive target_embedding_padding automatically
@classmethod
def from_lora_tensors(
cls,
lora_model_id: int,
rank: int,
lora_alpha: int,
tensors: Dict[str, torch.Tensor],
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
assert embedding_modules is not None
embeddings_module = next(
(k for k in embedding_modules if k in module_name),
None)
if embeddings_module:
lora_embeddings_tensor = embeddings[
embedding_modules[embeddings_module]].to(
device=device, dtype=dtype)
if pin_memory:
lora_embeddings_tensor = (
lora_embeddings_tensor.pin_memory())
loras[module_name] = LoRALayerWeights(module_name, rank,
lora_alpha, None, None,
lora_embeddings_tensor)
if is_lora_a:
loras[module_name].lora_a = tensor.to(device=device,
dtype=dtype).t()
if pin_memory:
loras[module_name].lora_a = loras[
module_name].lora_a.pin_memory()
else:
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
assert embedding_padding_modules is not None
if any(name in module_name
for name in embedding_padding_modules
) and target_embedding_padding is not None:
lora_b = loras[module_name].lora_b
assert target_embedding_padding >= lora_b.shape[1]
addition = target_embedding_padding - lora_b.shape[1]
loras[module_name].lora_b = torch.nn.functional.pad(
lora_b, (0, addition))
if pin_memory:
loras[module_name].lora_b = loras[
module_name].lora_b.pin_memory()
for lora in loras.values():
lora.optimize()
return cls(lora_model_id, rank, loras)
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: List[str],
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint."""
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin")
with open(lora_config_path) as f:
config = json.load(f)
target_modules = config["target_modules"]
unexpected_modules = []
for module in target_modules:
# Compatible with more modules, such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules
if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
if os.path.isfile(lora_tensor_path):
tensors = safetensors.torch.load_file(lora_tensor_path)
elif os.path.isfile(lora_bin_file_path):
tensors = torch.load(lora_bin_file_path)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
embeddings = None
if os.path.isfile(new_embeddings_tensor_path):
embeddings = safetensors.torch.load_file(
new_embeddings_tensor_path)
elif os.path.isfile(new_embeddings_bin_file_path):
embeddings = torch.load(new_embeddings_bin_file_path)
rank = config["r"]
lora_alpha = config["lora_alpha"]
return cls.from_lora_tensors(
lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id,
rank=rank,
lora_alpha=lora_alpha,
tensors=tensors,
device=device,
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
)
class LoRAModelManager:
"""A manager that manages multiple LoRA-fine-tuned models."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
):
"""Create a LoRAModelManager and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
vocab_size: the vocab size of the model.
lora_config: the LoRA configuration.
"""
self.lora_config = lora_config
self.max_num_seqs = max_num_seqs
assert self.capacity >= self.lora_slots
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.base_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.embeddings_indices = torch.empty(2,
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len: List[Optional[int]] = [None] * 4
self.model: nn.Module = model
if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules)
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {}
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self
@property
def capacity(self) -> int:
return self.lora_config.max_cpu_loras
@property
def lora_slots(self) -> int:
return self.lora_config.max_loras
def __len__(self) -> int:
return len(self._registered_loras)
def activate_lora(
self,
lora_id: int,
) -> bool:
"""Move LoRA into a GPU buffer to be used in the forward pass."""
if lora_id in self._active_loras:
return False
first_free_slot = next(
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
if lora_id is None), None)
if first_free_slot is None:
raise ValueError("No free lora slots")
index, _ = first_free_slot
self._active_loras[lora_id] = None
lora_model = self._registered_loras[lora_id]
logger.debug("Activating LoRA. int id: %d, slot index: %d",
lora_model.id, index)
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name)
if module_lora:
module_lora.optimize()
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
module_lora.embeddings_tensor)
else:
module.reset_lora(index)
return True
def _deactivate_lora(self, lora_id: int):
try:
index = self.lora_index_to_id.index(lora_id)
self.lora_index_to_id[index] = None
except ValueError:
pass
def deactivate_lora(self, lora_id: int) -> bool:
"""Remove a LoRA from a GPU buffer."""
if lora_id in self._active_loras:
self._deactivate_lora(lora_id)
self._active_loras.pop(lora_id)
return True
return False
def _add_lora(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager CPU cache."""
if lora.id not in self._registered_loras:
if len(self._registered_loras) >= self.capacity:
raise RuntimeError("No free LoRA slots.")
self._add_lora(lora)
return True
return False
def remove_lora(self, lora_id: int) -> bool:
"""Remove a LoRAModel from the manager CPU cache."""
# TODO: should we check active lora?
self.deactivate_lora(lora_id)
return bool(self._registered_loras.pop(lora_id, None))
# TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices,
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
self.lora_slots + 1, self.vocab_size,
self.lora_config.lora_extra_vocab_size)
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded)
self.embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
# Maintain the reference
self.indices_len[:] = indices_len
def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
if self._last_mapping != lora_mapping:
self._set_lora_mapping(lora_mapping)
self._last_mapping = lora_mapping
def list_loras(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_loras)
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None)
def remove_all_loras(self):
"""Remove all LoRAModels from the manager."""
self._registered_loras.clear()
self.lora_index_to_id = [None] * self.lora_slots
self._active_loras.clear()
def _create_lora_modules(self):
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name):
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule(
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config))
# (yard1): TODO make this more robust
if "lm_head" in module_name:
logits_processor_module = self.model.get_submodule(
"logits_processor")
new_module = replace_submodule(
self.model, "logits_processor",
from_layer_logits_processor(logits_processor_module,
module, self.lora_slots,
self.lora_config,
self.model.config))
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
new_module.set_mapping(self.base_indices, self.sampler_indices,
self.sampler_indices_padded,
self.embeddings_indices, self.indices_len)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA)
self.modules[module_name] = module
def create_dummy_lora(
self,
lora_id: int,
rank: int,
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name) or not isinstance(
module, BaseLayerWithLoRA):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
assert embedding_modules is not None
if parts[-1] in embedding_modules:
input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if
hasattr(module.base_layer, "org_vocab_size")
else module.base_layer.weight.shape[1])
output_dim = module.base_layer.embedding_dim if hasattr(
module.base_layer,
"embedding_dim") else module.base_layer.weight.shape[0]
embeddings_tensor_dim = (module.base_layer.embedding_dim if
hasattr(module.base_layer,
"embedding_dim") else
module.base_layer.weight.shape[1])
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
input_dim,
output_dim,
rank,
module.lora_a_stacked.dtype,
"cpu",
embeddings_tensor_dim=embeddings_tensor_dim)
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.lora_a_stacked.shape[-1],
module.lora_b_stacked.shape[-2],
rank,
module.lora_a_stacked.dtype,
"cpu",
)
lora.optimize()
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
subloras: List[Optional["LoRALayerWeights"]] = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
module.lora_a_stacked[i].shape[-1],
module.lora_b_stacked[i].shape[-2],
rank,
module.lora_a_stacked[i].dtype,
"cpu",
)
lora.optimize()
subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora
return model
def _match_target_modules(self, module_name: str):
return any(
re.match(
r".*\.{target_module}$".format(target_module=target_module),
module_name) or target_module == module_name
for target_module in self.supported_lora_modules)
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
replacements = self.packed_modules_mapping.get(module_name, [])
# When replacements is less than or equal to 1, it indicates that this
# module is not a packed module.
if len(replacements) <= 1:
return
prefix = ".".join(parts[:-1])
self.packed_modules[module_full_name] = [
prefix + "." + r if prefix else r for r in replacements
]
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras: List[Optional[LoRALayerWeights]] = []
has_replacement = False
for r in new_module_names:
lora = lora_model.get_lora(r)
replacement_loras.append(lora)
if lora:
has_replacement = True
if not has_replacement:
continue
for i in range(len(replacement_loras)):
if replacement_loras[i]:
continue
replacement_loras[i] = None
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras)
class LoRALRUCache(LRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
bool]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: int, value: LoRAModel):
logger.debug("Removing LoRA. int id: %d", key)
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
class LRUCacheLoRAModelManager(LoRAModelManager):
"""A model manager that manages multiple LoRAs with LRU cache."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
):
super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config)
self._registered_loras: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_lora)
self._active_loras: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_lora)
def list_loras(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_loras.cache)
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager."""
if lora.id not in self._registered_loras:
self._add_lora(lora)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_loras.touch(lora.id)
was_added = False
return was_added
def activate_lora(
self,
lora_id: int,
) -> bool:
if lora_id not in self._active_loras and len(
self._active_loras) >= self.lora_slots:
self._active_loras.remove_oldest()
result = super().activate_lora(lora_id)
# We always touch to update the LRU cache order
self._active_loras.touch(lora_id)
return result
def remove_oldest_lora(self) -> bool:
if len(self._registered_loras) > 0:
self._registered_loras.remove_oldest()
return True
return False
def create_lora_manager(
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not hasattr(model, "supported_lora_modules"):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size,
lora_config=lora_config,
**kwargs)
return lora_manager

213
vllm/lora/punica.py Normal file
View File

@@ -0,0 +1,213 @@
# Based on code from https://github.com/punica-ai/punica
from typing import Optional
import torch
def _raise_import_error(e):
if torch.cuda.get_device_capability() < (8, 0):
raise ImportError(
"punica LoRA kernels require compute capability >= 8.0") from e
else:
raise ImportError(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set.") from e
def bgmv(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
matrices.
indicies: Shape: `[B]`. Indices of the weight matrices.
layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, indicies: torch.LongTensor,
layer_idx: int, scale: float, y_offset: int,
y_slice_size: int):
"""
Same as `bgmv` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
all of the transposed LoRA matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
punica_kernels.dispatch_bgmv_low_level(
y,
x,
w_t_all,
indicies,
layer_idx,
scale,
x.size(1),
y_slice_size,
y_offset,
)
def add_lora(y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
*,
buffer: Optional[torch.Tensor] = None):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical inaccuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
scale)
def add_lora_slice(y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
y_offset: int,
y_slice_size: int,
*,
buffer: Optional[torch.Tensor] = None):
"""
Same as `add_lora` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical inaccuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv_low_level(
buffer,
x,
wa_t_all,
indicies,
layer_idx,
1.0,
x.size(1),
buffer.size(1),
0,
)
punica_kernels.dispatch_bgmv_low_level(
y,
buffer,
wb_t_all,
indicies,
layer_idx,
scale,
buffer.size(1),
y_slice_size,
y_offset,
)

32
vllm/lora/request.py Normal file
View File

@@ -0,0 +1,32 @@
from dataclasses import dataclass
@dataclass
class LoRARequest:
"""
Request for a LoRA adapter.
Note that this class should be be used internally. For online
serving, it is recommended to not allow users to use this class but
instead provide another layer of abstraction to prevent users from
accessing unauthorized LoRA adapters.
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
"""
lora_name: str
lora_int_id: int
lora_local_path: str
def __post_init__(self):
if self.lora_int_id < 1:
raise ValueError(
f"lora_int_id must be > 0, got {self.lora_int_id}")
def __eq__(self, value: object) -> bool:
return isinstance(
value, LoRARequest) and self.lora_int_id == value.lora_int_id
def __hash__(self) -> int:
return self.lora_int_id

98
vllm/lora/utils.py Normal file
View File

@@ -0,0 +1,98 @@
from typing import List, Optional, Set, Tuple, Type
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger = init_logger(__name__)
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA
}
def from_layer(layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes:
# specifying kwargs so they can be easily accessed in decorator
if lora_cls.can_replace_layer(source_layer=layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config):
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer
def from_layer_logits_processor(
layer: LogitsProcessor,
lm_head: ParallelLMHead,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.dtype, lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
def replace_submodule(model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
"""Parse the name of lora weights.
args:
name: the name of the fine-tuned LoRA, e.g.
base_model.model.dense1.weight
return:
Tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
"""
parts = name.split(".")
assert parts[0] == "base_model"
assert parts[1] == "model"
if parts[-1] == "weight":
assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
raise ValueError(f"{name} is unsupported format")

251
vllm/lora/worker_manager.py Normal file
View File

@@ -0,0 +1,251 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, List, Set, Type
import torch
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest
logger = init_logger(__name__)
class AbstractWorkerLoRAManager(ABC):
"""Abstract class for managing LoRA models on the worker side."""
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
vocab_size: int, lora_config: LoRAConfig,
device: torch.device):
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.device = device
self.lora_config = lora_config
@abstractproperty
def is_enabled(self) -> bool:
...
@abstractmethod
def create_lora_manager(
self,
model: torch.nn.Module,
) -> Any:
...
@abstractmethod
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
...
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
...
@abstractmethod
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
...
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
...
@abstractmethod
def remove_all_loras(self):
...
@abstractmethod
def list_loras(self) -> Set[int]:
...
class WorkerLoRAManager(AbstractWorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Every request, the requested LoRAs will be loaded (unless they are already
loaded), and every other LoRA will be unloaded."""
_lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager
def __init__(
self,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
embedding_modules: Dict[str, str],
embedding_padding_modules: List[str],
lora_model_cls: Type[LoRAModel] = LoRAModel,
):
self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules
# Lazily initialized by create_lora_manager.
self._lora_manager: LoRAModelManager
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device)
@property
def is_enabled(self) -> bool:
return True
def create_lora_manager(
self,
model: torch.nn.Module,
) -> Any:
lora_manager = create_lora_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls,
)
self._lora_manager = lora_manager
return lora_manager.model
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
self._apply_loras(lora_requests)
self._lora_manager.set_lora_mapping(lora_mapping)
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_that_exist = self.list_loras()
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._lora_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).")
new_loras = set(loras_map)
loras_to_add = new_loras - loras_that_exist
loras_to_remove = loras_that_exist - new_loras
for lora_id in loras_to_remove:
self.remove_lora(lora_id)
for lora_id in loras_to_add:
self.add_lora(loras_map[lora_id])
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
try:
model = self._lora_manager.model
supported_lora_modules = model.supported_lora_modules
packed_modules_mapping = model.packed_modules_mapping
expected_lora_modules = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(
packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
lora = self._lora_model_cls.from_local_checkpoint(
lora_request.lora_local_path,
expected_lora_modules,
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
target_embedding_padding=self.vocab_size +
self.lora_config.lora_extra_vocab_size,
embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules,
)
except Exception as e:
raise RuntimeError(
f"Loading lora {lora_request.lora_local_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank:
raise ValueError(
f"LoRA rank {lora.rank} is greater than max_lora_rank "
f"{self.lora_config.max_lora_rank}.")
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
f"is greater than lora_extra_vocab_size "
f"{self.lora_config.lora_extra_vocab_size}.")
return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras():
return False
return self._lora_manager.add_lora(
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
rank, self.embedding_modules))
def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras():
return False
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
self._lora_manager.activate_lora(lora.id)
return loaded
def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id)
def remove_all_loras(self):
self._lora_manager.remove_all_loras()
def list_loras(self) -> Set[int]:
return set(self._lora_manager.list_loras())
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Uses an LRU Cache. Every request, the requested LoRAs will be loaded
(unless they are already loaded) and least recently used LoRAs will
be unloaded if the cache is above capacity."""
_lora_manager_cls: Type[
LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
def create_lora_manager(
self,
model: torch.nn.Module,
) -> Any:
lora_manager = create_lora_manager(
model,
lora_manager_cls=self._lora_manager_cls,
max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
max_num_batched_tokens=self.max_num_batched_tokens,
)
self._lora_manager = lora_manager
return lora_manager.model
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._lora_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).")
for lora in loras_map.values():
self.add_lora(lora)
def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id not in self.list_loras():
# Remove before we load the new lora to save memory
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
self._lora_manager.remove_oldest_lora()
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
else:
# If the lora is already loaded, just touch it to
# update its position in the caches
loaded = self._lora_manager.get_lora(
lora_request.lora_int_id) is not None
self._lora_manager.activate_lora(lora_request.lora_int_id)
return loaded

View File

@@ -0,0 +1,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
__all__ = [
"SamplingMetadata",
"set_random_seed",
]

View File

@@ -0,0 +1,25 @@
from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
get_lm_format_enforcer_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
async def get_guided_decoding_logits_processor(
guided_decoding_backend: str, request: Union[CompletionRequest,
ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
if guided_decoding_backend == 'outlines':
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")

View File

@@ -0,0 +1,70 @@
from functools import lru_cache
from json import loads as json_loads
from typing import Optional, Union
from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
RegexParser, StringParser,
TokenEnforcerTokenizerData, UnionParser)
from lmformatenforcer.integrations.vllm import (
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
async def get_lm_format_enforcer_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer)
character_level_parser: CharacterLevelParser
if request.guided_json:
schema = _normalize_json_schema_object(request.guided_json)
character_level_parser = JsonSchemaParser(schema)
elif request.guided_choice:
character_level_parser = UnionParser(
[StringParser(choice) for choice in request.guided_choice])
elif request.guided_regex:
character_level_parser = RegexParser(request.guided_regex)
elif request.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
elif (request.response_format is not None
and request.response_format.type == "json_object"):
character_level_parser = JsonSchemaParser(
None) # None means any json object
else:
return None
logits_processor = build_vllm_logits_processor(tokenizer_data,
character_level_parser)
return logits_processor
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
if isinstance(schema, str):
return json_loads(schema)
if isinstance(schema, dict):
return schema
if isinstance(schema, BaseModel):
return schema.model_json_schema()
raise AssertionError(f"Unsupported schema type {schema}")
@lru_cache
def _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
return build_vllm_token_enforcer_tokenizer_data(tokenizer)

View File

@@ -0,0 +1,130 @@
import asyncio
import concurrent.futures
from copy import copy
from enum import Enum
from functools import lru_cache
from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Tuple, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
class GuidedDecodingMode(Enum):
JSON = "json"
REGEX = "regex"
CHOICE = "choice"
GRAMMAR = "grammar"
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
# the main difference is that we changed the start: value to
# start: object | array, so we are denying scalar values as the root of the
# JSON. Starting with scalars as the root seems to cause llama to generate
# without stop.
JSON_GRAMMAR = r"""
?start: object | array
?value: object
| array
| UNESCAPED_STRING
| SIGNED_NUMBER -> number
| "true" -> true
| "false" -> false
| "null" -> null
array : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair : UNESCAPED_STRING ":" value
%import common.UNESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS
%ignore WS
"""
global_thread_pool = None # used for generating logits processor fsm
async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
global global_thread_pool
guide, mode = _get_guide_and_mode(request)
if not guide:
return None
if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2)
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(global_thread_pool,
_get_cached_logits_processor, guide,
tokenizer, mode,
request.guided_whitespace_pattern)
logits_processor = copy(result)
# reset logits processor's internal state
logits_processor.init_state()
return logits_processor
def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest]
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
if request.guided_json:
json = request.guided_json
if isinstance(json, dict):
# turn dict into hashable string
json = json_dumps(json)
elif isinstance(json, BaseModel):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json = str(json.__signature__)
return json, GuidedDecodingMode.JSON
elif request.guided_regex:
return request.guided_regex, GuidedDecodingMode.REGEX
elif request.guided_choice:
# choice just uses regex
choices = [
regex_escape(str(choice)) for choice in request.guided_choice
]
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE
elif request.guided_grammar:
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
elif (request.response_format is not None
and request.response_format.type == "json_object"):
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
else:
return None, None
@lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str,
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None]):
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.GRAMMAR:
return CFGLogitsProcessor(guide, tokenizer)
else:
raise ValueError(f"Unknown guided decoding mode {mode}")

View File

@@ -0,0 +1,184 @@
# Copyright 2024- the Outlines developers
# This file is adapted from
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import math
from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union
import torch
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
class BaseLogitsProcessor:
def __init__(self):
# Child class should use initialize in their init.
self.fsm: FSM
def init_state(self):
"""Initialize the FSM states."""
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
seq_id = hash(tuple(input_ids))
if len(input_ids) == 0:
self.init_state()
else:
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[last_seq_id], last_token)
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
mask = torch.full((scores.shape[-1], ),
-math.inf,
device=scores.device)
mask[allowed_tokens] = 0
scores.add_(mask)
return scores
class RegexLogitsProcessor(BaseLogitsProcessor):
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the regex-structured generation.
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
"""
tokenizer = _adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None]):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
----------
schema
A JSON schema that encodes the structure we want the model to
generate
tokenizer
The model's tokenizer
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact
string literals)
Example: allow only a single space or newline with
`whitespace_pattern=r"[\n ]?"`
"""
if isinstance(schema, type(BaseModel)):
schema_str = json.dumps(schema.model_json_schema())
elif isinstance(schema, Dict):
schema_str = json.dumps(schema)
elif isinstance(schema, str):
schema_str = schema
else:
raise ValueError(
f"Cannot parse schema {schema}. The schema must be either "
f"a Pydantic object, a dictionary or a string that contains "
f"the JSON Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer)
class CFGLogitsProcessor(BaseLogitsProcessor):
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the context free grammar generation.
Parameters
----------
cfg
A string that represents a context-free grammar
tokenizer
The model's tokenizer
"""
tokenizer = _adapt_tokenizer(tokenizer)
fsm = CFGFSM(cfg, tokenizer)
self.fsm = fsm
def init_state(self):
"""Initialize state with a CFGFSM copy."""
super().init_state()
self.fsm = self.fsm.copy()
@lru_cache
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
"""
if getattr(tokenizer, "_outlines_adapted", False):
return tokenizer
tokenizer = copy.deepcopy(tokenizer)
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
def change_decoder(
decoder: Callable[[List[int]],
str]) -> Callable[[List[int]], List[str]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def new_decoder(inp_tokens: List[int]) -> List[str]:
return [decoder(inp_tokens)]
return new_decoder
tokenizer.convert_token_to_string = convert_token_to_string
tokenizer.decode = change_decoder(tokenizer.decode)
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
return tokenizer

View File

View File

@@ -0,0 +1,173 @@
"""Custom activation functions."""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
return out
class GeluAndMul(nn.Module):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def __init__(self, approximate: str = "none"):
super().__init__()
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out
def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}'
class NewGELU(nn.Module):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0))))
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
ops.gelu_new(out, x)
return out
class FastGELU(nn.Module):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
This is used for some quantization methods like AWQ.
"""
def __init__(
self,
act_module: nn.Module,
intermediate_size: int,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.act = act_module
self.input_is_parallel = input_is_parallel
if input_is_parallel:
tp_size = get_tensor_model_parallel_world_size()
intermediate_size_per_partition = divide(intermediate_size,
tp_size)
else:
intermediate_size_per_partition = intermediate_size
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.scales = nn.Parameter(
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(x) / self.scales
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
if self.input_is_parallel:
tp_rank = get_tensor_model_parallel_rank()
shard_size = param_data.shape[0]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_fast": FastGELU(),
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
}
def get_act_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn

View File

@@ -0,0 +1,7 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_moe, get_config_file_name)
__all__ = [
"fused_moe",
"get_config_file_name",
]

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

Some files were not shown because too many files have changed in this diff Show More