This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

Binary file not shown.

Binary file not shown.

View File

View File

@@ -0,0 +1,982 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional
import os
import numpy as np
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
# from vllm_vacc.vllm.attention.backends.vacc_attn import (VACCAttentionBackendImpl,
# VACCAttentionMetadata)
# from vllm_vacc.vllm.attention.backends.vacc_attn import VACCAttentionBackendImpl
# from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: list[int],
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype))
return attn_biases
def _make_sliding_window_bias(
seq_lens: list[int],
window_size: Optional[int],
dtype: torch.dtype,
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
for seq_len in seq_lens:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = torch.log(mask)
attn_biases.append(mask.to(dtype))
return attn_biases
@dataclass
class VACCAttentionMetadata(AttentionMetadata):
"""Metadata for VACCAttentionMetadata.
"""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
chunked_prefill: bool
seq_lens: Optional[list[int]] = None # For non-chunked prefill
# For chunked prefill only
max_query_len: Optional[int] = None
max_kv_len: Optional[int] = None
prefill_query_start_loc: Optional[torch.Tensor] = None
kv_start_loc: Optional[torch.Tensor] = None
prefill_block_tables: Optional[torch.Tensor] = None
# For V1 logits index only
query_start_loc: Optional[torch.Tensor] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[list[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[list[torch.Tensor]] = None
self.encoder_attn_bias: Optional[list[torch.Tensor]] = None
self.cross_attn_bias: Optional[list[torch.Tensor]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property
def prefill_metadata(self) -> Optional["VACCAttentionMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefill_tokens == 0:
return None
return self
@property
def decode_metadata(self) -> Optional["VACCAttentionMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
return None
return self
def get_seq_lens(
self,
attn_type: AttentionType,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
seq_lens_q = self.encoder_seq_lens
seq_lens_kv = self.encoder_seq_lens
elif attn_type == AttentionType.ENCODER_DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.encoder_seq_lens
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
return seq_lens_q, seq_lens_kv
def get_attn_bias(
self,
attn_type: AttentionType,
) -> Optional[list[torch.Tensor]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return self.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def set_attn_bias(
self,
attn_bias: list[torch.Tensor],
attn_type: AttentionType,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
self.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_seq_len_block_table_args(
self,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len,
self.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
self.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
# class VACCMetadataBuilder(AttentionMetadataBuilder[VACCAttentionMetadata]):
# def __init__(self, input_builder: ModelInputForVACCBuilder) -> None:
# self.chunked_prefill = input_builder.chunked_prefill
# self.input_builder = input_builder
# def prepare(self):
# self.input_data = self.input_builder.input_data
# def build(self, seq_lens: list[int], query_lens: list[int],
# cuda_graph_pad_size: int, batch_size: int) -> VACCAttentionMetadata:
# input_data = self.input_data
# prefill_seq_lens = seq_lens[0:input_data.num_prefills]
# prefill_query_lens = query_lens[0:input_data.num_prefills]
# slot_mapping = torch.tensor(input_data.slot_mapping,
# dtype=torch.int32,
# device=self.input_builder.device)
# # For chunked-prefill
# if self.chunked_prefill and input_data.num_prefill_tokens != 0:
# prefill_block_tables = make_tensor_with_pad(
# self.input_data.prefill_block_tables,
# pad=0,
# dtype=torch.int32,
# device=self.input_builder.device,
# )
# query_lens_tensor = torch.tensor(prefill_query_lens,
# dtype=torch.int32,
# device=self.input_builder.device)
# kv_lens_tensor = torch.tensor(prefill_seq_lens,
# dtype=torch.int32,
# device=self.input_builder.device)
# query_start_loc = torch.zeros(input_data.num_prefills + 1,
# dtype=torch.int32,
# device=self.input_builder.device)
# kv_start_loc = torch.zeros(input_data.num_prefills + 1,
# dtype=torch.int32,
# device=self.input_builder.device)
# torch.cumsum(query_lens_tensor,
# dim=0,
# dtype=torch.int32,
# out=query_start_loc[1:])
# torch.cumsum(kv_lens_tensor,
# dim=0,
# dtype=torch.int32,
# out=kv_start_loc[1:])
# max_query_len = max(prefill_query_lens)
# max_kv_len = max(prefill_seq_lens)
# else:
# prefill_block_tables = None
# query_start_loc = None
# kv_start_loc = None
# max_query_len = None
# max_kv_len = None
# # For paged attention
# if input_data.num_decode_tokens != 0:
# seq_lens_tensor = torch.tensor(
# input_data.seq_lens[input_data.num_prefills:],
# dtype=torch.int32,
# device=self.input_builder.device,
# )
# block_tables = make_tensor_with_pad(
# self.input_data.decode_block_tables,
# pad=0,
# dtype=torch.int32,
# device=self.input_builder.device,
# )
# # lowest_dim_size = block_tables.size(-1)
# # if lowest_dim_size < 1024:
# # padding_amount = 1024 - lowest_dim_size
# # padding = torch.zeros(*block_tables.size()[:-1], padding_amount, dtype=block_tables.dtype, device=block_tables.device)
# # block_tables = torch.cat((block_tables, padding), dim=-1)
# else:
# block_tables = torch.tensor([])
# seq_lens_tensor = torch.tensor(
# input_data.seq_lens[:input_data.num_prefills],
# dtype=torch.int32,
# device=self.input_builder.device,
# )
# # For multi-modal models
# placeholder_index_maps = None
# if len(input_data.multi_modal_inputs_list) != 0:
# placeholder_index_maps = {
# modality: placeholder_map.index_map()
# for modality, placeholder_map in
# input_data.multi_modal_placeholder_maps.items()
# }
# attn_metadata = VACCAttentionMetadata(
# chunked_prefill=self.chunked_prefill,
# seq_lens=seq_lens, #prefill_seq_lens,
# seq_lens_tensor=seq_lens_tensor,
# max_query_len=max_query_len,
# max_kv_len=max_kv_len,
# query_start_loc=query_start_loc,
# kv_start_loc=kv_start_loc,
# max_decode_seq_len=None,
# num_prefills=input_data.num_prefills,
# num_prefill_tokens=input_data.num_prefill_tokens,
# num_decode_tokens=input_data.num_decode_tokens,
# block_tables=block_tables,
# prefill_block_tables=prefill_block_tables,
# slot_mapping=slot_mapping,
# multi_modal_placeholder_index_maps=placeholder_index_maps,
# enable_kv_scales_calculation=False,
# )
# return attn_metadata
def fp32_attention(
query_layer,
key_layer,
value_layer,
mask,
norm_factor,
out_type=None,
):
ori_type = out_type if out_type is not None else query_layer.dtype
query_layer = query_layer.to(torch.float32)
key_layer = key_layer.to(torch.float32)
value_layer = value_layer.to(torch.float32)
# GQA
if query_layer.size(1) != key_layer.size(1):
if query_layer.size(1) % key_layer.size(1) != 0:
assert False
groups = query_layer.size(1) // key_layer.size(1)
key_layer = torch.repeat_interleave(key_layer, groups, dim=1)
value_layer = torch.repeat_interleave(value_layer, groups, dim=1)
matmul_result = torch.bmm(
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
) * norm_factor
mask_output = matmul_result
if mask != None:
mask = mask if mask.dim() >= 3 else mask.unsqueeze(0)
mask_output = matmul_result.masked_fill_(mask, -10000.0) # [b * np, sq, sk]
probs = torch.nn.Softmax(dim=-1)(mask_output)
context_layer = torch.bmm(probs, value_layer.transpose(0, 1))
return context_layer.transpose(0, 1).to(ori_type)
class VACCAttentionBackendImpl(AttentionImpl[VACCAttentionMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
# if logits_soft_cap is not None:
# logger.warning_once("Torch SPDA does not support logits soft cap. "
# "Outputs may be slightly off.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
if kv_cache_dtype != "auto":
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
self.attn_type = attn_type
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: VACCAttentionMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
layer._k_scale, layer._v_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0
or prefill_meta.block_tables.numel() == 0):
self._run_vacc_forward(
output,
query,
key,
value,
prefill_meta,
attn_type=attn_type)
else:
# prefix-enabled attention
assert not self.need_mask
import intel_extension_for_pytorch.llm.modules as ipex_modules
output = torch.empty_like(query)
ipex_modules.PagedAttention.flash_attn_varlen_func(
output[:prefill_meta.num_prefill_tokens, :, :],
query[:prefill_meta.num_prefill_tokens, :, :],
key_cache,
value_cache,
prefill_meta.query_start_loc,
prefill_meta.kv_start_loc,
prefill_meta.max_query_len,
prefill_meta.max_kv_len,
self.scale,
True,
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
# Decoding run.
# (
# seq_lens_arg,
# max_seq_len_arg,
# block_tables_arg,
# ) = decode_meta.get_seq_len_block_table_args(attn_type)
# Note:
# decode attention still use SDPA method
# reshape k/v_cache to (num_block_grp, block_grp_size, head, hidden_size)
k_cache = key_cache.view(-1, env_blk_grp_size, key_cache.shape[2], key_cache.shape[3])
v_cache = value_cache.view(-1, env_blk_grp_size, value_cache.shape[2], value_cache.shape[3])
block_per_group = env_blk_grp_size // 16
# convert block_tables to 8K group index
block_tables = (decode_meta.block_tables // block_per_group).to(torch.int32)
attn_outs = []
for i in range(len(decode_meta.seq_lens_tensor)):
seq_len = decode_meta.seq_lens_tensor[i]
k_slices = k_cache[block_tables[i], ...]
k = \
torch.cat([k_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
v_slices = v_cache[block_tables[i], ...]
v = \
torch.cat([v_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
q = query[i : i + 1, ...]
if q.dtype == torch.bfloat16:
attn_out = fp32_attention(
q.cpu(),
k.cpu(),
v.cpu(),
None,
self.scale
).to(query.dtype).to(query.device)
else:
attn_out = torch.vacc.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale,
)
attn_outs.append(attn_out)
output = torch.cat(attn_outs, dim=0)
# '''
# PagedAttention.forward_decode(
# output[attn_metadata.num_prefill_tokens:, :, :],
# query[attn_metadata.num_prefill_tokens:, :, :],
# key_cache,
# value_cache,
# block_tables_arg,
# seq_lens_arg,
# max_seq_len_arg,
# self.kv_cache_dtype,
# self.num_kv_heads,
# self.scale,
# self.alibi_slopes,
# layer._k_scale,
# layer._v_scale,
# )
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _run_vacc_forward(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: VACCAttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
):
# if self.num_kv_heads != self.num_heads:
# key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
# value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
attn_masks = attn_metadata.get_attn_bias(attn_type)
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
causal_attn = (attn_type == AttentionType.DECODER)
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
start_q, start_kv = 0, 0
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sub_out=torch.vacc.scaled_dot_product_attention(
query[start_q:end_q,:, :].to(torch.float16) * (self.scale),
key[start_kv:end_kv,:, :].to(torch.float16),
value[start_kv:end_kv,:, :].contiguous().to(torch.float16),
attn_mask=None,
dropout_p=0.0,
is_causal=True if attn_type == AttentionType.DECODER else False, #causal_attn and not self.need_mask,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=1)
output[ start_q:end_q,:, :] = sub_out
start_q, start_kv = end_q, end_kv
return output
class VACCAttentionBackend(AttentionBackend):
accept_output_buffer: bool = False
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "TORCH_SDPA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["VACCAttentionBackendImpl"]:
return VACCAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return VACCAttentionMetadata
# @staticmethod
# def get_state_cls() -> type["CommonAttentionState"]:
# return CommonAttentionState
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_builder_cls() -> type["VACCAttentionMetadataBuilderV1"]:
return VACCAttentionMetadataBuilderV1
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str:str,
) -> tuple[int, ...]:
# return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
# num_kv_heads, head_size)
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
class VACCAttentionMetadataBuilderV1(AttentionMetadataBuilder[VACCAttentionMetadata]):
# def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
# block_table: BlockTable) -> None:
# self.runner = runner
# self.block_table = block_table
# # For reorder
# self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
# dtype=np.int64)
# self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
# dtype=np.int64)
# self.num_prompt_req: int = 0
# self.seq_start_loc_cpu = torch.zeros(
# runner.max_num_reqs + 1,
# dtype=torch.int32,
# device="cpu",
# )
# self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device) -> None:
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.scheduler_config = vllm_config.scheduler_config
# For reorder
self.reorder_prompt_req_index_list = np.empty(
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
self.reorder_decode_req_index_list = np.empty(
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
self.num_prompt_req: int = 0
self.seq_start_loc_cpu = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device="cpu",
)
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
# def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
# vllm_config: VllmConfig, device: torch.device):
# super().__init__(kv_cache_spec, layer_names, vllm_config, device)
# self.block_size = kv_cache_spec.block_size
# model_config = vllm_config.model_config
# self.num_heads_q = model_config.get_num_attention_heads(
# vllm_config.parallel_config)
# self.num_heads_kv = model_config.get_num_kv_heads(
# vllm_config.parallel_config)
# self.headdim = model_config.get_head_size()
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
prompt_list_idx = 0
decode_list_idx = 0
for req_index in range(input_batch.num_reqs):
if input_batch.num_computed_tokens_cpu[
req_index] < input_batch.num_prompt_tokens[req_index]:
# prompt stage
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
prompt_list_idx += 1
else:
# decode stage
self.reorder_decode_req_index_list[decode_list_idx] = req_index
decode_list_idx += 1
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
# Update prompt requests number
self.num_prompt_req = prompt_list_idx
reorder_req_num = 0
for req_index in range(decode_list_idx):
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
reorder_req_num += 1
else:
break
if reorder_req_num == 0:
return False
reorder_prompt_list = (
self.reorder_prompt_req_index_list[:prompt_list_idx]
[-reorder_req_num:])
reorder_decode_list = (
self.reorder_decode_req_index_list[:decode_list_idx]
[:reorder_req_num])
assert reorder_decode_list.size == reorder_prompt_list.size
for idx in range(reorder_req_num):
prompt_req_index = reorder_prompt_list[idx].item()
decode_req_index = reorder_decode_list[idx].item()
input_batch.swap_states(prompt_req_index, decode_req_index)
return True
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens = common_attn_metadata.seq_lens
# seq_lens = common_attn_metadata.seq_lens
# runner = self.runner
# block_table = self.block_table
# seq_lens = runner.seq_lens[:num_reqs]
num_prompt_req = self.num_prompt_req
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item())
num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() -
num_prefill_tokens)
# print('query_start_loc_cpu', query_start_loc_cpu)
# print('num_prompt_req', num_prompt_req)
# print('num_reqs', num_reqs)
# num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
# num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
# ) - num_prefill_tokens
# block_table.slot_mapping[:num_actual_tokens].copy_(
# block_table.slot_mapping_cpu[:num_actual_tokens],
# non_blocking=True)
# slot_mapping = block_table.slot_mapping[:num_actual_tokens] #.long()
# block_table_tensor = block_table.get_device_tensor()
slot_mapping = common_attn_metadata.slot_mapping
block_table_tensor = common_attn_metadata.block_table_tensor
block_num_per_group = env_blk_grp_size // 16
block_table_tensor_new = block_table_tensor[:num_reqs-num_prompt_req, ::block_num_per_group].contiguous()
# [bs, seq//16] => [bs, seq//16//block_num_per_group, block_num_per_group]
# => [:num_reqs, :, 0] 提取前reqs行并且把 block_num_per_group 的倍数提取出
attn_metadata = VACCAttentionMetadata(
num_prefills=num_prompt_req,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
slot_mapping=slot_mapping,
seq_lens_tensor=seq_lens, # decode
max_decode_seq_len=None, # decode
block_tables=block_table_tensor_new, # decode
chunked_prefill=False,
# max_query_len=max_query_len,
# max_kv_len=max_prefill_seq_len,
# prefill_query_start_loc=runner.
# query_start_loc_cpu[:num_prompt_req + 1], # prefill
# kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
# 1], # prefill
prefill_block_tables=block_table_tensor[:
num_prompt_req], # prefill
query_start_loc=query_start_loc_cpu[:num_reqs +
1], # for logits index
# multi_modal_placeholder_index_maps=None,
# enable_kv_scales_calculation=False,
)
return attn_metadata

View File

@@ -0,0 +1,953 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, List, Optional, Dict, TypeVar, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType,
)
from vllm.attention.backends.utils import get_mla_dims
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.attention.backends.utils import get_mla_dims
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonMetadata)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
get_per_layer_parameters,
infer_global_hyperparameters,
split_decodes_and_prefills)
from vllm.v1.attention.backends.mla.common import MLACommonImpl
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
logger = init_logger(__name__)
M = TypeVar("M", bound=MLACommonMetadata)
def vacc_paged_attention_naive(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_table: torch.Tensor,
# seq_lens: torch.Tensor,
seq_lens: int,
out: Optional[torch.Tensor] = None,
sm_scale = -1
) -> torch.Tensor:
# gurantee batch=1 perf
if len(seq_lens) == 1:
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[0]]
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[0]]
attn_out = torch.vacc.scaled_dot_product_attention(
query=query,
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=sm_scale
)
else:
# t0 = time.time()
attn_outs = []
for i in range(len(seq_lens)):
k_slices = key_cache[block_table[i], :, :, :]
k = torch.cat([k_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
k = k.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
v_slices = value_cache[block_table[i], :, :, :]
v = torch.cat([v_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
v = v.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
attn_out = torch.vacc.scaled_dot_product_attention(
query=query[i:i+1,:,:],
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=sm_scale
)
attn_outs.append(attn_out)
attn_out = torch.cat(attn_outs, dim=0)
# print(f'{os.getpid()} call spda(seq: {seq_lens}) time: {time.time() - t0}')
return attn_out
@dataclass
class MLACommonPrefillMetadata:
""" Prefill Specific Metadata """
@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
block_tables: torch.Tensor #block_table => block_tables 兼容v0
query_start_loc: torch.Tensor
# max_query_len: int
seq_lens: list[int]
chunked_context: Optional[ChunkedContextMetadata] = None
@dataclass
class MLACommonDecodeMetadata:
block_tables: torch.Tensor #block_table => block_tables 兼容v0
seq_lens: torch.Tensor
class VACCMLAMetadata(MLACommonMetadata):
"""Metadata for VACCMLAMetadata.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Maximum query length in the batch.
max_query_len: Optional[int] = None
input_positions: Optional[torch.Tensor] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional["VACCMLAMetadata"] = None
_cached_decode_metadata: Optional["VACCMLAMetadata"] = None
num_prefill_tokens: int
num_kv_splits: int = 4 # TODO(lucas) add heuristic
attn_logits: Optional[torch.Tensor] = None
req_idx: Optional[torch.Tensor] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
def __post_init__(self):
supported_head_sizes = VACCMLABackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
@property
def prefill_metadata(self) -> Optional["VACCMLAMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
input_positions = (None if self.input_positions is None else
self.input_positions[:self.num_prefill_tokens])
self._cached_prefill_metadata = VACCMLAMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
input_positions=input_positions,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=None,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
head_dim=self.head_dim)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["VACCMLAMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.seq_lens_tensor is not None
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
input_positions = (None if self.input_positions is None else
self.input_positions[self.num_prefill_tokens:])
self._cached_decode_metadata = VACCMLAMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=self.seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
input_positions=input_positions,
head_dim=self.head_dim)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForVACCWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
if turn_prefills_into_decodes:
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
# self.num_prefill_tokens = 0
# self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
# assert self.max_query_len == 1
# assert self.max_prefill_seq_len == 0
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
# self.max_decode_seq_len = None
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class VACCMLAImpl(MLACommonImpl[VACCMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
# blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**kwargs) -> None:
# super().__init__(num_heads, head_size, scale, num_kv_heads,
# alibi_slopes, sliding_window, kv_cache_dtype,
# logits_soft_cap, attn_type,
# kv_sharing_target_layer_name, **kwargs)
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
# print('kwargs', kwargs) # 手动写老版本的继承
self.q_lora_rank = kwargs['q_lora_rank'] if 'q_lora_rank' in kwargs else None
self.kv_lora_rank = kwargs['kv_lora_rank'] if 'kv_lora_rank' in kwargs else None
self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] if 'qk_nope_head_dim' in kwargs else None
self.qk_head_dim = kwargs['qk_head_dim'] if 'qk_head_dim' in kwargs else None
self.qk_head_dim = kwargs['qk_head_dim'] if 'qk_head_dim' in kwargs else None
self.v_head_dim = kwargs['v_head_dim'] if 'v_head_dim' in kwargs else None
self.rotary_emb = kwargs['rotary_emb'] if 'rotary_emb' in kwargs else None
self.q_proj = kwargs['q_proj'] if 'q_proj' in kwargs else None
self.kv_b_proj = kwargs['kv_b_proj'] if 'kv_b_proj' in kwargs else None
self.o_proj = kwargs['o_proj'] if 'o_proj' in kwargs else None
unsupported_features = [
alibi_slopes, sliding_window, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"VACCMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"VACCMLAImpl")
def extract_weights(self):
weights = {}
if hasattr(self, 'W_Q'):
weights["W_Q"] = self.W_Q
if hasattr(self, 'W_Q_scales'):
weights["W_Q_scales"] = self.W_Q_scales
if hasattr(self, 'W_QR'):
weights['W_QR'] = self.W_QR
if hasattr(self, 'W_QR_scales'):
weights["W_QR_scales"] = self.W_QR_scales
if hasattr(self, 'W_Q_QR'):
weights["W_Q_QR"] = self.W_Q_QR
if hasattr(self, 'W_Q_QR_scales'):
weights["W_Q_QR_scales"] = self.W_Q_QR_scales
if hasattr(self, 'W_UK'):
weights['W_UK'] = self.W_UK
if hasattr(self, 'W_UK_scales'):
weights['W_UK_scales'] = self.W_UK_scales
if hasattr(self, 'W_Q_UK_scales'):
weights['W_Q_UK_scales'] = self.W_Q_UK_scales
if hasattr(self, 'W_UV'):
weights['W_UV'] = self.W_UV
if hasattr(self, 'W_UV_scales'):
weights['W_UV_scales'] = self.W_UV_scales
if hasattr(self, 'W_UV_O'):
weights['W_UV_O'] = self.W_UV_O
if hasattr(self, 'W_UV_O_scales'):
weights['W_UV_O_scales'] = self.W_UV_O_scales
return weights
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: VACCMLAMetadata,
) -> torch.Tensor:
assert isinstance(attn_metadata, VACCMLAV1Metadata)
kv_nope = self.kv_b_proj(kv_c_normed)[0]\
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
v = v.contiguous()
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
# value=0)
# attn_output = torch.vacc.scaled_dot_product_attention(
# query=q,
# key=k,
# value=v_padded,
# attn_mask=None,
# dropout_p=0,
# is_causal=True,
# is_train=False,
# recompute=False,
# flash_attention=True,
# sm_scale=self.scale
# )
# attn_output = attn_output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
seq_lens = attn_metadata.prefill_metadata.seq_lens
if len(seq_lens) == 1:
# Vacc supports different head dim of v and qk.
attn_output = torch.vacc.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=True,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale
)
attn_out = attn_output.view(-1, self.num_heads * v.shape[-1])
else:
attn_outs = []
start = 0
for seq in seq_lens:
end = start + seq
attn_out = torch.vacc.scaled_dot_product_attention(
query=q[start:end, :],
key=k[start:end, :],
value=v[start:end, :],
attn_mask=None,
dropout_p=0,
is_causal=True,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale
)
start = end
attn_outs.append(attn_out)
attn_out = torch.cat(attn_outs, dim=0).view(-1, self.num_heads * v.shape[-1])
return self.o_proj(attn_out)[0]
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: VACCMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
# print(f"kv_c_and_k_pe_cache: {kv_c_and_k_pe_cache.shape} ")
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
# Run MQA using paged_attention
# o = torch.vacc.paged_attention(
# query=q,
# key_cache=kv_c_and_k_pe_cache,
# value_cache=kv_c_cache,
# block_table=decode_meta.block_tables,
# seq_len=decode_meta.seq_lens_tensor,
# out=o,
# sm_scale=self.scale
# )
# Run MQA using spda
# t0 = time.time()
o = vacc_paged_attention_naive(
q,
kv_c_and_k_pe_cache,
kv_c_cache,
block_table = decode_meta.block_tables,
# seq_lens = decode_meta.seq_lens_tensor,
seq_lens=decode_meta.seq_lens,
out = o,
sm_scale=self.scale)
return self._v_up_proj_and_o_proj(o)
# patch from MLACommonBackend
class VACCMLABackend(AttentionBackend):
accept_output_buffer: bool = False
@staticmethod
def get_name() -> str:
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return VACCMLAMetadata
@staticmethod
def get_impl_cls() -> Type["VACCMLAImpl"]:
return VACCMLAImpl
@staticmethod
def get_builder_cls() -> type["MLAVaccMetadataBuilder"]:
return MLAVaccMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
D = TypeVar("D", bound=MLACommonDecodeMetadata)
# patch from MLACommonMetadata
@dataclass
class VACCMLAV1Metadata(Generic[D]):
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
# The dimension of the attention heads
head_dim: Optional[int] = None
decode_metadata: Optional[D] = None
prefill_metadata: Optional[MLACommonPrefillMetadata] = None
def __post_init__(self):
if self.head_dim is not None:
MLACommonBackend.validate_head_size(self.head_dim)
class MLAVaccMetadataBuilder(AttentionMetadataBuilder[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
reorder_batch_threshold: ClassVar[int] = 2
#TODO 区分 prefill decode 阈值
def __init__(self,
# runner: "GPUModelRunner",
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[type[M]] = None):
self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names,
MLACommonImpl))
self.metadata_cls = metadata_cls \
if metadata_cls is not None else VACCMLAV1Metadata
self.kv_cache_spec = kv_cache_spec
scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.device = device
# self.runner = runner
# scheduler_config = runner.scheduler_config
# model_config = runner.model_config
# cache_config = runner.cache_config
self.chunked_prefill_enabled = False
self.num_heads = self.model_config.get_num_attention_heads(
parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.aot_schedule = current_platform.is_cuda()
self.kv_cache_spec = kv_cache_spec
# Dont try to access the runner on AMD
if self.aot_schedule:
self.page_size = self.kv_cache_spec.block_size
# if self.chunked_prefill_enabled:
# self.chunked_prefill_workspace_size = min(
# # Max sure there is enough for 8 full length request or at least
# # 4 pages of cache per request
# max(
# 8 * model_config.max_model_len, 4 *
# scheduler_config.max_num_seqs * cache_config.block_size),
# # For long-context models try not to over-allocate limiting
# # kv-cache space, limiting it to 64k tokens,
# # which would result in the workspace being:
# # 2*(576)*(64*1024) = 144mb
# # (assuming 576 MLA head dim, and fp16)
# # which would result in up-projected context being
# # 2*(192*128)*(64*1024) = 3gb
# # (assuming 192 QK head dim, 128 heads, and fp16)
# 128 * 1024)
# assert self.chunked_prefill_workspace_size >= \
# scheduler_config.max_num_seqs * cache_config.block_size
# self.chunked_prefill_workspace = torch.empty(
# (self.chunked_prefill_workspace_size,
# model_config.get_head_size()),
# dtype=model_config.dtype,
# device=runner.device,
# )
# self.block_table = block_table
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if scheduler_output.scheduled_cached_reqs.num_computed_tokens != []:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(
block_tables=block_table_tensor,
seq_lens=seq_lens,
)
# def build_for_cudagraph_capture(
# self, common_attn_metadata: CommonAttentionMetadata) -> M:
# """
# This method builds the metadata for full cudagraph capture.
# Currently, only decode is supported for full cudagraphs with MLA.
# """
# m = common_attn_metadata
# assert m.num_reqs == m.num_actual_tokens, \
# "MLA only supports decode-only full CUDAGraph capture. " \
# "Make sure all cudagraph capture sizes <= max_num_seq."
# # m.max_query_len = 1 # decode-only
# # Update state usually set in reorder_batch.
# self._num_decodes = m.num_reqs
# self._num_decode_tokens = m.num_actual_tokens
# self._num_prefills = 0
# self._num_prefill_tokens = 0
# return self.build(0, m)
def append_seqlen(self, seq_len: list[int], all_len: int):
# print('append_seqlen seq_len', seq_len)
# print('append_seqlen all_len', all_len)
if all_len > len(seq_len) and all_len % len(seq_len) == 0:
new_seq_len = []
mtp_num = all_len // len(seq_len)
for start_len in seq_len:
for i in range(1,1+mtp_num):
new_seq_len.append(start_len-mtp_num+i)
return new_seq_len
return seq_len
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> M:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
# max_query_len = common_attn_metadata.max_query_len
# assert self._num_decodes + self._num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
# device = self.runner.device
# block_table = self.block_table
# block_table_tensor = block_table.get_device_tensor()#[:num_reqs]
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
#decode common_attn_metadata: CommonAttentionMetadata(query_start_loc=tensor([0, 2], device='vacc:20', dtype=torch.int32), query_start_loc_cpu=tensor([0, 2], dtype=torch.int32), seq_lens=[40], seq_lens_cpu=[40, 0, 0, 0], num_computed_tokens_cpu=tensor([38], dtype=torch.int32), num_reqs=1, num_actual_tokens=2, max_query_len=2, max_seq_len=40, block_table_tensor=tensor([[1536, 1537, 1538, ..., 0, 0, 0]], device='vacc:20',
# block_table.slot_mapping[:num_actual_tokens].copy_(
# block_table.slot_mapping_cpu[:num_actual_tokens],
# non_blocking=True)
# # block_table.slot_mapping[num_actual_tokens:].fill_(-1)
# slot_mapping = block_table.slot_mapping[:num_actual_tokens]
# num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
# split_decodes_and_prefills(common_attn_metadata,
# decode_threshold=self.reorder_batch_threshold)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = 0, 0, 0, 0
token_nums = common_attn_metadata.query_start_loc_cpu[1:] - common_attn_metadata.query_start_loc_cpu[:-1]
if token_nums.max().item() > self.reorder_batch_threshold:
num_prefills = num_reqs
num_prefill_tokens = num_actual_tokens
else:
num_decodes = num_reqs
num_decode_tokens = num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
prefill_metadata = None
if num_prefills > 0:
reqs_start = num_decodes # prefill_start
# context_lens_cpu = self.runner.input_batch.\
# num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
# max_context_len_cpu = context_lens_cpu.max().item()
# num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
chunked_context_metadata = None
# if self.chunked_prefill_enabled and self._num_prefills > 0 \
# and max_context_len_cpu > 0:
# dont support chunked prefill
prefill_metadata = MLACommonPrefillMetadata(
block_tables=block_table_tensor[reqs_start:reqs_start+num_prefills, ...],
query_start_loc=prefill_query_start_loc,
# max_query_len=None,
chunked_context=chunked_context_metadata,
seq_lens=seq_lens[:num_prefills],
)
if not isinstance(seq_lens, list):
# TODO init set list in init: vllm/v1/spec_decode/eagle.py
seq_lens = seq_lens.tolist()
decode_metadata = None
if num_decodes > 0:
block_num_per_group = env_blk_grp_size // 16
block_table_tensor_new = block_table_tensor[:num_decodes, ::block_num_per_group].contiguous()
seq_lens_new = self.append_seqlen(seq_lens[:slot_mapping.shape[-1]], slot_mapping.shape[-1])
if slot_mapping.shape[-1] > num_decodes:
mtp_numbers = [query_start_loc[i+1]-query_start_loc[i] for i in range(len(query_start_loc)-1)] #query_start_loc[1:] - query_start_loc[:-1]
block_table_tensor_list = []
for bi,mtp_number in enumerate(mtp_numbers):
for _ in range(mtp_number):
block_table_tensor_list.append(block_table_tensor_new[bi:bi+1])
block_table_tensor_new = torch.concatenate(block_table_tensor_list, 0)
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor_new,
seq_lens=seq_lens_new,
)
return self.metadata_cls(
num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc,
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
# prefill_seq_lens=seq_lens[:self._num_prefills].tolist(), # device to host, todo optimiz
# MLACommonMetadata Chunk prefill specific
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
prefill_metadata=prefill_metadata,
decode_metadata=decode_metadata,
)

View File

View File

@@ -0,0 +1,78 @@
from collections import defaultdict
from collections.abc import Iterable
from typing import Callable, Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens)
from vllm.v1.request import Request
from vllm.v1.core.block_pool import BlockHashToBlockMap
logger = init_logger(__name__)
class BlockPool:
"""BlockPool that manages KVCacheBlocks.
It provides methods to allocate, free and cache the kv cache blocks. The
free_block_queue stores the free blocks in eviction order to enable
allocation, free, and cache eviction. The cached_block_hash_to_block
maps between block hash and cached block to support finding cached blocks
by their block hash.
Args:
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
enable_kv_cache_events: Whether to enable kv cache events.
"""
def __init__(
self,
num_gpu_blocks: int,
enable_caching: bool,
enable_kv_cache_events: bool = False,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
# All kv-cache blocks.
self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
# Cache for block lookup
self.cached_block_hash_to_block: BlockHashToBlockMap = \
BlockHashToBlockMap()
# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
# self.null_block = self.free_block_queue.popleft()
# self.null_block.is_null = True
self.null_block = self.free_block_queue.fake_free_list_head.next_free_block #self.free_block_queue.popleft()
self.null_block.is_null = False
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue: list[KVCacheEvent] = []
def get_usage(self) -> float:
"""Get the KV cache usage.
Returns:
The KV cache usage (between 0.0 and 1.0).
"""
total_gpu_blocks = self.num_gpu_blocks
if not total_gpu_blocks:
return 0
return 1.0 - (self.get_num_free_blocks() / total_gpu_blocks)

View File

@@ -0,0 +1,156 @@
from vllm.config import VllmConfig
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import logger
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
from vllm.v1.kv_cache_interface import KVCacheSpec
def query_device_avaliable_memory(vllm_config):
import os
import torch
import torch_vacc
available_kv_cache_memory= int(os.getenv("VLLM_VACC_KVCACHE_SPACE", "16")) * GiB_bytes
if available_kv_cache_memory ==0:
torch.vacc.empty_cache()
torch.vacc.reset_peak_memory_stats()
total_memory = torch.vacc.mem_get_info()[1]
torch.vacc.synchronize()
peak_memory = torch.vacc.max_memory_allocated()
torch.vacc.empty_cache()
torch_allocated_bytes = torch.vacc.memory_stats(
)["allocated_bytes.all.current"]
total_allocated_bytes = torch.vacc.mem_get_info(
)[1] - torch.vacc.mem_get_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory=total_memory * vllm_config.cache_config.gpu_memory_utilization - peak_memory
return available_kv_cache_memory
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
available_memory: int, page_size: int) -> int:
"""
Get the number of kv cache blocks.
Args:
vllm_config: The global VllmConfig
num_layers: The number of layers
available_memory: Memory available for KV cache in bytes.
page_size: The page size of the KV cache.
"""
num_blocks = int(available_memory // page_size // num_layers)
num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = \
vllm_config.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
num_blocks = num_gpu_blocks_override
block_num_per_group = env_blk_grp_size // 16
num_blocks = num_blocks // block_num_per_group * block_num_per_group
return num_blocks
def estimate_max_model_len(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> int:
"""
Estimates the maximum model length that can fit in the available memory
using binary search.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The estimated maximum model length that can fit in the available memory.
"""
from vllm.v1.core.kv_cache_utils import max_memory_usage_bytes
available_kv_cache_memory = query_device_avaliable_memory(vllm_config)
# Define a function to check if a given model length fits in memory
def fits_in_memory(model_len: int) -> bool:
# Modify the max_model_len for this calculation
vllm_config.model_config.max_model_len = model_len
# Calculate memory needed for the given model length
memory_needed = max_memory_usage_bytes(vllm_config,
kv_cache_spec.values())
# 增加vacc buffer 判断,与原始设计不一致
# 如果所需memory 小于总的预分配空间/剩余物理空间,支持其继续分配,
# 无需受到score模型max_model_len个预分配空间限制
return memory_needed <= max(available_memory, available_kv_cache_memory)
# Binary search for the maximum model length
current_max = vllm_config.model_config.max_model_len
left, right = 1, current_max
# If even the smallest model length doesn't fit, return 0
if not fits_in_memory(left):
return 0
# Binary search for the maximum model length that fits
result = 1
while left <= right:
mid = (left + right) // 2
if fits_in_memory(mid):
result = mid
left = mid + 1
else:
right = mid - 1
return result
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int):
"""
Checks whether `available_memory` is enough for the KV cache to hold at
least one request with the model's max_model_len.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Raises:
ValueError: If there is not enough memory available for the KV cache.
"""
from vllm.v1.core.kv_cache_utils import max_memory_usage_bytes
if available_memory <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_model_len = vllm_config.model_config.max_model_len
needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values())
available_kv_cache_memory = query_device_avaliable_memory(vllm_config)
# 增加vacc buffer 判断,与原始设计不一致
# 如果所需memory 小于总的预分配空间/剩余物理空间,支持其继续分配,
# 无需受到score模型max_model_len个预分配空间限制
if needed_memory > max(available_memory, available_kv_cache_memory):
# Estimate the maximum model length that can fit in the available memory
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
available_memory)
print("exec max model len is:", estimated_max_len)
estimated_msg = ""
if estimated_max_len > 0:
estimated_msg = (
"Based on the available memory, "
f"the estimated maximum model length is {estimated_max_len}.")
raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory/GiB_bytes:.2f} GiB). "
f"{estimated_msg} "
f"Try increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine.")

View File

View File

@@ -0,0 +1,89 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from vllm._bc_linter import bc_linter_include
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata)
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
@bc_linter_include
@dataclass
class NewRequestData:
req_id: str
prompt_token_ids: Optional[list[int]]
mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
prompt_embeds: Optional[torch.Tensor] = None
deepstack_input_embeds: Optional[torch.Tensor] = None #patch
@classmethod
def from_request(
cls,
request: Request,
block_ids: tuple[list[int], ...],
) -> NewRequestData:
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
prompt_embeds=request.prompt_embeds,
deepstack_input_embeds=request.deepstack_input_embeds,
)
def __repr__(self) -> str:
prompt_embeds_shape = (self.prompt_embeds.shape
if self.prompt_embeds else None)
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request},"
f"prompt_embeds_shape={prompt_embeds_shape}"
")")
# Version of __repr__ with the prompt data obfuscated
def anon_repr(self) -> str:
prompt_token_ids_len = len(
self.prompt_token_ids
) if self.prompt_token_ids is not None else None
prompt_embeds_shape = (self.prompt_embeds.shape
if self.prompt_embeds else None)
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids_len={prompt_token_ids_len},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request},"
f"prompt_embeds_shape={prompt_embeds_shape}"
")")

View File

@@ -0,0 +1,930 @@
from __future__ import annotations
import itertools
import time
from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Optional, Union
from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
create_request_queue)
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
logger = init_logger(__name__)
prefill_first = True
def _is_in_prefill(req: Request) -> bool:
# 还没把 prompt 全算完
return req.num_computed_tokens < req.num_prompt_tokens
def align_up(value: int, alignment: int) -> int:
if alignment <= 0:
return value
return ((value + alignment - 1) // alignment) * alignment
class Scheduler(SchedulerInterface):
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None:
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.kv_cache_config = kv_cache_config
self.kv_events_config = vllm_config.kv_events_config
self.parallel_config = vllm_config.parallel_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently.
self.finished_req_ids_dict: Optional[dict[int, set[str]]] = (
defaultdict(set) if include_finished_set else None)
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = \
self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len
self.enable_kv_cache_events = (
self.kv_events_config is not None
and self.kv_events_config.enable_kv_cache_events)
# Create KVConnector for the Scheduler. Note that each Worker
# will have a corresponding KVConnector with Role=WORKER.
# KV Connector pushes/pull of remote KVs for P/D and offloading.
self.connector = None
if self.vllm_config.kv_transfer_config is not None:
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"Multiple KV cache groups are not currently supported "
"with KV connectors")
assert not self.is_encoder_decoder, (
"Encoder-decoder models are not currently supported "
"with KV connectors")
self.connector = KVConnectorFactory.create_connector(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
self.parallel_config.data_parallel_rank,
)
num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None and num_gpu_blocks > 0
self.block_size = self.cache_config.block_size
self.max_total_kv_tokens: Optional[int] = None
if self.block_size is not None:
self.max_total_kv_tokens = (
self.kv_cache_config.num_blocks * self.block_size)
self.dcp_world_size = \
vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): The schedulers block_size must be multiplied
# by dcp_world_size, since block hashes are computed on the
# original full token sequence at a granularity of
# original_block_size × dcp_world_size.
if self.dcp_world_size > 1:
self.block_size *= self.dcp_world_size
# req_id -> Request
self.requests: dict[str, Request] = {}
# Scheduling policy
if self.scheduler_config.policy == "priority":
self.policy = SchedulingPolicy.PRIORITY
elif self.scheduler_config.policy == "fcfs":
self.policy = SchedulingPolicy.FCFS
else:
raise ValueError(
f"Unknown scheduling policy: {self.scheduler_config.policy}")
# Priority queues for requests.
self.waiting = create_request_queue(self.policy)
self.running: list[Request] = []
# Pending barrier groups: gid -> list[Request]
self._barrier_groups: dict[str, list[Request]] = {}
# Expected group sizes: gid -> expected size
self._barrier_expected: dict[str, int] = {}
# Force decode-only scheduling in the next step when prefills block.
self.force_decode_next_step = False
# The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished
# requests so that they can free the cached states for those requests.
# This is flushed at the end of each scheduling step.
self.finished_req_ids: set[str] = set()
# KV Connector: requests in process of async KV loading or recving
self.finished_recving_kv_req_ids: set[str] = set()
# Encoder-related.
# Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space.
# This can be changed when we make encoder cache for embedding caching
# across requests.
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
mm_registry=mm_registry,
)
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed) for MM models as well as encoder-decoder
# transformers.
self.max_num_encoder_input_tokens = encoder_compute_budget
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
# for these models.
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)
speculative_config = vllm_config.speculative_config
self.use_eagle = False
self.num_spec_tokens = self.num_lookahead_tokens = 0
if speculative_config:
self.num_spec_tokens = speculative_config.num_speculative_tokens
if speculative_config.use_eagle():
self.use_eagle = True
self.num_lookahead_tokens = self.num_spec_tokens
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
enable_caching=self.cache_config.enable_prefix_caching,
use_eagle=self.use_eagle,
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
def add_request(self, request: Request) -> None:
# self.waiting.add_request(request)
# Always track in global dict first
self.requests[request.request_id] = request
gid = getattr(request, "barrier_group_id", None)
gsz = getattr(request, "barrier_group_size", 0) or 0
if gid and gsz > 1:
# Stage into pending barrier group
group_list = self._barrier_groups.get(gid)
if group_list is None:
self._barrier_groups[gid] = group_list = []
self._barrier_expected[gid] = gsz
else:
# Reconcile inconsistent sizes by taking the max
self._barrier_expected[gid] = max(self._barrier_expected[gid],
gsz)
group_list.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
# Flush group when complete
if len(group_list) >= self._barrier_expected[gid]:
# Preserve arrival order
for r in sorted(group_list, key=lambda x: x.arrival_time):
self.waiting.add_request(r)
# Cleanup
del self._barrier_groups[gid]
del self._barrier_expected[gid]
return
# No barrier group; enqueue directly
self.waiting.add_request(request)
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
def _schedule_running_requests_for_mode(
self,
prefill_mode: bool,
token_budget: int,
encoder_compute_budget: int,
scheduled_timestamp: float,
scheduled_running_reqs: list[Request],
preempted_reqs: list[Request],
req_to_new_blocks: dict[str, KVCacheBlocks],
num_scheduled_tokens: dict[str, int],
scheduled_spec_decode_tokens: dict[str, list[int]],
scheduled_encoder_inputs: dict[str, list[int]],
scan_index: int,
next_req_index: int,
skip_request_ids: set[str],
) -> tuple[int, int, int, int]:
"""Schedule running requests under the given mode.
Args:
scan_index: The position in ``self.running`` to start scanning from.
next_req_index: The next logical "running index" used when we
populate ``structured_output_request_ids``.
"""
while scan_index < len(self.running) and token_budget > 0:
request = self.running[scan_index]
if request.request_id in skip_request_ids:
scan_index += 1
continue
if prefill_mode and not _is_in_prefill(request):
scan_index += 1
continue
if (not prefill_mode) and _is_in_prefill(request):
scan_index += 1
continue
num_new_tokens = max(
request.num_tokens_with_spec - request.num_computed_tokens, 0)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
# chunked prefill has to be enabled explicitly to allow pooling
# requests to be chunked
if (not self.scheduler_config.chunked_prefill_enabled
and num_new_tokens > token_budget):
break
num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - 1 - request.num_computed_tokens)
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
# new_encoder_budget = encoder_budget
# if request.has_encoder_inputs:
# (encoder_inputs_to_schedule, num_new_tokens,
# new_encoder_budget) = self._try_schedule_encoder_inputs(
# request, request.num_computed_tokens, num_new_tokens,
# encoder_budget)
new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_compute_budget
) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_compute_budget)
if num_new_tokens == 0:
scan_index += 1
continue
# num_draft_tokens = max(
# num_new_tokens + request.num_computed_tokens -
# request.num_tokens, 0)
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
# num_draft_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None:
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
if _is_in_prefill(preempted_req):
preempted_req.num_computed_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
can_schedule = False
break
else:
can_schedule = True
break
if not can_schedule:
break
assert new_blocks is not None
scheduled_running_reqs.append(request)
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
next_req_index += 1
if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
request.num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids)
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
# encoder_budget = new_encoder_budget
encoder_compute_budget = new_encoder_compute_budget
scan_index += 1
return token_budget, encoder_compute_budget, scan_index, next_req_index
def finish_requests(
self,
request_ids: Union[str, Iterable[str]],
finished_status: RequestStatus,
) -> None:
"""Handles the finish signal from outside the scheduler.
For example, the API server can abort a request when the client
disconnects.
"""
assert RequestStatus.is_finished(finished_status)
if isinstance(request_ids, str):
request_ids = (request_ids, )
else:
request_ids = set(request_ids)
running_requests_to_remove = []
waiting_requests_to_remove = []
valid_requests = []
# First pass: collect requests to remove from queues
for req_id in request_ids:
request = self.requests.get(req_id)
if request is None:
# Invalid request ID.
continue
valid_requests.append(request)
if request.status == RequestStatus.RUNNING:
running_requests_to_remove.append(request)
else:
waiting_requests_to_remove.append(request)
# Remove all requests from queues at once for better efficiency
for request in running_requests_to_remove:
self.running.remove(request)
if waiting_requests_to_remove:
self.waiting.remove_requests(waiting_requests_to_remove)
# Handle requests that are pending in barrier groups:
# if any item from a pending group is finished/aborted, flush the rest.
for request in valid_requests:
gid = getattr(request, "barrier_group_id", None)
if not gid:
continue
group_list = self._barrier_groups.get(gid)
if not group_list:
continue
# Remove this request from pending group if present
try:
if request in group_list:
group_list.remove(request)
except ValueError:
pass
# Flush remaining members (if any) to waiting, then cleanup.
if group_list:
for r in sorted(group_list, key=lambda x: x.arrival_time):
self.waiting.add_request(r)
self._barrier_groups.pop(gid, None)
self._barrier_expected.pop(gid, None)
# Second pass: set status and free requests
for request in valid_requests:
request.status = finished_status
self._free_request(request)
def _estimate_future_kv_tokens(self, request: Request) -> int:
"""Estimate the KV tokens a request may need for its full lifetime."""
alignment = min(env_blk_grp_size, self.max_model_len)
alignment = max(alignment, 1)
prompt_tokens = request.num_prompt_tokens # prefill length
prompt_reserved = align_up(prompt_tokens, alignment)
remaining_room = max(self.max_model_len - prompt_tokens, 0)
max_decode_tokens = request.max_tokens # output length
if max_decode_tokens < 0:
max_decode_tokens = remaining_room
else:
max_decode_tokens = min(max_decode_tokens, remaining_room)
decode_budget = max_decode_tokens
if self.num_lookahead_tokens > 0:
decode_budget = min(decode_budget + self.num_lookahead_tokens,
remaining_room)
remaining_capacity_in_prompt = max(prompt_reserved - prompt_tokens, 0)
decode_after_prompt = max(decode_budget - remaining_capacity_in_prompt,
0)
if decode_after_prompt == 0:
total_reserved = prompt_reserved
else:
total_reserved = prompt_reserved + \
align_up(decode_after_prompt, alignment)
return min(total_reserved, self.max_model_len)
def _compute_total_future_kv_tokens(
self, requests: Iterable[Request]
) -> int:
return sum(self._estimate_future_kv_tokens(req) for req in requests)
def make_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats] = None,
kv_connector_stats: Optional["KVConnectorStats"] = None,
) -> Optional[SchedulerStats]:
if not self.log_stats:
return None
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
assert prefix_cache_stats is not None
return SchedulerStats(num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
running_seqlens = [len(running_i._all_token_ids) for running_i in self.running],
kv_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=prefix_cache_stats,
spec_decoding_stats=spec_decoding_stats,
num_corrupted_reqs=sum(req.is_output_corrupted
for req in self.running),
kv_connector_stats=kv_connector_stats.data
if kv_connector_stats else None)
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. num_tokens_with_spec =
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens_with_spec. This is general enough to cover
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
scheduled_new_reqs: list[Request] = []
scheduled_resumed_reqs: list[Request] = []
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to the running request index.
# This will helps us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_compute_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# For logging.
scheduled_timestamp = time.monotonic()
# Determine whether we should prioritize prefills in this step.
has_running_prefill = any(_is_in_prefill(r) for r in self.running)
has_waiting_prefill = False
if self.waiting and prefill_first:
for waiting_req in self.waiting:
if _is_in_prefill(waiting_req):
has_waiting_prefill = True
break
prefill_mode = prefill_first and (has_running_prefill or has_waiting_prefill or not self.running) and len(self.running) < self.max_num_running_reqs
if self.force_decode_next_step:
prefill_mode = False
# First, schedule the RUNNING requests.
req_index = 0
token_budget, encoder_budget, _, req_index = (
self._schedule_running_requests_for_mode(
prefill_mode,
token_budget,
encoder_compute_budget,
scheduled_timestamp,
scheduled_running_reqs,
preempted_reqs,
req_to_new_blocks,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
scheduled_encoder_inputs,
0,
req_index,
set(),
))
# Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
# prefill_allocation_failed = False
# Next, schedule the WAITING requests.
if not preempted_reqs:
future_kv_token_usage = 0
prefill_all_len = 0
if self.max_total_kv_tokens is not None:
future_kv_token_usage = (
self._compute_total_future_kv_tokens(self.running))
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
request = self.waiting.peek_request()
request_future_tokens: Optional[int] = None
if self.max_total_kv_tokens is not None:
request_future_tokens = self._estimate_future_kv_tokens(
request)
if prefill_mode:
prefill_all_len += request.num_prompt_tokens
if prefill_mode and ((future_kv_token_usage +
request_future_tokens
> self.max_total_kv_tokens) or prefill_all_len > LLM_MAX_PREFILL_SEQ_LEN):
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
# prefill_allocation_failed = True
self.force_decode_next_step = True
continue
in_prefill = _is_in_prefill(request)
if prefill_mode and not in_prefill:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
if (not prefill_mode) and in_prefill:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Check that adding the request still respects the max_loras
# constraint.
if (self.lora_config and request.lora_request and
(len(scheduled_loras) == self.lora_config.max_loras and
request.lora_request.lora_int_id not in scheduled_loras)):
# Scheduling would exceed max_loras, skip.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_external_computed_tokens = 0
load_kv_async = False
# Get already-cached tokens.
if request.num_computed_tokens == 0:
# Get locally-cached tokens.
new_computed_blocks, num_new_local_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
if num_external_computed_tokens is None:
# The request cannot be scheduled because
# the KVConnector couldn't determine
# the number of matched tokens.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
else:
new_computed_blocks = (
self.kv_cache_manager.create_empty_block_list())
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
# new_encoder_budget = encoder_budget
new_encoder_compute_budget = encoder_compute_budget
# KVTransfer: loading remote KV, do not allocate for new work.
if load_kv_async:
assert num_external_computed_tokens > 0
num_new_tokens = 0
# Number of tokens to be scheduled.
else:
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold
< num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_compute_budget
) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_compute_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
# Handles an edge case when P/D Disaggregation
# is used with Spec Decoding where an
# extra block gets allocated which
# creates a mismatch between the number
# of local and remote blocks.
effective_lookahead_tokens = (0 if request.num_computed_tokens
== 0 else
self.num_lookahead_tokens)
# Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs:
# TODO(russellb): For Whisper, we know that the input is
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
num_encoder_tokens =\
self.scheduler_config.max_num_encoder_input_tokens
else:
num_encoder_tokens = 0
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens,
delay_cache_blocks=load_kv_async,
num_encoder_tokens=num_encoder_tokens,
)
if new_blocks is None:
# The request cannot be scheduled in this step.
if prefill_mode:
# prefill_allocation_failed = True
self.force_decode_next_step = True
break
# KVTransfer: the connector uses this info to determine
# if a load is needed. Note that
# This information is used to determine if a load is
# needed for this request.
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
req_index += 1
self.running.append(request)
if self.max_total_kv_tokens is not None:
if request_future_tokens is None:
request_future_tokens = (
self._estimate_future_kv_tokens(request))
future_kv_token_usage += request_future_tokens
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED,
scheduled_timestamp)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(
f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
# req_to_new_block_ids[request.request_id] = (
# self.kv_cache_manager.get_block_ids(request.request_id))
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
# encoder_budget = new_encoder_budget
encoder_compute_budget = new_encoder_compute_budget
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
if not prefill_mode:
self.force_decode_next_step = False
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
# Since some requests in the RUNNING queue may not be scheduled in
# this step, the total number of scheduled requests can be smaller than
# len(self.running).
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
len(scheduled_running_reqs) <= len(self.running))
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len(
self.kv_cache_config.kv_cache_groups)
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)))
# Construct the scheduler output.
# new_reqs_data = [
# NewRequestData.from_request(req,
# req_to_new_block_ids[req.request_id])
# for req in scheduled_new_reqs
# ]
new_reqs_data = [
NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs
]
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
# req_to_new_block_ids,
req_to_new_blocks,
)
scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs +
scheduled_resumed_reqs)
structured_output_request_ids, grammar_bitmask = (
self.get_grammar_bitmask(scheduled_requests,
scheduled_spec_decode_tokens))
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
# NOTE(Kuntai): this function is designed for multiple purposes:
# 1. Plan the KV cache store
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
events = self.kv_cache_manager.take_events()
# collect KV cache events from connector
if self.connector is not None:
print('if self.connector is not None:')
connector_events = self.connector.take_events()
if connector_events:
if events is None:
events = list(connector_events)
else:
events.extend(connector_events)
# publish collected KV cache events
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
self._update_after_schedule(scheduler_output)
return scheduler_output

View File

@@ -0,0 +1,61 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Callable
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec)
from vllm.v1.request import Request
import os
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
class SingleTypeKVCacheManager(ABC):
def free(self, request_id: str) -> None:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
# Default to [] in case a request is freed (aborted) before alloc.
req_blocks = self.req_to_blocks.pop(request_id, [])
# Free blocks in reverse order so that the tail blocks are
# freed first.
# ordered_blocks = reversed(req_blocks)
self.block_pool.free_blocks(req_blocks)
self.num_cached_block.pop(request_id, None)
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[KVCacheBlock]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
Returns:
The new allocated blocks.
"""
req_blocks = self.req_to_blocks[request_id]
num_required_blocks = cdiv(num_tokens, self.block_size)
block_size_number_per_group = env_blk_grp_size // self.block_size #512
num_required_blocks = (num_required_blocks + block_size_number_per_group - 1) // block_size_number_per_group * block_size_number_per_group
num_new_blocks = num_required_blocks - len(req_blocks)
if num_new_blocks <= 0:
return []
else:
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
return new_blocks

View File

@@ -0,0 +1,94 @@
import enum
import time
from collections.abc import Mapping
from typing import Any, Optional, Union
import msgspec
import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
from vllm.v1.engine import EngineCoreOutput, UtilityOutput
from vllm_vacc.vllm.v1.metrics.stats import SchedulerStats
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort")
class EngineCoreRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
prompt_token_ids: Optional[list[int]]
mm_features: Optional[list[MultiModalFeatureSpec]]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]
cache_salt: Optional[str]
data_parallel_rank: Optional[int]
prompt_embeds: Optional[torch.Tensor] = None
deepstack_input_embeds: Optional[torch.Tensor] = None
# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.
client_index: int = 0
# Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before
# a wave finished notification is received.
current_wave: int = 0
priority: int = 0
trace_headers: Optional[Mapping[str, str]] = None
class EngineCoreOutputs(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
# NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout
engine_index: int = 0
# [num_reqs]
outputs: list[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0
utility_output: Optional[UtilityOutput] = None
finished_requests: Optional[set[str]] = None
# In DP case, used to signal that the current wave of requests
# has finished and the engines are paused.
wave_complete: Optional[int] = None
# In DP case, used to signal that a request was received for an
# "old" wave, so the next wave needs to be started in other engines.
start_wave: Optional[int] = None
def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.monotonic()
class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
ADD = b'\x00'
ABORT = b'\x01'
START_DP_WAVE = b'\x02'
UTILITY = b'\x03'
# Sentinel used within EngineCoreProc.
EXECUTOR_FAILED = b'\x04'
ADD_BULK = b'\x05'

View File

@@ -0,0 +1,108 @@
from typing import Any, Optional, Union
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory
from vllm.engine.protocol import EngineClient
from vllm_vacc.vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.async_llm import logger
class AsyncLLM(EngineClient):
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
enable_log_requests: bool = False,
disable_log_stats: bool = False,
client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0,
disable_log_requests: bool = True, # Deprecated, will be removed
) -> "AsyncLLM":
# vacc support spec_num = 1
from .vllm_config_checker import check_spec_model
check_spec_model(vllm_config)
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
# Create the LLMEngine.
from vllm.v1.engine.async_llm import AsyncLLM as DefaultAsyncLLM
async_cls = DefaultAsyncLLM
return async_cls(
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop,
stat_loggers=stat_loggers,
log_requests=enable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
client_addresses=client_addresses,
client_count=client_count,
client_index=client_index,
)
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
# vacc support spec_num = 1
from .vllm_config_checker import check_spec_model
check_spec_model(vllm_config)
# Create the AsyncLLM.
from vllm.v1.engine.async_llm import AsyncLLM as DefaultAsyncLLM
async_cls = DefaultAsyncLLM
return async_cls(
vllm_config=vllm_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
async def _add_request(self, request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest], index: int,
queue: RequestOutputCollector):
# Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, prompt, parent_req, index,
queue)
# Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
if self.log_requests:
if request.prompt_token_ids is not None:
logger.info("Added request: %s, prompt length: %s", request.request_id, len(request.prompt_token_ids))
else:
logger.info("Added request %s.", request.request_id)

View File

@@ -0,0 +1,209 @@
import os
import queue
import signal
import sys
import threading
import time
from collections import deque
from collections.abc import Generator
from concurrent.futures import Future
from contextlib import ExitStack, contextmanager
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, Callable, Optional, TypeVar, Union
import msgspec
import zmq
from vllm.config import ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.utils import make_zmq_socket
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.engine import (EngineCoreRequest,EngineCoreRequestType)
from vllm.v1.core.kv_cache_utils import (BlockHash,
generate_scheduler_kv_cache_config,
get_kv_cache_configs,
get_request_block_hasher,
init_none_hash)
from vllm.v1.serial_utils import MsgpackDecoder
from vllm.v1.engine.core import EngineCore
from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5
HANDSHAKE_TIMEOUT_MINS = 5
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'
def process_input_sockets(self, input_addresses: list[str],
coord_input_address: Optional[str],
identity: bytes, ready_event: threading.Event):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
bulk_add_decoder = MsgpackDecoder(list[EngineCoreRequest])
with ExitStack() as stack, zmq.Context() as ctx:
input_sockets = [
stack.enter_context(
make_zmq_socket(ctx,
input_address,
zmq.DEALER,
identity=identity,
bind=False))
for input_address in input_addresses
]
if coord_input_address is None:
coord_socket = None
else:
coord_socket = stack.enter_context(
make_zmq_socket(ctx,
coord_input_address,
zmq.XSUB,
identity=identity,
bind=False))
# Send subscription message to coordinator.
coord_socket.send(b'\x01')
# Register sockets with poller.
poller = zmq.Poller()
for input_socket in input_sockets:
# Send initial message to each input socket - this is required
# before the front-end ROUTER socket can send input messages
# back to us.
input_socket.send(b'')
poller.register(input_socket, zmq.POLLIN)
if coord_socket is not None:
poller.register(coord_socket, zmq.POLLIN)
ready_event.set()
del ready_event
while True:
for input_socket, _ in poller.poll():
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(
copy=False)
request_type = EngineCoreRequestType(
bytes(type_frame.buffer))
if request_type == EngineCoreRequestType.ADD_BULK:
# 关键:按 list[EngineCoreRequest] 解码,然后在接收线程就地 fan-out
requests = bulk_add_decoder.decode(data_frames)
for r in requests:
r = self.preprocess_add_request(r)
self.input_queue.put_nowait((EngineCoreRequestType.ADD, r))
continue
# Deserialize the request data.
if request_type == EngineCoreRequestType.ADD:
request = add_request_decoder.decode(data_frames)
request = self.preprocess_add_request(request)
else:
request = generic_decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
class EngineCore:
"""Inner loop of vLLM's Engine."""
def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time()
# Get all kv cache needed by the model
kv_cache_specs = self.model_executor.get_kv_cache_specs()
# get_kv_cache_specs in model_runner
# for layer_name, layer in vllm_config.compilation_config.static_forward_context.items():
# print(f'layer_name = {layer_name}; layer = {layer}')
# # 只有 moe layer 拿不到attention layer?
# # TODO for no kv cahe model
# has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
# if has_kv_cache:
# if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
# dp_group = getattr(self, "dp_group", None)
# assert dp_group is not None
# self.available_gpu_memory_for_kv_cache = \
# ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
# available_gpu_memory = [
# self.available_gpu_memory_for_kv_cache
# ] * len(kv_cache_specs)
# else:
# # Profiles the peak memory usage of the model to determine how
# # much memory can be allocated for kv cache.
# available_gpu_memory = (
# self.model_executor.determine_available_memory())
# self.available_gpu_memory_for_kv_cache = \
# available_gpu_memory[0]
# else:
# # Attention free models don't need memory for kv cache
# available_gpu_memory = [0] * len(kv_cache_specs)
memory_blocks = self.model_executor.determine_available_memory_block() # [(memory, blocks) * rank_number]
available_gpu_memory = [memory_block[0] for memory_block in memory_blocks]
num_gpu_blocks = memory_blocks[0][1]
assert len(kv_cache_specs) == len(available_gpu_memory)
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
available_gpu_memory)
### patch here to support long seq_length for mtp
for kv_cache_config in kv_cache_configs:
for ii in range(len(kv_cache_config.kv_cache_tensors)):
kv_cache_config.kv_cache_tensors[ii].size = kv_cache_config.kv_cache_tensors[ii].size * num_gpu_blocks // kv_cache_config.num_blocks
kv_cache_config.num_blocks = num_gpu_blocks
### patch here to support long seq_length for mtp end
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
kv_cache_configs)
num_gpu_blocks = scheduler_kv_cache_config.num_blocks
num_cpu_blocks = 0
# Initialize kv cache and warmup the execution
self.model_executor.initialize_from_config(kv_cache_configs)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
def preprocess_add_request(
self, request: EngineCoreRequest) -> tuple[Request, int]:
"""Preprocess the request.
This function could be directly used in input processing thread to allow
request initialization running in parallel with Model forward
"""
# Note on thread safety: no race condition.
# `mm_receiver_cache` is reset at the end of LLMEngine init,
# and will only be accessed in the input processing thread afterwards.
if self.mm_receiver_cache is not None and request.mm_features:
request.mm_features = (
self.mm_receiver_cache.get_and_update_features(
request.mm_features))
req = Request.from_engine_core_request(request,
self.request_block_hasher)
if req.use_structured_output:
# Note on thread safety: no race condition.
# `grammar_init` is only invoked in input processing thread. For
# `structured_output_manager`, each request is independent and
# grammar compilation is async. Scheduler always checks grammar
# compilation status before scheduling request.
self.structured_output_manager.grammar_init(req)
return req, request.current_wave

View File

@@ -0,0 +1,76 @@
import asyncio
import contextlib
import queue
import sys
import uuid
import weakref
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections.abc import Awaitable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass
from threading import Thread
from typing import Any, Callable, Optional, TypeVar, Union, List
import msgspec.msgpack
import zmq
import zmq.asyncio
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import get_open_zmq_inproc_path, make_zmq_socket
from vllm.v1.engine import (EngineCoreOutputs,
EngineCoreRequestType, UtilityOutput)
from vllm_vacc.vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager, launch_core_engines)
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.engine.core_client import MPClient
logger = init_logger(__name__)
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc
EngineIdentity = bytes
class EngineCoreClient(ABC):
"""
EngineCoreClient: subclasses handle different methods for pushing
and pulling from the EngineCore for asyncio / multiprocessing.
Subclasses:
* InprocClient: In process EngineCore (for V0-style LLMEngine use)
* SyncMPClient: ZMQ + background proc EngineCore (for LLM)
* AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM)
"""
@abstractmethod
def _send_input(self, req_type: EngineCoreRequestType, payload: Any) -> None:
"""Send a request to EngineCore."""
raise NotImplementedError
def add_requests(self, requests: List["EngineCoreRequest"]) -> None:
"""一次性发多条 ADDADD_BULK"""
if not requests:
return
self._send_input(EngineCoreRequestType.ADD_BULK, requests)
class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore."""
def add_requests(self, requests: List[EngineCoreRequest]) -> None:
if not requests:
return
if self.is_dp: # 与 add_request 保持一致
self.engines_running = True
self._send_input(EngineCoreRequestType.ADD_BULK, requests)

View File

@@ -0,0 +1,176 @@
from collections.abc import Mapping
from copy import copy
from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
# from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
class LLMEngine:
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
# vacc support spec_num = 1
from .vllm_config_checker import check_spec_model
check_spec_model(vllm_config)
from vllm.v1.engine.llm_engine import LLMEngine as DefaultLLM
default_cls = DefaultLLM
return default_cls(vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=(not disable_log_stats),
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
enable_multiprocessing: bool = False,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
# vacc support spec_num = 1
from .vllm_config_checker import check_spec_model
check_spec_model(vllm_config)
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.")
enable_multiprocessing = True
# Create the LLMEngine.
from vllm.v1.engine.llm_engine import LLMEngine as DefaultLLM
default_cls = DefaultLLM
return default_cls(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing)
"""Legacy LLMEngine for backwards compatibility."""
def add_requests(
self,
items: list[tuple[
str, # request_id
PromptType, # prompt
Union[SamplingParams, PoolingParams], # params
Optional[float], # arrival_time
Optional[LoRARequest], # lora_request
Optional[dict], # tokenization_kwargs
Optional[dict], # trace_headers
# Optional[PromptAdapterRequest], # prompt_adapter_request
int, # priority
]],
) -> None:
"""批量把请求送入 EngineCore一次性触发 ADD_BULK。"""
core_reqs: list[EngineCoreRequest] = []
for (request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers,
priority) in items:
# 复用现有逐条流程的解析入口
prompt_str, request = self.processor.process_inputs(
request_id=request_id,
prompt=prompt,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
# prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
# Make a new RequestState and queue.
self.output_processor.add_request(request, prompt_str, None, 0)
# Add the request to EngineCore.
core_reqs.append(request)
continue
# self.engine_core.add_request(request)
# return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
# Make a new RequestState and queue.
self.output_processor.add_request(child_request, prompt_str,
parent_req, idx)
# Add the request to EngineCore.
# self.engine_core.add_request(child_request)
# print("add_requests: child_request id=", child_request.request_id)
core_reqs.append(child_request)
# output_processor 需要为每个“实际进入引擎的 req_id”建索引。
# 如果是 SamplingParams 且 n>1/best_of>1要做 parent-children 拆分;
# 否则直接登记单条。
# if isinstance(params, SamplingParams) and (
# (params.n is not None and params.n > 1) or
# (getattr(params, "best_of", 1) and getattr(params, "best_of", 1) > 1)
# ):
# parent = self.parallel_sampler.create_parent(request_id, params)
# # 注意:最后一个 child 可以直接复用 request其余用 copy
# children = self.parallel_sampler.materialize_children(parent, request)
# for child_idx, child in enumerate(children):
# self.output_processor.add_request(
# request=child,
# prompt_str=prompt_str,
# parent=parent,
# child_index=child_idx,
# )
# core_reqs.append(child)
# else:
# self.output_processor.add_request(request, prompt_str)
# core_reqs.append(request)
# 关键:一次性下发给 Core。EngineCoreClient 会发送 ADD_BULK。
print('self.engine_core', self.engine_core)
self.engine_core.add_requests(core_reqs)

View File

@@ -0,0 +1,184 @@
import time
from collections.abc import Mapping
from typing import Any, Literal, Optional, Union
import torch
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_lm_format_enforcer import (
validate_structured_output_request_lm_format_enforcer)
from vllm.v1.structured_output.backend_outlines import (
validate_structured_output_request_outlines)
from vllm.v1.structured_output.backend_xgrammar import (
validate_xgrammar_grammar)
logger = init_logger(__name__)
class Processor:
def process_inputs(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> tuple[Optional[str], EngineCoreRequest]:
# TODO(woosuk): Support pooling models.
self._validate_lora(lora_request)
self._validate_params(params)
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
data_parallel_size):
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
f"is out of range [0, {data_parallel_size}).")
if arrival_time is None:
arrival_time = time.time()
# Optionally generate multimodal hash overrides to avoid hashing
# multimodal data items by their content as their identifiers.
# NOTE: when users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# request id-modality-index as multimodal hash overrides.
if (self.model_config.multimodal_config and
self.model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.cache_config.enable_prefix_caching):
mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
else:
# Otherwise, use user-provided uuids as multimodal hash overrides
# if provided.
self._validate_multi_modal_uuids(prompt)
if isinstance(prompt, dict):
mm_uuids = prompt.get("multi_modal_uuids")
else:
mm_uuids = None
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
from vllm.platforms import current_platform
current_platform.validate_request(
prompt=prompt,
params=params,
processed_inputs=processed_inputs,
)
eos_token_id = self.input_preprocessor.get_eos_token_id()
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._validate_model_inputs(encoder_inputs, decoder_inputs)
# Mypy does not always properly infer the types of some elements of
# discriminated unions of TypedDicts, because of how it handles
# inheritance of TypedDict. If we explicitly extract the items we want
# we can avoid type errors from using `dict.get` later in the method.
prompt_str: Optional[str] = None if decoder_inputs[
"type"] == "embeds" else decoder_inputs.get("prompt")
prompt_token_ids = decoder_inputs[
"prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None
prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[
"type"] == "embeds" else None
deepstack_input_embeds = decoder_inputs["deepstack_input_embeds"] if decoder_inputs[
"type"] == "embeds" else None
# for deepstack_input_embeds in llm.generate method
if isinstance(deepstack_input_embeds, dict):
all_tensors = []
for key in deepstack_input_embeds:
if isinstance(deepstack_input_embeds[key], torch.Tensor):
all_tensors.append(deepstack_input_embeds[key].unsqueeze(0))
if len(all_tensors) > 0:
deepstack_input_embeds = torch.concatenate(all_tensors, 0)
sampling_params = None
pooling_params = None
if isinstance(params, SamplingParams):
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
seq_len = length_from_prompt_token_ids_or_embeds(
prompt_token_ids, prompt_embeds)
sampling_params.max_tokens = \
self.model_config.max_model_len - seq_len
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
if self.tokenizer is not None:
sampling_params.update_from_tokenizer(self.tokenizer)
else:
pooling_params = params.clone()
# Multimodal related.
mm_features: Optional[list[MultiModalFeatureSpec]] = None
if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
decoder_mm_positions = decoder_inputs["mm_placeholders"]
decoder_mm_hashes = decoder_inputs["mm_hashes"]
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
mm_features = []
for modality, idx in sorted_mm_idxs:
mm_features.append(
MultiModalFeatureSpec(
data=decoder_mm_inputs[modality][idx],
modality=modality,
identifier=decoder_mm_hashes[modality][idx],
mm_position=decoder_mm_positions[modality][idx]))
return prompt_str, EngineCoreRequest(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
prompt_embeds=prompt_embeds,
deepstack_input_embeds=deepstack_input_embeds,
mm_features=mm_features,
sampling_params=sampling_params,
pooling_params=pooling_params,
eos_token_id=eos_token_id,
arrival_time=arrival_time,
lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"),
priority=priority,
data_parallel_rank=data_parallel_rank,
trace_headers=trace_headers,
)

View File

@@ -0,0 +1,17 @@
# check spec model config
# write spec model func to config file
def check_spec_model(vllm_config):
# add spec tag
speculative_mode = hasattr(vllm_config, 'speculative_config')
if speculative_mode and \
hasattr(vllm_config.speculative_config, 'num_speculative_tokens') and \
vllm_config.speculative_config.num_speculative_tokens != 1:
raise ValueError(f'run_mp_engine: only support num_speculative_tokens == 1, but get {vllm_config.speculative_config.num_speculative_tokens}')
default_model_infos = "default"
if speculative_mode:
if hasattr(vllm_config.speculative_config, 'method'):
default_model_infos = vllm_config.speculative_config.method
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
vllm_vacc_config_manager().update_model_infos(default_model_infos)

View File

View File

@@ -0,0 +1,24 @@
from concurrent.futures import Future
from typing import Callable, Union
import torch
import torch.distributed as dist
from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
from vllm.executor.uniproc_executor import ( # noqa
UniProcExecutor as UniProcExecutorV0)
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
FailureCallback = Callable[[], None]
class Executor(ExecutorBase):
def determine_available_memory_block(self) -> list[(int, int)]: # in bytes
output = self.collective_rpc("determine_available_memory_block")
return output

View File

View File

@@ -0,0 +1,49 @@
import time
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.loggers import logger
class LoggingStatLogger(StatLoggerBase):
def log(self):
now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now)
self._reset(now)
scheduler_stats = self.last_scheduler_stats
log_fn = logger.info
if not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput)):
# Avoid log noise on an idle production system
log_fn = logger.debug
self.last_generation_throughput = generation_throughput
self.last_prompt_throughput = prompt_throughput
# Format and print output.
log_fn(
"Engine %03d: "
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%, "
"running seqlens: %s ",
self.engine_index,
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
str(scheduler_stats.running_seqlens),
)
self.spec_decoding_logging.log(log_fn=log_fn)
self.kv_transfer_logging.log(log_fn=log_fn)

View File

@@ -0,0 +1,32 @@
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.metrics.stats import PrefixCacheStats
@dataclass
class SchedulerStats:
"""Stats associated with the scheduler."""
num_running_reqs: int = 0
num_waiting_reqs: int = 0
running_seqlens: list[int] = None
# These are used for internal DP load-balancing.
step_counter: int = 0
current_wave: int = 0
kv_cache_usage: float = 0.0
prefix_cache_stats: PrefixCacheStats = field(
default_factory=PrefixCacheStats)
spec_decoding_stats: Optional[SpecDecodingStats] = None
kv_connector_stats: Optional[dict[str, Any]] = None
num_corrupted_reqs: int = 0

View File

@@ -0,0 +1,355 @@
import enum
import time
from collections.abc import Mapping
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreRequest, FinishReason)
from vllm.v1.structured_output.request import StructuredOutputRequest
from vllm.v1.utils import ConstantList
from vllm.v1.request import RequestStatus
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.v1.core.kv_cache_utils import BlockHash
class Request:
def __init__(
self,
request_id: str,
prompt_token_ids: Optional[list[int]],
sampling_params: Optional[SamplingParams],
pooling_params: Optional[PoolingParams],
eos_token_id: Optional[int],
client_index: int = 0,
arrival_time: Optional[float] = None,
prompt_embeds: Optional[torch.Tensor] = None,
mm_features: Optional[list[MultiModalFeatureSpec]] = None,
lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
priority: int = 0,
trace_headers: Optional[Mapping[str, str]] = None,
block_hasher: Optional[Callable[["Request"],
list["BlockHash"]]] = None,
) -> None:
self.request_id = request_id
self.client_index = client_index
self.priority = priority
self.sampling_params = sampling_params
self.pooling_params = pooling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.structured_output_request = structured_output_request
self.arrival_time = arrival_time if arrival_time is not None else \
time.time()
self.status = RequestStatus.WAITING
self.use_structured_output = False
self.events: list[EngineCoreEvent] = []
self.stop_reason: Union[int, str, None] = None
# P/D: Connector-specific KV transfer parameters.
self.kv_transfer_params: Optional[dict[str, Any]] = None
if pooling_params is not None:
# Pooling models.
self.max_tokens = 1
elif sampling_params is not None:
# Generative models.
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
if sampling_params.structured_outputs is not None:
self.status = RequestStatus.WAITING_FOR_FSM
self.use_structured_output = True
if sampling_params.extra_args is not None:
self.kv_transfer_params = \
sampling_params.extra_args.get("kv_transfer_params")
else:
raise ValueError(
"sampling_params and pooling_params can't both be unset")
self.prompt_token_ids = prompt_token_ids
self.prompt_embeds = prompt_embeds
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
prompt_token_ids, prompt_embeds)
self._output_token_ids: list[int] = []
self._all_token_ids: list[int] = self.prompt_token_ids.copy(
) if self.prompt_token_ids is not None else [0
] * self.num_prompt_tokens
self.num_output_placeholders = 0 # Used in async scheduling.
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
self.cache_salt: Optional[str] = cache_salt
# Multi-modal related
self.mm_features = mm_features or []
self.num_encoder_inputs = len(self.mm_features)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# Read-only views
# Prevent directly appending to these lists since
# they should also be updated simultaneously.
self.output_token_ids = ConstantList(self._output_token_ids)
self.all_token_ids = ConstantList(self._all_token_ids)
# trace_headers
self.trace_headers = trace_headers
# State
# The number of tokens with prefix cache hits.
self.num_cached_tokens = -1
# The number of NaNs in logits. A value greater than 0
# indicates that the output is corrupted
self.num_nans_in_logits = 0
self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Optional[Callable[
[], list[BlockHash]]] = None
if block_hasher is not None:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
@classmethod
def from_engine_core_request(
cls, request: EngineCoreRequest,
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
) -> "Request":
return cls(
request_id=request.request_id,
client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids,
prompt_embeds=request.prompt_embeds,
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params) \
if request.sampling_params else None,
cache_salt=request.cache_salt,
priority=request.priority,
trace_headers=request.trace_headers,
block_hasher=block_hasher,
)
def append_output_token_ids(
self,
token_ids: Union[int, list[int]],
) -> None:
if isinstance(token_ids, int):
self._output_token_ids.append(token_ids)
self._all_token_ids.append(token_ids)
else:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
if self.get_hash_new_full_blocks is not None:
self.block_hashes.extend(self.get_hash_new_full_blocks())
@property
def is_output_corrupted(self) -> bool:
return self.num_nans_in_logits > 0
@property
def num_tokens(self) -> int:
return len(self._all_token_ids)
@property
def num_tokens_with_spec(self) -> int:
return len(self._all_token_ids) + len(self.spec_token_ids)
@property
def num_output_tokens(self) -> int:
return len(self._output_token_ids)
def is_finished(self) -> bool:
return RequestStatus.is_finished(self.status)
def get_finished_reason(self) -> Union[FinishReason, None]:
return RequestStatus.get_finished_reason(self.status)
def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_features)
num_tokens = self.mm_features[input_id].mm_position.length
return num_tokens
def record_event(
self,
event_type: EngineCoreEventType,
timestamp: Optional[float] = None,
) -> None:
self.events.append(EngineCoreEvent.new_event(event_type, timestamp))
def take_events(self) -> Optional[list[EngineCoreEvent]]:
if not self.events:
return None
events, self.events = self.events, []
return events
def __init__(
self,
request_id: str,
prompt_token_ids: Optional[list[int]],
sampling_params: Optional[SamplingParams],
pooling_params: Optional[PoolingParams],
eos_token_id: Optional[int],
client_index: int = 0,
arrival_time: Optional[float] = None,
prompt_embeds: Optional[torch.Tensor] = None,
deepstack_input_embeds: Optional[torch.Tensor] = None,
mm_features: Optional[list[MultiModalFeatureSpec]] = None,
lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
priority: int = 0,
trace_headers: Optional[Mapping[str, str]] = None,
block_hasher: Optional[Callable[["Request"],
list["BlockHash"]]] = None,
) -> None:
self.request_id = request_id
self.client_index = client_index
self.priority = priority
self.sampling_params = sampling_params
self.pooling_params = pooling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.structured_output_request = structured_output_request
self.arrival_time = arrival_time if arrival_time is not None else \
time.time()
self.status = RequestStatus.WAITING
self.use_structured_output = False
self.events: list[EngineCoreEvent] = []
self.stop_reason: Union[int, str, None] = None
# P/D: Connector-specific KV transfer parameters.
self.kv_transfer_params: Optional[dict[str, Any]] = None
if pooling_params is not None:
# Pooling models.
self.max_tokens = 1
elif sampling_params is not None:
# Generative models.
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
if sampling_params.guided_decoding is not None:
self.status = RequestStatus.WAITING_FOR_FSM
self.use_structured_output = True
if sampling_params.extra_args is not None:
self.kv_transfer_params = \
sampling_params.extra_args.get("kv_transfer_params")
else:
raise ValueError(
"sampling_params and pooling_params can't both be unset")
# Strict barrier batching metadata (optional)
self.barrier_group_id: Optional[str] = None
self.barrier_group_size: int = 0
if sampling_params is not None and sampling_params.extra_args:
try:
if "barrier_group_id" in sampling_params.extra_args:
self.barrier_group_id = str(
sampling_params.extra_args["barrier_group_id"])
if "barrier_group_size" in sampling_params.extra_args:
self.barrier_group_size = int(
sampling_params.extra_args["barrier_group_size"])
except Exception:
# Be tolerant to malformed extra_args; just ignore on failure.
self.barrier_group_id = None
self.barrier_group_size = 0
self.prompt_token_ids = prompt_token_ids
self.prompt_embeds = prompt_embeds
self.deepstack_input_embeds = deepstack_input_embeds
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
prompt_token_ids, prompt_embeds)
self._output_token_ids: list[int] = []
self._all_token_ids: list[int] = self.prompt_token_ids.copy(
) if self.prompt_token_ids is not None else [0
] * self.num_prompt_tokens
self.num_output_placeholders = 0 # Used in async scheduling.
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
self.cache_salt: Optional[str] = cache_salt
# Multi-modal related
self.mm_features = mm_features or []
self.num_encoder_inputs = len(self.mm_features)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# Read-only views
# Prevent directly appending to these lists since
# they should also be updated simultaneously.
self.output_token_ids = ConstantList(self._output_token_ids)
self.all_token_ids = ConstantList(self._all_token_ids)
# trace_headers
self.trace_headers = trace_headers
# State
# The number of tokens with prefix cache hits.
self.num_cached_tokens = -1
# The number of NaNs in logits. A value greater than 0
# indicates that the output is corrupted
self.num_nans_in_logits = 0
self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Optional[Callable[
[], list[BlockHash]]] = None
if block_hasher is not None:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
@classmethod
def from_engine_core_request(
cls, request: EngineCoreRequest,
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
) -> "Request":
return cls(
request_id=request.request_id,
client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids,
prompt_embeds=request.prompt_embeds,
deepstack_input_embeds=request.deepstack_input_embeds,
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params) \
if request.sampling_params else None,
cache_salt=request.cache_salt,
priority=request.priority,
trace_headers=request.trace_headers,
block_hasher=block_hasher,
)
def record_event(
self,
event_type: EngineCoreEventType,
timestamp: Optional[float] = None,
) -> None:
self.events.append(EngineCoreEvent.new_event(event_type, timestamp))

View File

View File

@@ -0,0 +1,91 @@
import torch
import torch_vacc
class BinCountTensorPooler:
"""
sampler类中bin-count tensor的缓冲器。
核心功能为apply_penalties 提供创建好的bin-count, 为避免重复重头计算。
注意: 如果系统中的req_id未能确保在max_count的数量下的唯一性要考虑是否适用该pooler
pooler = BinCountTensorPooler(15169, "cpu")
list_1 = pooler.request_tensors(['1','3','aad'])
bin_count_buffers = pooler.request_tensors(req_ids)
"""
def __init__(
self,
vocab_size: int,
device: torch.device,
max_cache_count: int = 20,
):
# 缓存容器用列表维护请求ID与对应张量的顺序FIFO淘汰
self._cached_tensors: list[torch.Tensor] = []
self._cached_req_ids: list[str] = []
# 张量维度(+1 通常用于padding/特殊标记)
self.vocab_size: int = vocab_size + 1
self.device: torch.device = device
self.max_cache_count: int = max_cache_count
def request_tensors(self, req_ids: list[str]) -> list[torch.Tensor]:
"""
批量请求张量:已缓存的直接返回,未缓存的创建并加入缓存。
Args:
req_ids: 待请求的请求ID列表
Returns:
与req_ids顺序对应的bin-count tensor
"""
if not req_ids:
return [] # 空输入快速返回,避免无效循环
out_tensors = []
for req_id in req_ids:
if req_id not in self._cached_req_ids:
# 注册新的req, 增加cached-tensor
self._add_new_cache(req_id)
# 快速获取缓存索引后续可优化为dict映射提升性能
cache_idx = self._cached_req_ids.index(req_id)
out_tensors.append(self._cached_tensors[cache_idx])
return out_tensors
def _add_new_cache(self, req_id: str) -> None:
"""
新增缓存项创建张量并加入缓存超出最大数量时按FIFO淘汰最早项。
私有方法:封装内部逻辑,避免外部直接调用
"""
# FIFO淘汰缓存满时移除头部最早加入的项
if len(self._cached_req_ids) >= self.max_cache_count: # >= 更严谨(避免边界值问题)
self._cached_tensors.pop(0)
self._cached_req_ids.pop(0)
# 创建二进制计数张量int32类型初始全0
bin_count_tensor = torch.zeros(
size=(1, self.vocab_size),
dtype=torch.int32,
device=self.device,
requires_grad=False, # 明确禁用梯度(计数张量无需反向传播)
)
# 同步添加ID和张量保证两个列表索引一致
self._cached_req_ids.append(req_id)
self._cached_tensors.append(bin_count_tensor)
def clear_cache(self) -> None:
"""清空所有缓存(可选扩展方法,方便外部手动清理)"""
self._cached_tensors.clear()
self._cached_req_ids.clear()
def get_cache_status(self) -> dict:
"""获取缓存状态(可选扩展方法,方便监控)"""
return {
"current_cache_count": len(self._cached_req_ids),
"max_cache_count": self.max_cache_count,
"cached_req_ids": self._cached_req_ids.copy(), # 返回副本避免外部修改
"tensor_shape": (1, self.vocab_size),
"tensor_dtype": torch.int32,
"device": str(self.device),
}

View File

@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.v1.sample.logits_processor import LogitsProcessors
@dataclass
class SamplingMetadata:
temperature: Optional[torch.Tensor]
all_greedy: bool
all_random: bool
top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
generators: dict[int, torch.Generator]
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: Optional[int]
no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]
frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
output_token_ids: list[list[int]]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask: Optional[torch.Tensor]
# req_index -> bad_words_token_ids
bad_words_token_ids: dict[int, list[list[int]]]
# Loaded logits processors
logitsprocs: LogitsProcessors
temperature_cpu: Optional[torch.Tensor]
top_p_cpu: Optional[torch.Tensor]
top_k_cpu: Optional[torch.Tensor]

View File

@@ -0,0 +1,230 @@
from typing import Optional
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.sample.rejection_sampler import generate_uniform_probs, compute_probs, rejection_random_sample_kernel, sample_recovered_tokens
from vllm.distributed import get_tensor_model_parallel_rank
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = -1
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32
def rejection_greedy_sample_python(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
max_spec_len,
num_warps
):
# print('max_spec_len', max_spec_len)
if max_spec_len == 1:
for bi in range(output_token_ids_ptr.shape[0]):
output_token_ids_ptr[bi, 0] = target_argmax_ptr[bi]
if target_argmax_ptr[bi].item() == draft_token_ids_ptr[bi].item():
output_token_ids_ptr[bi, 1] = bonus_token_ids_ptr[bi]
else:
raise ValueError('TODO mtp k > 1')
class RejectionSampler(nn.Module):
def forward(
self,
metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
# print(sampling_metadata)
# rank_id = get_tensor_model_parallel_rank()
if metadata.max_spec_len == 1:
output_token_ids = torch.vacc.rejection_sampler_v1(
target_logits.to(torch.float32),
metadata.draft_token_ids,
bonus_token_ids,
sampling_metadata.temperature,
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.all_greedy,
sampling_metadata.all_random,
sampling_metadata.generators
)
else:
target_probs = compute_probs(
target_logits.to(torch.float32),
metadata.cu_num_draft_tokens,
sampling_metadata,
)
output_token_ids = rejection_sample(
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_probs,
bonus_token_ids,
sampling_metadata,
)
# output_token_ids_cpu = output_token_ids.cpu().tolist()
# output_token_ids_dev_cpu = output_token_ids_dev.cpu().tolist()
# for i in range(len(output_token_ids_cpu)):
# for j in range(len(output_token_ids_cpu[0])):
# if output_token_ids_cpu[i][j] != output_token_ids_dev_cpu[i][j]:
# # print(output_token_ids_cpu)
# # print(output_token_ids_dev_cpu)
# print("cpu \n", output_token_ids, "vacc \n" , output_token_ids_dev)
# exit()
# print("cpu \n", output_token_ids, "vacc \n" , output_token_ids_dev)
return output_token_ids
def rejection_sample(
# [num_tokens]
draft_token_ids: torch.Tensor,
# [batch_size]
num_draft_tokens: list[int],
max_spec_len: int,
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2
batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1]
device = target_probs.device
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
# Create output buffer.
output_token_ids = torch.empty(
(batch_size, max_spec_len + 1),
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
device=device,
)
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
if sampling_metadata.all_greedy:
is_greedy = None
else:
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
# rejection_greedy_sample_kernel[(batch_size, )](
rejection_greedy_sample_python(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
num_warps=1,
)
if sampling_metadata.all_greedy:
return output_token_ids
else:
# TODO
raise ValueError('not support yet')
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs = generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device,
)
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
num_warps=1,
)
return output_token_ids

View File

@@ -0,0 +1,276 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A layer that samples the next tokens from the model's outputs."""
import torch
import torch.nn as nn
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from .cached_pooler import BinCountTensorPooler
_SAMPLING_EPS = 1e-5
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int32, # init with int32
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)
class Sampler(nn.Module):
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
is_first_calculate,
req_ids: list[str],
) -> SamplerOutput:
# print(sampling_metadata.generators, len(sampling_metadata.generators), sampling_metadata.temperature_cpu)
# sampling_metadata.temperature = sampling_metadata.temperature.to(logits.device)
# sampling_metadata.top_p = sampling_metadata.top_p.to(logits.device)
# sampling_metadata.top_k = sampling_metadata.top_k.to(logits.device)
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# TODO(rob): provide option for logprobs post sampling.
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits)
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Apply allowed token ids.
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
# Apply bad words exclusion.
logits = self.apply_bad_words(logits, sampling_metadata)
# Apply logits processors which can impact greedy sampling
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
logits = processor.apply(logits)
# Apply penalties (e.g., min_tokens, freq_penalties).
# logits = self.apply_penalties(logits, sampling_metadata)
if not sampling_metadata.no_penalties:
if (not hasattr(self, "bin_count_pooler")):
self.bin_count_pooler = BinCountTensorPooler(logits.shape[-1], logits.device)
assert sampling_metadata.prompt_token_ids is not None
buf_bin_buffer = self.bin_count_pooler.request_tensors(req_ids)
batch, vocab_size = logits.shape
logits_res = []
for i in range(batch):
output_tokens_t = _convert_to_tensors([sampling_metadata.output_token_ids[i]], vocab_size, logits.device).to(torch.int32)
output_tokens_t_lastone = output_tokens_t[:, -1:]
if logits.shape[0] > 1:
logits_i = torch.vacc.apply_penalties(logits[i:i+1],
output_tokens_t_lastone,
buf_bin_buffer[i],
vocab_size,
output_tokens_t_lastone.shape[-1],
[sampling_metadata.frequency_penalties[i]],
[sampling_metadata.presence_penalties[i]],
is_first_calculate)
else:
logits_i = torch.vacc.apply_penalties(logits,
output_tokens_t_lastone,
buf_bin_buffer[i],
vocab_size,
output_tokens_t_lastone.shape[-1],
[sampling_metadata.frequency_penalties[i]],
[sampling_metadata.presence_penalties[i]],
is_first_calculate)
logits_res.append(logits_i)
if len(logits_res) > 1:
logits = torch.concat(logits_res)
else:
logits = logits_res[0]
# Sample the next token.
# sampled = self.sample(logits, sampling_metadata)
sampled, _ = torch.vacc.sampler_v1(logits, sampling_metadata.top_p_cpu, sampling_metadata.top_k_cpu, sampling_metadata.temperature_cpu, int(sampling_metadata.all_greedy), int(sampling_metadata.all_random), sampling_metadata.generators)
# Gather the logprobs of the topk and sampled token (if requested).
# Get logprobs and rank tensors (if requested)
logprobs_tensors = None if num_logprobs is None else \
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled.long())
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=logprobs_tensors,
)
return sampler_output
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use in-place division to avoid creating a new tensor.
return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
may update the logits tensor in-place.
"""
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_random:
greedy_sampled = None
else:
greedy_sampled = self.greedy_sample(logits)
if sampling_metadata.all_greedy:
return greedy_sampled
assert sampling_metadata.temperature is not None
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply logits processors that only apply to random sampling
# (argmax invariant)
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
if greedy_sampled is None:
return random_sampled
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
return sampled
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
def gather_logprobs(
self,
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == torch.int64
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs,
dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size.
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
def apply_penalties(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids,
)
return logits
def apply_allowed_token_ids(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.allowed_token_ids_mask is not None:
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
float("-inf"))
return logits
def apply_bad_words(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.bad_words_token_ids:
apply_bad_words(
logits,
sampling_metadata.bad_words_token_ids,
sampling_metadata.output_token_ids,
)
return logits

View File

@@ -0,0 +1,789 @@
import ast
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.distributed import get_tensor_model_parallel_rank
PADDING_SLOT_ID = -1
from vacc_tools.trace_logger import get_trace_api
trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
get_trace_api("deepseek")
)
# @trace_time('prepare_eagle_input_python')
def prepare_eagle_input_python(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
# BLOCK_SIZE
):
"""
Python实现版本的prepare_eagle_input_kernel
参数:
out_ptr: 输出张量
cu_query_lens_ptr: 每个查询的起始索引张量
cu_num_tokens_ptr: 每个查询的token数量累计张量
BLOCK_SIZE: 块大小
"""
cu_query_lens_ptr_list = cu_query_lens_ptr
cu_num_tokens_ptr_list = cu_num_tokens_ptr
num_queries = len(cu_num_tokens_ptr) - 1
# out_ptr_list = np.zeros(cu_query_lens_ptr_list.shape, cu_query_lens_ptr_list.dtype)
for pid in range(num_queries):
start_pos = cu_num_tokens_ptr_list[pid]#.item()
end_pos = cu_num_tokens_ptr_list[pid + 1]#.item()
num_tokens = end_pos - start_pos
index_start = cu_query_lens_ptr_list[pid]#.item()
# offset = np.array([i for i in range(num_tokens)], dtype=cu_num_tokens_ptr_list.dtype)
# values = index_start + offset
# 存储到输出张量
# out_ptr[start_pos + offset] = values
for i in range(num_tokens):
out_ptr[start_pos + i] = index_start + i
return
import torch
num_queries = len(cu_num_tokens_ptr) - 1
for pid in range(num_queries):
# [start_pos, end_pos)
start_pos = cu_num_tokens_ptr[pid].item()
end_pos = cu_num_tokens_ptr[pid + 1].item()
num_tokens = end_pos - start_pos
index_start = cu_query_lens_ptr[pid].item()
num_blocks = (num_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE
for i in range(num_blocks):
offset_start = i * BLOCK_SIZE
offset_end = min(offset_start + BLOCK_SIZE, num_tokens)
# 创建当前块的偏移量
offset = torch.arange(offset_start, offset_end, device=out_ptr.device, dtype=out_ptr.dtype)
# 计算要存储的值
values = index_start + offset
# 存储到输出张量
out_ptr[start_pos + offset] = values
class EagleProposer:
# @trace_time('EagleProposer_propose')
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
last_token_indices: Optional[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
# self.input_ids[last_token_indices] = next_token_ids
if isinstance(last_token_indices, list) and len(last_token_indices) == 1:
self.input_ids[last_token_indices[0] : last_token_indices[0]+1] = next_token_ids
else:
self.input_ids[last_token_indices] = next_token_ids
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
ubatch_id = dbo_current_ubatch_id()
attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0)
# FIXME: support hybrid kv for draft model (remove separate indexer)
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = (
self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=0,
))
else:
draft_indexer_metadata = None
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
# self.positions[:num_tokens] = target_positions
# self.hidden_states[:num_tokens] = target_hidden_states
if self.is_multimodal_model:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds or None,
)
self.inputs_embeds[:num_tokens] = inputs_embeds
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=target_positions,
hidden_states=target_hidden_states, #self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
# sample_hidden_states = last_hidden_states[last_token_indices]
if isinstance(last_token_indices, list):
if len(last_token_indices) == last_hidden_states.shape[0]:
sample_hidden_states = last_hidden_states
elif len(last_token_indices) == 1:
sample_hidden_states = last_hidden_states[last_token_indices[0] : last_token_indices[0] + 1]
else:
sample_hidden_states = last_hidden_states.new_empty(last_token_indices.shape + last_hidden_states.shape[1:])
torch.ops.aten.index(hidden_states, [torch.tensor(last_token_indices, dtype=torch.int32)], out=sample_hidden_states)
else:
assert isinstance(last_token_indices, torch.Tensor)
if last_token_indices.shape[0] == last_hidden_states.shape[0]:
sample_hidden_states = last_hidden_states
else:
sample_hidden_states = last_hidden_states.new_empty(last_token_indices.shape + last_hidden_states.shape[1:])
torch.ops.aten.index(hidden_states, [last_token_indices], out=sample_hidden_states)
logits = self.model.compute_logits(sample_hidden_states)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
draft_token_ids = logits.argmax(dim=-1)
return draft_token_ids.view(-1, 1)
else:
raise ValueError(f'not support self.num_speculative_tokens > 1, but get {self.num_speculative_tokens}')
'''
positions = target_positions[last_token_indices]
if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
hidden_states = None #self.hidden_states[last_token_indices]
else:
hidden_states = None #hidden_states[last_token_indices]
if isinstance(attn_metadata, TreeAttentionMetadata):
# Draft using tree attention.
draft_token_ids_list = self.propose_tree(
batch_size=batch_size,
logits=logits,
positions=positions,
hidden_states=hidden_states,
common_attn_metadata=common_attn_metadata,
)
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = logits.argmax(dim=-1)
if self.allowed_attn_types is not None and \
not isinstance(attn_metadata, self.allowed_attn_types):
raise ValueError(
f"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: "
f"{type(attn_metadata)}. Supported types are: "
f"{self.allowed_attn_types}")
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[:batch_size + 1]).clone()
for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
common_attn_metadata.seq_lens += 1
common_attn_metadata.seq_lens_cpu += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
1)
common_attn_metadata.num_computed_tokens_cpu = \
common_attn_metadata.seq_lens_cpu - 1
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = common_attn_metadata.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
common_attn_metadata.slot_mapping = (
block_ids * self.block_size +
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
common_attn_metadata.slot_mapping.masked_fill_(
exceeds_max_model_len, PADDING_SLOT_ID)
# Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
# self.hidden_states[:batch_size] = hidden_states
if self.is_multimodal_model:
inputs_embeds = self.model.get_input_embeddings(input_ids)
self.inputs_embeds[:batch_size] = inputs_embeds
inputs_embeds = self.inputs_embeds[:input_batch_size]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:input_batch_size]
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:input_batch_size],
# hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size])
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
'''
def prepare_next_token_ids_padded(self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_indices: torch.Tensor,
num_discarded_requests: int) -> \
tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids and the number of valid sampled tokens
for each request, considering the "discarded" requests whose next token
is not sampled and comes from `request.get_token_id()` instead.
It also accounts for the rejected tokens in `sampled_token_ids`.
This function must use device functions to operate on the inputs, and
should not introduce any blocking CPU-GPU synchronization.
"""
# TODO(Ben): Combine this into a custom fused kernel
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
self.backup_next_token_ids.np[:num_reqs] = np.array([
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i])
for i in range(num_reqs)
])
self.backup_next_token_ids.copy_to_gpu(num_reqs)
# Mask out the sampled tokens indices that should not be sampled.
discard_sampled_tokens_req_indices = \
discard_request_indices[:num_discarded_requests]
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
valid_sampled_token_ids_gpu.index_fill_(
0, discard_sampled_tokens_req_indices, -1)
# Generate a mask for all valid tokens within those requests
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
dtype=torch.bool)
else:
valid_mask = (
(valid_sampled_token_ids_gpu != -1) &
(valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))
# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
# Get the rightmost valid index per row
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
# Get last valid token from each row
# (assume undefined state where there is no valid token)
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1,
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
# Use last token if valid, pre-computed backup if not
batch_size = valid_sampled_token_ids_gpu.shape[0]
next_token_ids = torch.where(
last_valid_indices != -1, selected_tokens,
self.backup_next_token_ids.gpu[:batch_size])
return next_token_ids, valid_sampled_tokens_count
# @trace_time('prepare_inputs_padded')
def prepare_inputs_padded(self,
common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor) -> \
tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
"""
num_draft_tokens_gpu = torch.cat([
spec_decode_metadata.cu_num_draft_tokens[0:1],
spec_decode_metadata.cu_num_draft_tokens[1:] -
spec_decode_metadata.cu_num_draft_tokens[:-1]
])
num_rejected_tokens_gpu = torch.where(
num_draft_tokens_gpu > 0,
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
torch.zeros_like(num_draft_tokens_gpu))
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_query_len_per_req = (query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])
total_num_tokens = query_start_loc_cpu[-1].item()
token_indices = self.arange[:total_num_tokens]
spec_common_attn_metadata = CommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc,
seq_lens=common_attn_metadata.seq_lens,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.
num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
max_seq_len=max(common_attn_metadata.seq_lens_cpu),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
causal=True,
)
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \
- num_rejected_tokens_gpu
return spec_common_attn_metadata, token_indices, token_indices_to_sample
# @trace_time('prepare_inputs')
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
# E.g.
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1, q1 + q2, q1 + q2 + q3]
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
# num_rejected_tokens: [n1, n2, n3]
# This function computes the intermediate values:
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
# And returns:
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# common_attn_metadata.seq_lens{_cpu}:
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# token_indices: [0, 1, ..., q1 - n1 - 1,
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
# num_rejected_tokens = torch.tensor(num_rejected_tokens,
# dtype=torch.int32)
# device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
# new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
# - num_rejected_tokens
# new_seq_lens_cpu = [i-j for i,j in zip(common_attn_metadata.seq_lens_cpu, num_rejected_tokens)]
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
# new_query_len_per_req = (query_start_loc_cpu[1:] -
# query_start_loc_cpu[:-1])
# new_query_len_per_req = [query_start_loc_cpu[i+1] - query_start_loc_cpu[i] for i in range(len(query_start_loc_cpu)-1)]
new_query_len_per_req = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]).tolist() # [2] *bs+1
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
# new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
new_num_tokens_per_req = [i-j for i,j in zip(new_query_len_per_req, num_rejected_tokens)]
new_num_tokens_per_req_np = np.array(new_num_tokens_per_req)
# common_attn_metadata.seq_lens_cpu is list[int], length is max_seq_num,
# seq_lens_cpu come from VACCModelRunner _prepare_inputs, 预先加了 k + 1
# if seq=[31,63] bs=2, max_seq_num=4, seq_lens_cpu=[31,63,0,0]
# new_seq_lens_cpu need all real seq,
# if num_rejected_tokens=[0,1] new_seq_lens_cpu = [30,31,62] means bs=1 接受了, bs=2拒绝了 只有一个recover_token
# if num_rejected_tokens=[1,0] new_seq_lens_cpu = [30,62,63] means bs=2 接受了, bs=1拒绝了 只有一个recover_token
# if num_rejected_tokens=[0,0] new_seq_lens_cpu = [30,31,62,63] means 都接受了
# if num_rejected_tokens=[1,1] new_seq_lens_cpu = [30,62] means 都拒绝了
new_seq_lens_cpu = []
for i in range(len(num_rejected_tokens)):
for j in range(new_num_tokens_per_req[i]):
new_seq_lens_cpu.append(common_attn_metadata.seq_lens_cpu[i] - new_num_tokens_per_req[i] + 1 - num_rejected_tokens[i] + j)
# [q1 - n1, q2 - n2, q3 - n3] ->
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
new_query_start_loc_cpu = torch.zeros(
query_start_loc_cpu.shape,
dtype=torch.int32,
pin_memory=is_pin_memory_available())
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
total_num_tokens = new_query_start_loc_np[-1]
# Example assuming num_tokens_per_req_np = [2, 4, 3]
# this implies that `new_query_start_locs` is:
# [0, 2, 6, 9] ->
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
# _r1_ ____r2____ ___r3__
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
new_num_tokens_per_req_np)
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
# _r1_ ____r2____ ___r3__
token_offests = self.token_arange_np[:total_num_tokens] \
- new_query_start_locs_expanded
# Expand starting positions to match token pattern
# [0, q1, q1 + q2] ->
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
# _r1_ _____r2_______ ___________r3____________
old_query_start_locs_expanded = np.repeat(
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
# Final token indices are:
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = token_indices_np.tolist()
# token_indices = torch.from_numpy(token_indices_np).to(
# device, non_blocking=True)
# if get_tensor_model_parallel_rank() == 0:
# print('token_indices', token_indices, common_attn_metadata.slot_mapping.shape)
# [0] or [0,1] bs1
# [0, 2, 4, 6] ... or [0, 1, 2, 3, 4, 6, 7] for bs4
# opt slot_mapping slice
#common_attn_metadata.slot_mapping[token_indices] : copy + copy + index_out
if len(token_indices) == common_attn_metadata.slot_mapping.shape[0]: # no need slice
slot_mapping = common_attn_metadata.slot_mapping
elif len(token_indices) == 1:
slot_mapping = common_attn_metadata.slot_mapping[token_indices[0] : token_indices[0] + 1]
else:
slot_mapping = common_attn_metadata.slot_mapping.new_empty(len(token_indices))
torch.ops.aten.index(common_attn_metadata.slot_mapping, [torch.tensor(token_indices, dtype=torch.int32)], out=slot_mapping)
spec_common_attn_metadata = CommonAttentionMetadata(
query_start_loc=new_query_start_loc_cpu,
# query_start_loc=new_query_start_loc_cpu.to(device,
# non_blocking=True),
seq_lens=new_seq_lens_cpu, #.to(device, non_blocking=True),
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens_cpu=new_seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.
num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=max(new_query_len_per_req),#new_query_len_per_req.max().item(),
max_seq_len=None, #max(new_seq_lens_cpu),#new_seq_lens_cpu.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=slot_mapping, #common_attn_metadata.slot_mapping[token_indices],
causal=True,
)
return spec_common_attn_metadata, token_indices
# @trace_time('EagleProposer_prepare_inputs')
@staticmethod
def prepare_inputs_9_2(
self,
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
# torch
# query_len_per_req = (cu_target_query_lens[1:] -
# cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
# num_tokens_per_req = query_len_per_req - num_rejected_tokens
# list
num_tokens_per_req = [cu_target_query_lens[i+1] - cu_target_query_lens[i] - num_rejected_tokens[i] for i in range(len(cu_target_query_lens)-1)]
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# torch style
# cu_num_tokens = torch.zeros_like(cu_target_query_lens)
# torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
# list style
cu_num_tokens = [0] * len(cu_target_query_lens)
for i in range(len(cu_target_query_lens)-1):
cu_num_tokens[i+1] = cu_num_tokens[i] + num_tokens_per_req[i]
# token_indices = torch.empty(
# num_tokens,
# dtype=torch.int32,
# device=cu_target_query_lens.device,
# )
# batch_size = num_rejected_tokens.shape[0]
# BLOCK_SIZE = 1024
# prepare_eagle_input_kernel[(batch_size, )](
# token_indices,
# cu_target_query_lens,
# cu_num_tokens,
# BLOCK_SIZE=BLOCK_SIZE,
# )
token_indices = [0] * num_tokens
prepare_eagle_input_python(
token_indices,
cu_target_query_lens,
cu_num_tokens
)
return cu_num_tokens, token_indices
def EagleProposer_init_(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config
assert self.speculative_config is not None
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
self.runner = runner
self.device = device
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
self.num_speculative_tokens = (
self.speculative_config.num_speculative_tokens)
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
self.token_arange_np = np.arange(self.max_num_tokens)
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.is_multimodal_model = vllm_config.model_config \
.is_multimodal_model
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
self.draft_indexer_metadata_builder: Optional[
AttentionMetadataBuilder] = None
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
self.use_cuda_graph = False
self.cudagraph_batch_sizes = []
# persistent buffers for cuda graph
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=device)
# self.hidden_states = torch.zeros(
# (self.max_num_tokens, self.hidden_size),
# dtype=self.dtype,
# device=device)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
self.arange = torch.arange(max_num_slots_for_arange,
device=device,
dtype=torch.int32)
if self.is_multimodal_model:
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
else:
self.inputs_embeds = None
self.backup_next_token_ids = CpuGpuBuffer(
max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True)
# Determine allowed attention backends once during initialization.
self.allowed_attn_types: Optional[tuple] = None
# Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree
self.tree_choices: list[tuple[int,
...]] = ast.literal_eval(spec_token_tree)
tree_depth = len(self.tree_choices[-1])
# Precompute per-level properties of the tree.
num_drafts_per_level = [0] * tree_depth
for node in self.tree_choices:
num_drafts_per_level[len(node) - 1] += 1
self.cu_drafts_per_level = [num_drafts_per_level[0]]
self.child_drafts_per_level = [num_drafts_per_level[0]]
for level in range(1, tree_depth):
self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] +
num_drafts_per_level[level])
self.child_drafts_per_level.append(num_drafts_per_level[level] //
num_drafts_per_level[level - 1])
# Precompute draft position offsets in flattened tree.
self.tree_draft_pos_offsets = torch.arange(
1,
len(self.tree_choices) + 1,
device=device,
dtype=torch.int32,
).repeat(max_batch_size, 1)

View File

View File

@@ -0,0 +1,68 @@
import numpy as np
import torch
import os
from vllm.distributed import get_dcp_group
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
logger = init_logger(__name__)
class BlockTable:
def __init__(
self,
block_size: int,
max_num_reqs: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
):
self.block_size = block_size
self.max_num_reqs = max_num_reqs
# self.max_num_blocks_per_req = max_num_blocks_per_req
max_num_blocks_per_req = (max_num_blocks_per_req + env_blk_grp_size//16 - 1) // (env_blk_grp_size//16) * (env_blk_grp_size//16)
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
self.pin_memory = pin_memory
self.device = device
# self.block_table = torch.zeros(
# (max_num_reqs, max_num_blocks_per_req),
# device=self.device,
# dtype=torch.int32,
# )
self.block_table = self._make_buffer(max_num_reqs,
max_num_blocks_per_req,
dtype=torch.int32)
# self.block_table_cpu = torch.zeros(
# (max_num_reqs, max_num_blocks_per_req),
# device="cpu",
# dtype=torch.int32,
# pin_memory=pin_memory,
# )
# self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
# self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
# dtype=torch.int32,
# device="cpu",
# pin_memory=self.pin_memory)
# self.slot_mapping_np = self.slot_mapping_cpu.numpy()
# self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
# dtype=torch.int32,
# device=self.device)
self.slot_mapping = self._make_buffer(self.max_num_batched_tokens,
dtype=torch.int32)
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0

View File

@@ -0,0 +1,491 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from typing import Optional, cast
import numpy as np
import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
LogitsProcessors,
MoveDirectionality)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable
@dataclass
class CachedRequestState:
req_id: str
prompt_token_ids: Optional[list[int]]
mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
output_token_ids: list[int]
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None
deepstack_input_embeds: Optional[torch.Tensor] = None
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargsItems]:
return [
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
if f.data is not None
]
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
if self.prompt_token_ids is None:
raise ValueError(
f"Tried to access token index {idx}, but that token was "
"provided via prompt_embeds, and its ID is unknown.")
return self.prompt_token_ids[idx]
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
return self.output_token_ids[idx - self.num_prompt_tokens]
else:
return -1
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self._req_ids: list[Optional[str]] = []
self.req_id_to_index: dict[str, int] = {}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
device="cpu",
dtype=bool,
pin_memory=False)
# Store prompt embeddings per request to avoid OOM from large upfront
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
#patch to add req_deepstack_input_embeds
self.req_deepstack_input_embeds: dict[int, torch.Tensor] = {}
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
num_speculative_tokens=num_speculative_tokens,
)
# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device=device)
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set()
# IDs of requests which do not support spec decoding
self.spec_decode_unsupported_reqs: set[str] = set()
# Frequency penalty related data structures
# self.frequency_penalties = torch.empty((max_num_reqs, ),
# dtype=torch.float,
# device=device)
self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures
# self.presence_penalties = torch.empty((max_num_reqs, ),
# dtype=torch.float,
# device=device)
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
)
self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures
# self.repetition_penalties = torch.empty((max_num_reqs, ),
# dtype=torch.float,
# device=device)
self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# Speculative decoding
self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ),
dtype=torch.int64,
device="cpu",
pin_memory=pin_memory)
self.num_accepted_tokens_cpu = \
self.num_accepted_tokens_cpu_tensor.numpy()
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
# Internal representation of per-step batch state changes, used for
# reordering persistent batch and generating logitsprocs batch state
# updates. Should reset each step.
self.batch_update_builder = BatchUpdateBuilder()
# TODO convert this to LogitsProcessor
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
dtype=bool)
self.req_output_token_ids: list[Optional[list[int]]] = []
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
self.pooling_params: dict[str, PoolingParams] = {}
# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: Optional[dict[str, int]] = None
def add_request(
self,
request: "CachedRequestState",
) -> int:
req_index = self._register_add_request(request)
req_id = request.req_id
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds)
self.num_prompt_tokens[req_index] = num_prompt_tokens
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
#patch to add req_deepstack_input_embeds
if request.deepstack_input_embeds is not None:
self.req_deepstack_input_embeds[req_index] = request.deepstack_input_embeds
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params:
if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
self.spec_decode_unsupported_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.GREEDY:
# Should avoid division by zero later when apply_temperature.
self.temperature_cpu[req_index] = 0.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = (self.vocab_size
if sampling_params.logprobs == -1
else sampling_params.logprobs)
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
self.vocab_size if sampling_params.prompt_logprobs == -1
else sampling_params.prompt_logprobs)
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
elif pooling_params := request.pooling_params:
self.pooling_params[req_id] = pooling_params
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids)
else:
raise NotImplementedError("Unrecognized request type")
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0
return req_index
def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
# if not self.all_greedy:
# temperature = copy_slice(self.temperature_cpu_tensor,
# self.temperature, num_reqs)
# else:
# temperature = None
# if not self.no_top_p:
# copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
# if not self.no_top_k:
# copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
# if not self.no_penalties:
# # Since syncing these tensors is expensive only copy them
# # if necessary i.e. if there are requests which require
# # penalties to be applied during sampling.
# copy_slice(self.frequency_penalties_cpu_tensor,
# self.frequency_penalties, num_reqs)
# copy_slice(self.presence_penalties_cpu_tensor,
# self.presence_penalties, num_reqs)
# copy_slice(self.repetition_penalties_cpu_tensor,
# self.repetition_penalties, num_reqs)
needs_prompt_token_ids = (
not self.no_penalties
or self.logits_processing_needs_token_ids[:num_reqs].any())
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor()
else:
prompt_token_ids = None
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
return SamplingMetadata(
temperature=self.temperature_cpu_tensor[:num_reqs].to(self.device),
all_greedy=self.all_greedy,
all_random=self.all_random,
# top_p=None if self.no_top_p else self.top_p_cpu_tensor[:num_reqs].to(self.device),
# top_k=None if self.no_top_k else torch.tensor([40 for _ in range(num_reqs)]).to(torch.int32).to(self.device),
top_p=torch.tensor([1 for _ in range(num_reqs)]).to(torch.float32).to(self.device) if self.no_top_p else self.top_p_cpu_tensor[:num_reqs].to(self.device),
top_k=torch.tensor([40 for _ in range(num_reqs)]).to(torch.int32).to(self.device),
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties_cpu_tensor[:num_reqs].tolist(),
presence_penalties=self.presence_penalties_cpu_tensor[:num_reqs].tolist(),
repetition_penalties=self.repetition_penalties_cpu_tensor[:num_reqs].tolist(),
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
temperature_cpu = self.temperature_cpu_tensor[:num_reqs],
top_p_cpu = self.top_p_cpu_tensor[:num_reqs],
top_k_cpu = torch.tensor([40 for _ in range(num_reqs)]).to(torch.int32),
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,340 @@
"""A VACC worker class."""
import gc
import os
from typing import Dict, List, Optional, Set, Tuple, Type, Union
from importlib import util
from typing import Optional
import torch
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import IntermediateTensors
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
get_dtype_size)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput, AsyncModelRunnerOutput
# from vllm.v1.worker.cpu_model_runner import VACCModelRunner
from vllm_vacc.vllm.v1.worker.vacc_model_runner import VACCModelRunner
from vllm.v1.worker.gpu_worker import Worker
from vllm.v1.kv_cache_interface import KVCacheConfig
# from vllm.worker.cache_engine import CacheEngine
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
logger = init_logger(__name__)
from vllm.utils import GiB_bytes
TP_GROUP_ID = 1234
def generate_rank_info_list():
global TP_GROUP_ID
from vllm.distributed import get_tp_group
# generate ran
get_tp_group().generate_rank_device_infos()
get_tp_group().generate_group_id(TP_GROUP_ID)
def generate_tp_group_id():
global TP_GROUP_ID
from pathlib import Path
import uuid
workspace_path = Path.cwd()
bootinfo_config = f'{workspace_path}/.bootinfos'
bootinfo_inited = os.path.exists(bootinfo_config)
current_bootinfos = "default"
if bootinfo_inited:
try:
with open(bootinfo_config) as w:
current_bootinfos = w.readline()
except Exception as e:
print("[WARN] bootinfo load fail ", e)
if current_bootinfos is not None:
unique_value = uuid.uuid5(uuid.NAMESPACE_URL, current_bootinfos).int
int32_value = unique_value & 0xFFFFFFFF
if int32_value >= 2**31:
int32_value -= 2**32
TP_GROUP_ID = int32_value
# print("current_bootinfos:", current_bootinfos, TP_GROUP_ID)
def init_worker_distributed_environment(
vllm_config: VllmConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
backend: str = "vccl",
) -> None:
"""Initialize the distributed environment."""
parallel_config = vllm_config.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, backend)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
generate_tp_group_id()
generate_rank_info_list()
def get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
key_cache_entry = num_heads * head_size
# For MLA there is no value cache, since the latent vector
# is joint keys and values.
value_cache_entry = key_cache_entry if not model_config.use_mla else 0
total = num_attention_layers * cache_config.block_size * \
(key_cache_entry + value_cache_entry)
dtype_size = get_dtype_size(dtype)
return dtype_size * total
class VACCWorker(Worker):
def __init__(self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False):
super().__init__(vllm_config,
local_rank,
rank,
distributed_init_method,
is_driver_worker=is_driver_worker)
self.parallel_config.disable_custom_all_reduce = True
def init_device(self) -> None:
if self.device_config.device.type == "vacc":
try:
self.device = torch.device(f"vacc:{self.local_rank}")
torch.vacc.set_device(self.device)
gc.collect()
torch.vacc.empty_cache()
except Exception as e:
raise RuntimeError(
f"device init fail: {e} ",
f"self.device: {self.device}, check /dev/* or VACC_VISIBLE_DEVICES")
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
# Construct the model runner
self.model_runner: VACCModelRunner = VACCModelRunner(
self.vllm_config, self.device)
def sleep(self, level: int = 1) -> None:
logger.warning("sleep mode is not supported on VACC, ignore it.")
pass
def wake_up(self, tags: Optional[list[str]] = None) -> None:
logger.warning("sleep mode is not supported on VACC, ignore it.")
pass
def get_cache_block_size_bytes(self) -> int:
"""Get the size of the KV cache block size in bytes.
"""
return get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
def determine_available_memory(self) -> int:
"""Determine the number of available KV blocks.
Swapping is not yet supported, so always return num_cpu_blocks=0.
We configure num_gpu_blocks to be equal to max_num_seqs.
"""
available_kv_cache_memory_min, num_gpu_blocks = self.determine_available_memory_block()
return available_kv_cache_memory_min
def determine_available_memory_block(self) -> int:
"""Determine the number of available KV blocks.
Swapping is not yet supported, so always return num_cpu_blocks=0.
We configure num_gpu_blocks to be equal to max_num_seqs.
"""
available_kv_cache_memory= int(os.getenv("VLLM_VACC_KVCACHE_SPACE", "16")) * GiB_bytes
max_seq_num = int(os.getenv("MAX_SEQ_NUM", 4))
max_num_gpu_blocks=0
if available_kv_cache_memory ==0:
torch.vacc.empty_cache()
torch.vacc.reset_peak_memory_stats()
total_memory = torch.vacc.mem_get_info()[1]
self.model_runner.profile_run()
torch.vacc.synchronize()
peak_memory = torch.vacc.max_memory_allocated()
torch.vacc.empty_cache()
torch_allocated_bytes = torch.vacc.memory_stats(
)["allocated_bytes.all.current"]
total_allocated_bytes = torch.vacc.mem_get_info(
)[1] - torch.vacc.mem_get_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory=total_memory*self.cache_config.gpu_memory_utilization - peak_memory
if self.model_config.hf_config.model_type == "deepseek_v3":
assert self.model_config.max_model_len <= 65536*self.vllm_config.parallel_config.pipeline_parallel_size, f"unsupported max model len, should less equal 65536 but got {self.model_config.max_model_len}"
# Rules:
# 1. always reserve N * 8K blocks
# 2. no less than (MAX_SEQ_NUM + 1) * 8K blocks
minimum_num_gpu_blocks_required = (max_seq_num + 1) * env_blk_grp_size // self.cache_config.block_size
max_model_len = (self.model_config.max_model_len + env_blk_grp_size - 1) // env_blk_grp_size * env_blk_grp_size
max_num_gpu_blocks = max_model_len // self.cache_config.block_size
# limited by available_kv_cache_memory
cache_block_size = self.get_cache_block_size_bytes()
if cache_block_size == 0:
num_gpu_blocks = 0
num_cpu_blocks = 0
else:
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
assert num_gpu_blocks >= minimum_num_gpu_blocks_required, \
f"num_gpu_blocks should >= {minimum_num_gpu_blocks_required} please increase VLLM_VACC_KVCACHE_SPACE"
torch.vacc.empty_cache()
# if self.model_runner.lora_manager:
# self.model_runner.remove_all_loras()
gc.collect()
if max_num_gpu_blocks != 0:
num_gpu_blocks = min(max_num_gpu_blocks, num_gpu_blocks)
num_gpu_blocks = max(num_gpu_blocks, minimum_num_gpu_blocks_required)
available_kv_cache_memory_min = num_gpu_blocks * cache_block_size
return available_kv_cache_memory_min, num_gpu_blocks
def compile_or_warm_up_model(self) -> None:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
# self.model_runner.warming_up_model()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
if forward_pass and not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
return output
if not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return None
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
"""Return VACCs id binding based on NUMA nodes.
"""
rank_to_cpus = self.local_omp_cpuid
# Setup OpenMP thread affinity based on NUMA nodes automatically
world_size = self.vllm_config.parallel_config.world_size
libnuma_found = util.find_spec("numa") is not None
psutil_found = util.find_spec("psutil") is not None
if libnuma_found and psutil_found:
import psutil
from numa import info
cpu_count = psutil.cpu_count(logical=False)
cpus_allow_list = psutil.Process().cpu_affinity()
numa_size = info.get_num_configured_nodes()
cpu_count_per_numa = cpu_count // numa_size
num_of_reserved_cpu = min(envs.VLLM_VACC_NUM_OF_RESERVED_VACC,
cpu_count_per_numa // 2)
# check allow node_to_cpus list
node_to_cpus = []
for i in range(numa_size):
node_intersect = set(
info.node_to_cpus(i)).intersection(cpus_allow_list)
if bool(node_intersect):
node_to_cpus.append(list(node_intersect))
if world_size > len(node_to_cpus):
logger.error(
"Auto thread-binding failed due to "
"world size: %d is larger than "
"allowed NUMA nodes number: %d."
"Please try to bind threads manually.", world_size,
len(node_to_cpus))
else:
end = cpu_count_per_numa - num_of_reserved_cpu
rank_to_cpus_list = node_to_cpus[self.rank][:end]
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
logger.info("auto thread-binding list: %s", rank_to_cpus)
else:
logger.warning(
"Auto thread-binding is not supported due to "
"the lack of package numa and psutil,"
"fallback to no thread-binding. To get better performance,"
"please try to manually bind threads.")
return rank_to_cpus
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.initialize_kv_cache(kv_cache_config)