init
This commit is contained in:
0
vllm_vacc/vllm/v1/__init__.py
Normal file
0
vllm_vacc/vllm/v1/__init__.py
Normal file
BIN
vllm_vacc/vllm/v1/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/__pycache__/request.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/__pycache__/request.cpython-312.pyc
Normal file
Binary file not shown.
0
vllm_vacc/vllm/v1/attention/__init__.py
Normal file
0
vllm_vacc/vllm/v1/attention/__init__.py
Normal file
BIN
vllm_vacc/vllm/v1/attention/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/attention/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
0
vllm_vacc/vllm/v1/attention/backends/__init__.py
Normal file
0
vllm_vacc/vllm/v1/attention/backends/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
982
vllm_vacc/vllm/v1/attention/backends/vacc_attn.py
Normal file
982
vllm_vacc/vllm/v1/attention/backends/vacc_attn.py
Normal file
@@ -0,0 +1,982 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
# from vllm_vacc.vllm.attention.backends.vacc_attn import (VACCAttentionBackendImpl,
|
||||
# VACCAttentionMetadata)
|
||||
# from vllm_vacc.vllm.attention.backends.vacc_attn import VACCAttentionBackendImpl
|
||||
# from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: list[int],
|
||||
) -> list[torch.Tensor]:
|
||||
attn_biases: list[torch.Tensor] = []
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
|
||||
inf_mask = torch.empty(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
seq_lens: list[int],
|
||||
window_size: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
) -> list[torch.Tensor]:
|
||||
attn_biases: list[torch.Tensor] = []
|
||||
for seq_len in seq_lens:
|
||||
tensor = torch.full(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=dtype,
|
||||
fill_value=1,
|
||||
)
|
||||
shift = 0
|
||||
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
||||
if window_size is not None:
|
||||
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
||||
mask = torch.log(mask)
|
||||
attn_biases.append(mask.to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
@dataclass
|
||||
class VACCAttentionMetadata(AttentionMetadata):
|
||||
"""Metadata for VACCAttentionMetadata.
|
||||
"""
|
||||
# Total number of prefill requests.
|
||||
num_prefills: int
|
||||
# Number of prefill tokens.
|
||||
num_prefill_tokens: int
|
||||
# Number of decode tokens. Note that it is equivalent to the number of
|
||||
# decode requests.
|
||||
num_decode_tokens: int
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
"""Metadata for PagedAttention."""
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
"""Metadata for TorchSDPABackend.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
chunked_prefill: bool
|
||||
seq_lens: Optional[list[int]] = None # For non-chunked prefill
|
||||
|
||||
# For chunked prefill only
|
||||
max_query_len: Optional[int] = None
|
||||
max_kv_len: Optional[int] = None
|
||||
prefill_query_start_loc: Optional[torch.Tensor] = None
|
||||
kv_start_loc: Optional[torch.Tensor] = None
|
||||
prefill_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# For V1 logits index only
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[list[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[list[torch.Tensor]] = None
|
||||
self.encoder_attn_bias: Optional[list[torch.Tensor]] = None
|
||||
self.cross_attn_bias: Optional[list[torch.Tensor]] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["VACCAttentionMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_prefill_tokens == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["VACCAttentionMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
def get_seq_lens(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
):
|
||||
'''
|
||||
Extract appropriate sequence lengths from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate sequence lengths tensor for query
|
||||
* Appropriate sequence lengths tensor for key & value
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
seq_lens_q = self.seq_lens
|
||||
seq_lens_kv = self.seq_lens
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
seq_lens_q = self.encoder_seq_lens
|
||||
seq_lens_kv = self.encoder_seq_lens
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
seq_lens_q = self.seq_lens
|
||||
seq_lens_kv = self.encoder_seq_lens
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
return seq_lens_q, seq_lens_kv
|
||||
|
||||
def get_attn_bias(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
) -> Optional[list[torch.Tensor]]:
|
||||
'''
|
||||
Extract appropriate attention bias from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate attention bias value given the attention type
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
return self.attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
return self.encoder_attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
return self.cross_attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
def set_attn_bias(
|
||||
self,
|
||||
attn_bias: list[torch.Tensor],
|
||||
attn_type: AttentionType,
|
||||
) -> None:
|
||||
'''
|
||||
Update appropriate attention bias field of attention metadata,
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_bias: The desired attention bias value
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
self.attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
self.encoder_attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
self.cross_attn_bias = attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
def get_seq_len_block_table_args(
|
||||
self,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
return (self.seq_lens_tensor, self.max_decode_seq_len,
|
||||
self.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||
self.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||
None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
# class VACCMetadataBuilder(AttentionMetadataBuilder[VACCAttentionMetadata]):
|
||||
|
||||
# def __init__(self, input_builder: ModelInputForVACCBuilder) -> None:
|
||||
# self.chunked_prefill = input_builder.chunked_prefill
|
||||
# self.input_builder = input_builder
|
||||
|
||||
# def prepare(self):
|
||||
# self.input_data = self.input_builder.input_data
|
||||
|
||||
# def build(self, seq_lens: list[int], query_lens: list[int],
|
||||
# cuda_graph_pad_size: int, batch_size: int) -> VACCAttentionMetadata:
|
||||
# input_data = self.input_data
|
||||
# prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||
# prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||
# slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||
# dtype=torch.int32,
|
||||
# device=self.input_builder.device)
|
||||
|
||||
# # For chunked-prefill
|
||||
# if self.chunked_prefill and input_data.num_prefill_tokens != 0:
|
||||
# prefill_block_tables = make_tensor_with_pad(
|
||||
# self.input_data.prefill_block_tables,
|
||||
# pad=0,
|
||||
# dtype=torch.int32,
|
||||
# device=self.input_builder.device,
|
||||
# )
|
||||
# query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||
# dtype=torch.int32,
|
||||
# device=self.input_builder.device)
|
||||
# kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||
# dtype=torch.int32,
|
||||
# device=self.input_builder.device)
|
||||
# query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
# dtype=torch.int32,
|
||||
# device=self.input_builder.device)
|
||||
# kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
# dtype=torch.int32,
|
||||
# device=self.input_builder.device)
|
||||
# torch.cumsum(query_lens_tensor,
|
||||
# dim=0,
|
||||
# dtype=torch.int32,
|
||||
# out=query_start_loc[1:])
|
||||
# torch.cumsum(kv_lens_tensor,
|
||||
# dim=0,
|
||||
# dtype=torch.int32,
|
||||
# out=kv_start_loc[1:])
|
||||
# max_query_len = max(prefill_query_lens)
|
||||
# max_kv_len = max(prefill_seq_lens)
|
||||
# else:
|
||||
# prefill_block_tables = None
|
||||
# query_start_loc = None
|
||||
# kv_start_loc = None
|
||||
# max_query_len = None
|
||||
# max_kv_len = None
|
||||
|
||||
# # For paged attention
|
||||
# if input_data.num_decode_tokens != 0:
|
||||
# seq_lens_tensor = torch.tensor(
|
||||
# input_data.seq_lens[input_data.num_prefills:],
|
||||
# dtype=torch.int32,
|
||||
# device=self.input_builder.device,
|
||||
# )
|
||||
# block_tables = make_tensor_with_pad(
|
||||
# self.input_data.decode_block_tables,
|
||||
# pad=0,
|
||||
# dtype=torch.int32,
|
||||
# device=self.input_builder.device,
|
||||
# )
|
||||
# # lowest_dim_size = block_tables.size(-1)
|
||||
# # if lowest_dim_size < 1024:
|
||||
# # padding_amount = 1024 - lowest_dim_size
|
||||
# # padding = torch.zeros(*block_tables.size()[:-1], padding_amount, dtype=block_tables.dtype, device=block_tables.device)
|
||||
# # block_tables = torch.cat((block_tables, padding), dim=-1)
|
||||
# else:
|
||||
# block_tables = torch.tensor([])
|
||||
# seq_lens_tensor = torch.tensor(
|
||||
# input_data.seq_lens[:input_data.num_prefills],
|
||||
# dtype=torch.int32,
|
||||
# device=self.input_builder.device,
|
||||
# )
|
||||
|
||||
# # For multi-modal models
|
||||
# placeholder_index_maps = None
|
||||
# if len(input_data.multi_modal_inputs_list) != 0:
|
||||
# placeholder_index_maps = {
|
||||
# modality: placeholder_map.index_map()
|
||||
# for modality, placeholder_map in
|
||||
# input_data.multi_modal_placeholder_maps.items()
|
||||
# }
|
||||
|
||||
# attn_metadata = VACCAttentionMetadata(
|
||||
# chunked_prefill=self.chunked_prefill,
|
||||
# seq_lens=seq_lens, #prefill_seq_lens,
|
||||
# seq_lens_tensor=seq_lens_tensor,
|
||||
# max_query_len=max_query_len,
|
||||
# max_kv_len=max_kv_len,
|
||||
# query_start_loc=query_start_loc,
|
||||
# kv_start_loc=kv_start_loc,
|
||||
# max_decode_seq_len=None,
|
||||
# num_prefills=input_data.num_prefills,
|
||||
# num_prefill_tokens=input_data.num_prefill_tokens,
|
||||
# num_decode_tokens=input_data.num_decode_tokens,
|
||||
# block_tables=block_tables,
|
||||
# prefill_block_tables=prefill_block_tables,
|
||||
# slot_mapping=slot_mapping,
|
||||
# multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
# enable_kv_scales_calculation=False,
|
||||
# )
|
||||
# return attn_metadata
|
||||
|
||||
def fp32_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
mask,
|
||||
norm_factor,
|
||||
out_type=None,
|
||||
):
|
||||
ori_type = out_type if out_type is not None else query_layer.dtype
|
||||
query_layer = query_layer.to(torch.float32)
|
||||
key_layer = key_layer.to(torch.float32)
|
||||
value_layer = value_layer.to(torch.float32)
|
||||
|
||||
# GQA
|
||||
if query_layer.size(1) != key_layer.size(1):
|
||||
if query_layer.size(1) % key_layer.size(1) != 0:
|
||||
assert False
|
||||
groups = query_layer.size(1) // key_layer.size(1)
|
||||
key_layer = torch.repeat_interleave(key_layer, groups, dim=1)
|
||||
value_layer = torch.repeat_interleave(value_layer, groups, dim=1)
|
||||
|
||||
matmul_result = torch.bmm(
|
||||
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
||||
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
||||
) * norm_factor
|
||||
mask_output = matmul_result
|
||||
if mask != None:
|
||||
mask = mask if mask.dim() >= 3 else mask.unsqueeze(0)
|
||||
mask_output = matmul_result.masked_fill_(mask, -10000.0) # [b * np, sq, sk]
|
||||
probs = torch.nn.Softmax(dim=-1)(mask_output)
|
||||
context_layer = torch.bmm(probs, value_layer.transpose(0, 1))
|
||||
|
||||
return context_layer.transpose(0, 1).to(ori_type)
|
||||
|
||||
|
||||
class VACCAttentionBackendImpl(AttentionImpl[VACCAttentionMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
|
||||
|
||||
# if logits_soft_cap is not None:
|
||||
# logger.warning_once("Torch SPDA does not support logits soft cap. "
|
||||
# "Outputs may be slightly off.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.need_mask = (self.alibi_slopes is not None
|
||||
or self.sliding_window is not None)
|
||||
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError(
|
||||
"Torch SDPA backend does not support FP8 KV cache. "
|
||||
"Please use xFormers backend instead.")
|
||||
self.attn_type = attn_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: VACCAttentionMetadata, # type: ignore
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
attn_type = self.attn_type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
#
|
||||
# Even if there are no new key/value pairs to cache,
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
# During cross-attention decode, key & value will be None,
|
||||
# preventing this IF-statement branch from running
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale, layer._v_scale)
|
||||
|
||||
if attn_type != AttentionType.ENCODER:
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
# Encoder/decoder cross-attention requires no chunked
|
||||
# prefill (100% prefill or 100% decode tokens, no mix)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
# derive token-count from query shape & and treat them
|
||||
# as 100% prefill tokens
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_tokens = 0
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Only enforce this shape-constraint for decoder
|
||||
# self-attention
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
output = torch.empty_like(query)
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if (kv_cache.numel() == 0
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
self._run_vacc_forward(
|
||||
output,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta,
|
||||
attn_type=attn_type)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert not self.need_mask
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
output = torch.empty_like(query)
|
||||
ipex_modules.PagedAttention.flash_attn_varlen_func(
|
||||
output[:prefill_meta.num_prefill_tokens, :, :],
|
||||
query[:prefill_meta.num_prefill_tokens, :, :],
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.kv_start_loc,
|
||||
prefill_meta.max_query_len,
|
||||
prefill_meta.max_kv_len,
|
||||
self.scale,
|
||||
True,
|
||||
prefill_meta.prefill_block_tables,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
# Decoding run.
|
||||
# (
|
||||
# seq_lens_arg,
|
||||
# max_seq_len_arg,
|
||||
# block_tables_arg,
|
||||
# ) = decode_meta.get_seq_len_block_table_args(attn_type)
|
||||
|
||||
# Note:
|
||||
# decode attention still use SDPA method
|
||||
# reshape k/v_cache to (num_block_grp, block_grp_size, head, hidden_size)
|
||||
k_cache = key_cache.view(-1, env_blk_grp_size, key_cache.shape[2], key_cache.shape[3])
|
||||
v_cache = value_cache.view(-1, env_blk_grp_size, value_cache.shape[2], value_cache.shape[3])
|
||||
block_per_group = env_blk_grp_size // 16
|
||||
# convert block_tables to 8K group index
|
||||
block_tables = (decode_meta.block_tables // block_per_group).to(torch.int32)
|
||||
attn_outs = []
|
||||
for i in range(len(decode_meta.seq_lens_tensor)):
|
||||
seq_len = decode_meta.seq_lens_tensor[i]
|
||||
k_slices = k_cache[block_tables[i], ...]
|
||||
k = \
|
||||
torch.cat([k_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
|
||||
v_slices = v_cache[block_tables[i], ...]
|
||||
v = \
|
||||
torch.cat([v_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
|
||||
q = query[i : i + 1, ...]
|
||||
if q.dtype == torch.bfloat16:
|
||||
attn_out = fp32_attention(
|
||||
q.cpu(),
|
||||
k.cpu(),
|
||||
v.cpu(),
|
||||
None,
|
||||
self.scale
|
||||
).to(query.dtype).to(query.device)
|
||||
else:
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale,
|
||||
)
|
||||
attn_outs.append(attn_out)
|
||||
output = torch.cat(attn_outs, dim=0)
|
||||
# '''
|
||||
|
||||
# PagedAttention.forward_decode(
|
||||
# output[attn_metadata.num_prefill_tokens:, :, :],
|
||||
# query[attn_metadata.num_prefill_tokens:, :, :],
|
||||
# key_cache,
|
||||
# value_cache,
|
||||
# block_tables_arg,
|
||||
# seq_lens_arg,
|
||||
# max_seq_len_arg,
|
||||
# self.kv_cache_dtype,
|
||||
# self.num_kv_heads,
|
||||
# self.scale,
|
||||
# self.alibi_slopes,
|
||||
# layer._k_scale,
|
||||
# layer._v_scale,
|
||||
# )
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def _run_vacc_forward(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: VACCAttentionMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
):
|
||||
# if self.num_kv_heads != self.num_heads:
|
||||
# key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
# value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
attn_masks = attn_metadata.get_attn_bias(attn_type)
|
||||
if attn_masks is None:
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes, query.dtype,
|
||||
attn_metadata.seq_lens) # type: ignore
|
||||
elif self.sliding_window is not None:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_masks = _make_sliding_window_bias(
|
||||
attn_metadata.seq_lens, self.sliding_window,
|
||||
query.dtype) # type: ignore
|
||||
else:
|
||||
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
|
||||
attn_masks = [None] * len(seq_lens)
|
||||
attn_metadata.set_attn_bias(attn_masks, attn_type)
|
||||
|
||||
causal_attn = (attn_type == AttentionType.DECODER)
|
||||
|
||||
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
|
||||
attn_masks):
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
sub_out=torch.vacc.scaled_dot_product_attention(
|
||||
query[start_q:end_q,:, :].to(torch.float16) * (self.scale),
|
||||
key[start_kv:end_kv,:, :].to(torch.float16),
|
||||
value[start_kv:end_kv,:, :].contiguous().to(torch.float16),
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=True if attn_type == AttentionType.DECODER else False, #causal_attn and not self.need_mask,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=1)
|
||||
output[ start_q:end_q,:, :] = sub_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
return output
|
||||
|
||||
|
||||
class VACCAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_SDPA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["VACCAttentionBackendImpl"]:
|
||||
return VACCAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return VACCAttentionMetadata
|
||||
|
||||
# @staticmethod
|
||||
# def get_state_cls() -> type["CommonAttentionState"]:
|
||||
# return CommonAttentionState
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["VACCAttentionMetadataBuilderV1"]:
|
||||
return VACCAttentionMetadataBuilderV1
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str:str,
|
||||
) -> tuple[int, ...]:
|
||||
# return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
# num_kv_heads, head_size)
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class VACCAttentionMetadataBuilderV1(AttentionMetadataBuilder[VACCAttentionMetadata]):
|
||||
|
||||
# def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
# block_table: BlockTable) -> None:
|
||||
# self.runner = runner
|
||||
# self.block_table = block_table
|
||||
|
||||
# # For reorder
|
||||
# self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
|
||||
# dtype=np.int64)
|
||||
# self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
|
||||
# dtype=np.int64)
|
||||
# self.num_prompt_req: int = 0
|
||||
|
||||
# self.seq_start_loc_cpu = torch.zeros(
|
||||
# runner.max_num_reqs + 1,
|
||||
# dtype=torch.int32,
|
||||
# device="cpu",
|
||||
# )
|
||||
# self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
||||
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device) -> None:
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
# For reorder
|
||||
self.reorder_prompt_req_index_list = np.empty(
|
||||
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
|
||||
self.reorder_decode_req_index_list = np.empty(
|
||||
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
|
||||
|
||||
self.num_prompt_req: int = 0
|
||||
|
||||
self.seq_start_loc_cpu = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
||||
|
||||
|
||||
# def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
# vllm_config: VllmConfig, device: torch.device):
|
||||
# super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
# self.block_size = kv_cache_spec.block_size
|
||||
|
||||
# model_config = vllm_config.model_config
|
||||
# self.num_heads_q = model_config.get_num_attention_heads(
|
||||
# vllm_config.parallel_config)
|
||||
# self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
# vllm_config.parallel_config)
|
||||
# self.headdim = model_config.get_head_size()
|
||||
|
||||
|
||||
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
prompt_list_idx = 0
|
||||
decode_list_idx = 0
|
||||
for req_index in range(input_batch.num_reqs):
|
||||
if input_batch.num_computed_tokens_cpu[
|
||||
req_index] < input_batch.num_prompt_tokens[req_index]:
|
||||
# prompt stage
|
||||
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
|
||||
prompt_list_idx += 1
|
||||
else:
|
||||
# decode stage
|
||||
self.reorder_decode_req_index_list[decode_list_idx] = req_index
|
||||
decode_list_idx += 1
|
||||
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
|
||||
|
||||
# Update prompt requests number
|
||||
self.num_prompt_req = prompt_list_idx
|
||||
|
||||
reorder_req_num = 0
|
||||
for req_index in range(decode_list_idx):
|
||||
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
|
||||
reorder_req_num += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if reorder_req_num == 0:
|
||||
return False
|
||||
|
||||
reorder_prompt_list = (
|
||||
self.reorder_prompt_req_index_list[:prompt_list_idx]
|
||||
[-reorder_req_num:])
|
||||
reorder_decode_list = (
|
||||
self.reorder_decode_req_index_list[:decode_list_idx]
|
||||
[:reorder_req_num])
|
||||
assert reorder_decode_list.size == reorder_prompt_list.size
|
||||
|
||||
for idx in range(reorder_req_num):
|
||||
prompt_req_index = reorder_prompt_list[idx].item()
|
||||
decode_req_index = reorder_decode_list[idx].item()
|
||||
input_batch.swap_states(prompt_req_index, decode_req_index)
|
||||
|
||||
return True
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
# seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
# runner = self.runner
|
||||
# block_table = self.block_table
|
||||
# seq_lens = runner.seq_lens[:num_reqs]
|
||||
num_prompt_req = self.num_prompt_req
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item())
|
||||
num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() -
|
||||
num_prefill_tokens)
|
||||
|
||||
# print('query_start_loc_cpu', query_start_loc_cpu)
|
||||
# print('num_prompt_req', num_prompt_req)
|
||||
# print('num_reqs', num_reqs)
|
||||
|
||||
# num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
|
||||
# num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
|
||||
# ) - num_prefill_tokens
|
||||
|
||||
# block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
# block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
# non_blocking=True)
|
||||
|
||||
# slot_mapping = block_table.slot_mapping[:num_actual_tokens] #.long()
|
||||
|
||||
# block_table_tensor = block_table.get_device_tensor()
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
|
||||
block_num_per_group = env_blk_grp_size // 16
|
||||
block_table_tensor_new = block_table_tensor[:num_reqs-num_prompt_req, ::block_num_per_group].contiguous()
|
||||
# [bs, seq//16] => [bs, seq//16//block_num_per_group, block_num_per_group]
|
||||
# => [:num_reqs, :, 0] 提取前reqs行,并且把 block_num_per_group 的倍数提取出
|
||||
|
||||
attn_metadata = VACCAttentionMetadata(
|
||||
num_prefills=num_prompt_req,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens_tensor=seq_lens, # decode
|
||||
max_decode_seq_len=None, # decode
|
||||
block_tables=block_table_tensor_new, # decode
|
||||
chunked_prefill=False,
|
||||
# max_query_len=max_query_len,
|
||||
# max_kv_len=max_prefill_seq_len,
|
||||
# prefill_query_start_loc=runner.
|
||||
# query_start_loc_cpu[:num_prompt_req + 1], # prefill
|
||||
# kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
|
||||
# 1], # prefill
|
||||
prefill_block_tables=block_table_tensor[:
|
||||
num_prompt_req], # prefill
|
||||
query_start_loc=query_start_loc_cpu[:num_reqs +
|
||||
1], # for logits index
|
||||
# multi_modal_placeholder_index_maps=None,
|
||||
# enable_kv_scales_calculation=False,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
953
vllm_vacc/vllm/v1/attention/backends/vacc_mla.py
Normal file
953
vllm_vacc/vllm/v1/attention/backends/vacc_mla.py
Normal file
@@ -0,0 +1,953 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, List, Optional, Dict, TypeVar, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata, AttentionType,
|
||||
)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonMetadata)
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonImpl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
|
||||
M = TypeVar("M", bound=MLACommonMetadata)
|
||||
|
||||
def vacc_paged_attention_naive(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
# seq_lens: torch.Tensor,
|
||||
seq_lens: int,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
sm_scale = -1
|
||||
) -> torch.Tensor:
|
||||
|
||||
# gurantee batch=1 perf
|
||||
if len(seq_lens) == 1:
|
||||
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[0]]
|
||||
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[0]]
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=sm_scale
|
||||
)
|
||||
else:
|
||||
# t0 = time.time()
|
||||
attn_outs = []
|
||||
for i in range(len(seq_lens)):
|
||||
k_slices = key_cache[block_table[i], :, :, :]
|
||||
k = torch.cat([k_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
|
||||
k = k.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
|
||||
v_slices = value_cache[block_table[i], :, :, :]
|
||||
v = torch.cat([v_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
|
||||
v = v.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
|
||||
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=query[i:i+1,:,:],
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=sm_scale
|
||||
)
|
||||
attn_outs.append(attn_out)
|
||||
|
||||
attn_out = torch.cat(attn_outs, dim=0)
|
||||
# print(f'{os.getpid()} call spda(seq: {seq_lens}) time: {time.time() - t0}')
|
||||
return attn_out
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonPrefillMetadata:
|
||||
""" Prefill Specific Metadata """
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling chunked prefill
|
||||
cu_seq_lens: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
seq_tot: list[int]
|
||||
max_seq_lens: list[int]
|
||||
workspace: torch.Tensor
|
||||
|
||||
block_tables: torch.Tensor #block_table => block_tables 兼容v0
|
||||
query_start_loc: torch.Tensor
|
||||
# max_query_len: int
|
||||
seq_lens: list[int]
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonDecodeMetadata:
|
||||
block_tables: torch.Tensor #block_table => block_tables 兼容v0
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
class VACCMLAMetadata(MLACommonMetadata):
|
||||
"""Metadata for VACCMLAMetadata.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["VACCMLAMetadata"] = None
|
||||
_cached_decode_metadata: Optional["VACCMLAMetadata"] = None
|
||||
|
||||
num_prefill_tokens: int
|
||||
|
||||
num_kv_splits: int = 4 # TODO(lucas) add heuristic
|
||||
attn_logits: Optional[torch.Tensor] = None
|
||||
req_idx: Optional[torch.Tensor] = None
|
||||
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
supported_head_sizes = VACCMLABackend.get_supported_head_sizes()
|
||||
if self.head_dim is not None and self.head_dim \
|
||||
not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f"received {self.head_dim}.")
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["VACCMLAMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[:self.num_prefill_tokens])
|
||||
|
||||
self._cached_prefill_metadata = VACCMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
input_positions=input_positions,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=None,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["VACCMLAMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[self.num_prefill_tokens:])
|
||||
|
||||
self._cached_decode_metadata = VACCMLAMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
# Batch may be composed of prefill|decodes, adjust query start
|
||||
# indices to refer to the start of decodes. E.g.
|
||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
input_positions=input_positions,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForVACCWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
# self.num_prefill_tokens = 0
|
||||
# self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
# assert self.max_query_len == 1
|
||||
# assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
# self.max_decode_seq_len = None
|
||||
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
|
||||
class VACCMLAImpl(MLACommonImpl[VACCMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
# blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**kwargs) -> None:
|
||||
# super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
# alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
# logits_soft_cap, attn_type,
|
||||
# kv_sharing_target_layer_name, **kwargs)
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# print('kwargs', kwargs) # 手动写老版本的继承
|
||||
self.q_lora_rank = kwargs['q_lora_rank'] if 'q_lora_rank' in kwargs else None
|
||||
self.kv_lora_rank = kwargs['kv_lora_rank'] if 'kv_lora_rank' in kwargs else None
|
||||
self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] if 'qk_nope_head_dim' in kwargs else None
|
||||
self.qk_head_dim = kwargs['qk_head_dim'] if 'qk_head_dim' in kwargs else None
|
||||
self.qk_head_dim = kwargs['qk_head_dim'] if 'qk_head_dim' in kwargs else None
|
||||
self.v_head_dim = kwargs['v_head_dim'] if 'v_head_dim' in kwargs else None
|
||||
self.rotary_emb = kwargs['rotary_emb'] if 'rotary_emb' in kwargs else None
|
||||
self.q_proj = kwargs['q_proj'] if 'q_proj' in kwargs else None
|
||||
self.kv_b_proj = kwargs['kv_b_proj'] if 'kv_b_proj' in kwargs else None
|
||||
self.o_proj = kwargs['o_proj'] if 'o_proj' in kwargs else None
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"VACCMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"VACCMLAImpl")
|
||||
|
||||
def extract_weights(self):
|
||||
weights = {}
|
||||
if hasattr(self, 'W_Q'):
|
||||
weights["W_Q"] = self.W_Q
|
||||
if hasattr(self, 'W_Q_scales'):
|
||||
weights["W_Q_scales"] = self.W_Q_scales
|
||||
if hasattr(self, 'W_QR'):
|
||||
weights['W_QR'] = self.W_QR
|
||||
if hasattr(self, 'W_QR_scales'):
|
||||
weights["W_QR_scales"] = self.W_QR_scales
|
||||
if hasattr(self, 'W_Q_QR'):
|
||||
weights["W_Q_QR"] = self.W_Q_QR
|
||||
if hasattr(self, 'W_Q_QR_scales'):
|
||||
weights["W_Q_QR_scales"] = self.W_Q_QR_scales
|
||||
if hasattr(self, 'W_UK'):
|
||||
weights['W_UK'] = self.W_UK
|
||||
if hasattr(self, 'W_UK_scales'):
|
||||
weights['W_UK_scales'] = self.W_UK_scales
|
||||
if hasattr(self, 'W_Q_UK_scales'):
|
||||
weights['W_Q_UK_scales'] = self.W_Q_UK_scales
|
||||
if hasattr(self, 'W_UV'):
|
||||
weights['W_UV'] = self.W_UV
|
||||
if hasattr(self, 'W_UV_scales'):
|
||||
weights['W_UV_scales'] = self.W_UV_scales
|
||||
if hasattr(self, 'W_UV_O'):
|
||||
weights['W_UV_O'] = self.W_UV_O
|
||||
if hasattr(self, 'W_UV_O_scales'):
|
||||
weights['W_UV_O_scales'] = self.W_UV_O_scales
|
||||
return weights
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: VACCMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert isinstance(attn_metadata, VACCMLAV1Metadata)
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0]\
|
||||
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
v = v.contiguous()
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
# value=0)
|
||||
# attn_output = torch.vacc.scaled_dot_product_attention(
|
||||
# query=q,
|
||||
# key=k,
|
||||
# value=v_padded,
|
||||
# attn_mask=None,
|
||||
# dropout_p=0,
|
||||
# is_causal=True,
|
||||
# is_train=False,
|
||||
# recompute=False,
|
||||
# flash_attention=True,
|
||||
# sm_scale=self.scale
|
||||
# )
|
||||
|
||||
# attn_output = attn_output\
|
||||
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
||||
# .reshape(-1, self.num_heads * v.shape[-1])
|
||||
seq_lens = attn_metadata.prefill_metadata.seq_lens
|
||||
if len(seq_lens) == 1:
|
||||
# Vacc supports different head dim of v and qk.
|
||||
attn_output = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=True,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale
|
||||
)
|
||||
attn_out = attn_output.view(-1, self.num_heads * v.shape[-1])
|
||||
else:
|
||||
attn_outs = []
|
||||
start = 0
|
||||
for seq in seq_lens:
|
||||
end = start + seq
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q[start:end, :],
|
||||
key=k[start:end, :],
|
||||
value=v[start:end, :],
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=True,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale
|
||||
)
|
||||
start = end
|
||||
attn_outs.append(attn_out)
|
||||
attn_out = torch.cat(attn_outs, dim=0).view(-1, self.num_heads * v.shape[-1])
|
||||
|
||||
return self.o_proj(attn_out)[0]
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: VACCMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
# Add a head dim of 1
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
# print(f"kv_c_and_k_pe_cache: {kv_c_and_k_pe_cache.shape} ")
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
|
||||
# Run MQA using paged_attention
|
||||
# o = torch.vacc.paged_attention(
|
||||
# query=q,
|
||||
# key_cache=kv_c_and_k_pe_cache,
|
||||
# value_cache=kv_c_cache,
|
||||
# block_table=decode_meta.block_tables,
|
||||
# seq_len=decode_meta.seq_lens_tensor,
|
||||
# out=o,
|
||||
# sm_scale=self.scale
|
||||
# )
|
||||
|
||||
# Run MQA using spda
|
||||
# t0 = time.time()
|
||||
o = vacc_paged_attention_naive(
|
||||
q,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_cache,
|
||||
block_table = decode_meta.block_tables,
|
||||
# seq_lens = decode_meta.seq_lens_tensor,
|
||||
seq_lens=decode_meta.seq_lens,
|
||||
out = o,
|
||||
sm_scale=self.scale)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
|
||||
# patch from MLACommonBackend
|
||||
class VACCMLABackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return VACCMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["VACCMLAImpl"]:
|
||||
return VACCMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["MLAVaccMetadataBuilder"]:
|
||||
return MLAVaccMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
|
||||
D = TypeVar("D", bound=MLACommonDecodeMetadata)
|
||||
# patch from MLACommonMetadata
|
||||
@dataclass
|
||||
class VACCMLAV1Metadata(Generic[D]):
|
||||
"""Metadata for MLACommon.
|
||||
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
|
||||
num_prefill_tokens: int
|
||||
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
decode_metadata: Optional[D] = None
|
||||
prefill_metadata: Optional[MLACommonPrefillMetadata] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.head_dim is not None:
|
||||
MLACommonBackend.validate_head_size(self.head_dim)
|
||||
|
||||
|
||||
|
||||
class MLAVaccMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
reorder_batch_threshold: ClassVar[int] = 2
|
||||
#TODO 区分 prefill decode 阈值
|
||||
def __init__(self,
|
||||
# runner: "GPUModelRunner",
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[type[M]] = None):
|
||||
self._global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(vllm_config, layer_names,
|
||||
MLACommonImpl))
|
||||
self.metadata_cls = metadata_cls \
|
||||
if metadata_cls is not None else VACCMLAV1Metadata
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.device = device
|
||||
# self.runner = runner
|
||||
# scheduler_config = runner.scheduler_config
|
||||
# model_config = runner.model_config
|
||||
# cache_config = runner.cache_config
|
||||
self.chunked_prefill_enabled = False
|
||||
self.num_heads = self.model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.aot_schedule = current_platform.is_cuda()
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
# Dont try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
self.page_size = self.kv_cache_spec.block_size
|
||||
|
||||
# if self.chunked_prefill_enabled:
|
||||
# self.chunked_prefill_workspace_size = min(
|
||||
# # Max sure there is enough for 8 full length request or at least
|
||||
# # 4 pages of cache per request
|
||||
# max(
|
||||
# 8 * model_config.max_model_len, 4 *
|
||||
# scheduler_config.max_num_seqs * cache_config.block_size),
|
||||
# # For long-context models try not to over-allocate limiting
|
||||
# # kv-cache space, limiting it to 64k tokens,
|
||||
# # which would result in the workspace being:
|
||||
# # 2*(576)*(64*1024) = 144mb
|
||||
# # (assuming 576 MLA head dim, and fp16)
|
||||
# # which would result in up-projected context being
|
||||
# # 2*(192*128)*(64*1024) = 3gb
|
||||
# # (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
# 128 * 1024)
|
||||
# assert self.chunked_prefill_workspace_size >= \
|
||||
# scheduler_config.max_num_seqs * cache_config.block_size
|
||||
# self.chunked_prefill_workspace = torch.empty(
|
||||
# (self.chunked_prefill_workspace_size,
|
||||
# model_config.get_head_size()),
|
||||
# dtype=model_config.dtype,
|
||||
# device=runner.device,
|
||||
# )
|
||||
# self.block_table = block_table
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
if scheduler_output.scheduled_cached_reqs.num_computed_tokens != []:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor):
|
||||
return MLACommonDecodeMetadata(
|
||||
block_tables=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
|
||||
# def build_for_cudagraph_capture(
|
||||
# self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
# """
|
||||
# This method builds the metadata for full cudagraph capture.
|
||||
# Currently, only decode is supported for full cudagraphs with MLA.
|
||||
# """
|
||||
# m = common_attn_metadata
|
||||
# assert m.num_reqs == m.num_actual_tokens, \
|
||||
# "MLA only supports decode-only full CUDAGraph capture. " \
|
||||
# "Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
# # m.max_query_len = 1 # decode-only
|
||||
|
||||
# # Update state usually set in reorder_batch.
|
||||
# self._num_decodes = m.num_reqs
|
||||
# self._num_decode_tokens = m.num_actual_tokens
|
||||
# self._num_prefills = 0
|
||||
# self._num_prefill_tokens = 0
|
||||
# return self.build(0, m)
|
||||
|
||||
def append_seqlen(self, seq_len: list[int], all_len: int):
|
||||
# print('append_seqlen seq_len', seq_len)
|
||||
# print('append_seqlen all_len', all_len)
|
||||
if all_len > len(seq_len) and all_len % len(seq_len) == 0:
|
||||
new_seq_len = []
|
||||
mtp_num = all_len // len(seq_len)
|
||||
for start_len in seq_len:
|
||||
for i in range(1,1+mtp_num):
|
||||
new_seq_len.append(start_len-mtp_num+i)
|
||||
return new_seq_len
|
||||
return seq_len
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> M:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
# max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
# assert self._num_decodes + self._num_prefills == num_reqs
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
# device = self.runner.device
|
||||
# block_table = self.block_table
|
||||
# block_table_tensor = block_table.get_device_tensor()#[:num_reqs]
|
||||
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
#decode common_attn_metadata: CommonAttentionMetadata(query_start_loc=tensor([0, 2], device='vacc:20', dtype=torch.int32), query_start_loc_cpu=tensor([0, 2], dtype=torch.int32), seq_lens=[40], seq_lens_cpu=[40, 0, 0, 0], num_computed_tokens_cpu=tensor([38], dtype=torch.int32), num_reqs=1, num_actual_tokens=2, max_query_len=2, max_seq_len=40, block_table_tensor=tensor([[1536, 1537, 1538, ..., 0, 0, 0]], device='vacc:20',
|
||||
|
||||
# block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
# block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
# non_blocking=True)
|
||||
# # block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
# slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
# num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
# split_decodes_and_prefills(common_attn_metadata,
|
||||
# decode_threshold=self.reorder_batch_threshold)
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = 0, 0, 0, 0
|
||||
token_nums = common_attn_metadata.query_start_loc_cpu[1:] - common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
if token_nums.max().item() > self.reorder_batch_threshold:
|
||||
num_prefills = num_reqs
|
||||
num_prefill_tokens = num_actual_tokens
|
||||
else:
|
||||
num_decodes = num_reqs
|
||||
num_decode_tokens = num_actual_tokens
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
# context_lens_cpu = self.runner.input_batch.\
|
||||
# num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
||||
# max_context_len_cpu = context_lens_cpu.max().item()
|
||||
# num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
chunked_context_metadata = None
|
||||
# if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
||||
# and max_context_len_cpu > 0:
|
||||
# dont support chunked prefill
|
||||
|
||||
|
||||
prefill_metadata = MLACommonPrefillMetadata(
|
||||
block_tables=block_table_tensor[reqs_start:reqs_start+num_prefills, ...],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
# max_query_len=None,
|
||||
chunked_context=chunked_context_metadata,
|
||||
seq_lens=seq_lens[:num_prefills],
|
||||
)
|
||||
|
||||
if not isinstance(seq_lens, list):
|
||||
# TODO init set list in init: vllm/v1/spec_decode/eagle.py
|
||||
seq_lens = seq_lens.tolist()
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
block_num_per_group = env_blk_grp_size // 16
|
||||
block_table_tensor_new = block_table_tensor[:num_decodes, ::block_num_per_group].contiguous()
|
||||
|
||||
seq_lens_new = self.append_seqlen(seq_lens[:slot_mapping.shape[-1]], slot_mapping.shape[-1])
|
||||
|
||||
if slot_mapping.shape[-1] > num_decodes:
|
||||
mtp_numbers = [query_start_loc[i+1]-query_start_loc[i] for i in range(len(query_start_loc)-1)] #query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
block_table_tensor_list = []
|
||||
for bi,mtp_number in enumerate(mtp_numbers):
|
||||
for _ in range(mtp_number):
|
||||
block_table_tensor_list.append(block_table_tensor_new[bi:bi+1])
|
||||
block_table_tensor_new = torch.concatenate(block_table_tensor_list, 0)
|
||||
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor_new,
|
||||
seq_lens=seq_lens_new,
|
||||
)
|
||||
|
||||
return self.metadata_cls(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
# prefill_seq_lens=seq_lens[:self._num_prefills].tolist(), # device to host, todo optimiz
|
||||
# MLACommonMetadata Chunk prefill specific
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
prefill_metadata=prefill_metadata,
|
||||
decode_metadata=decode_metadata,
|
||||
)
|
||||
|
||||
0
vllm_vacc/vllm/v1/core/__init__.py
Normal file
0
vllm_vacc/vllm/v1/core/__init__.py
Normal file
BIN
vllm_vacc/vllm/v1/core/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/core/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/core/__pycache__/block_pool.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/core/__pycache__/block_pool.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
78
vllm_vacc/vllm/v1/core/block_pool.py
Normal file
78
vllm_vacc/vllm/v1/core/block_pool.py
Normal file
@@ -0,0 +1,78 @@
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Optional
|
||||
|
||||
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
||||
BlockStored, KVCacheEvent)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens)
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.core.block_pool import BlockHashToBlockMap
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class BlockPool:
|
||||
"""BlockPool that manages KVCacheBlocks.
|
||||
It provides methods to allocate, free and cache the kv cache blocks. The
|
||||
free_block_queue stores the free blocks in eviction order to enable
|
||||
allocation, free, and cache eviction. The cached_block_hash_to_block
|
||||
maps between block hash and cached block to support finding cached blocks
|
||||
by their block hash.
|
||||
|
||||
Args:
|
||||
num_gpu_blocks: The number of blocks in the pool.
|
||||
enable_caching: Whether to enable prefix caching.
|
||||
enable_kv_cache_events: Whether to enable kv cache events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool = False,
|
||||
):
|
||||
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.enable_caching = enable_caching
|
||||
# All kv-cache blocks.
|
||||
self.blocks: list[KVCacheBlock] = [
|
||||
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
|
||||
]
|
||||
# Free block queue that constructs and manipulates a doubly linked
|
||||
# list of free blocks (including eviction candidates when caching is
|
||||
# enabled).
|
||||
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
|
||||
|
||||
# Cache for block lookup
|
||||
self.cached_block_hash_to_block: BlockHashToBlockMap = \
|
||||
BlockHashToBlockMap()
|
||||
|
||||
# To represent a placeholder block with block_id=0.
|
||||
# The ref_cnt of null_block is not maintained, needs special care to
|
||||
# avoid freeing it.
|
||||
# self.null_block = self.free_block_queue.popleft()
|
||||
# self.null_block.is_null = True
|
||||
self.null_block = self.free_block_queue.fake_free_list_head.next_free_block #self.free_block_queue.popleft()
|
||||
self.null_block.is_null = False
|
||||
|
||||
self.enable_kv_cache_events = enable_kv_cache_events
|
||||
self.kv_event_queue: list[KVCacheEvent] = []
|
||||
|
||||
def get_usage(self) -> float:
|
||||
"""Get the KV cache usage.
|
||||
|
||||
Returns:
|
||||
The KV cache usage (between 0.0 and 1.0).
|
||||
"""
|
||||
|
||||
total_gpu_blocks = self.num_gpu_blocks
|
||||
if not total_gpu_blocks:
|
||||
return 0
|
||||
|
||||
return 1.0 - (self.get_num_free_blocks() / total_gpu_blocks)
|
||||
156
vllm_vacc/vllm/v1/core/kv_cache_utils.py
Normal file
156
vllm_vacc/vllm/v1/core/kv_cache_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.kv_cache_utils import logger
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
|
||||
def query_device_avaliable_memory(vllm_config):
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch_vacc
|
||||
|
||||
available_kv_cache_memory= int(os.getenv("VLLM_VACC_KVCACHE_SPACE", "16")) * GiB_bytes
|
||||
if available_kv_cache_memory ==0:
|
||||
torch.vacc.empty_cache()
|
||||
torch.vacc.reset_peak_memory_stats()
|
||||
total_memory = torch.vacc.mem_get_info()[1]
|
||||
torch.vacc.synchronize()
|
||||
peak_memory = torch.vacc.max_memory_allocated()
|
||||
torch.vacc.empty_cache()
|
||||
torch_allocated_bytes = torch.vacc.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
total_allocated_bytes = torch.vacc.mem_get_info(
|
||||
)[1] - torch.vacc.mem_get_info()[0]
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
if non_torch_allocations > 0:
|
||||
peak_memory += non_torch_allocations
|
||||
available_kv_cache_memory=total_memory * vllm_config.cache_config.gpu_memory_utilization - peak_memory
|
||||
|
||||
return available_kv_cache_memory
|
||||
|
||||
|
||||
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
|
||||
available_memory: int, page_size: int) -> int:
|
||||
"""
|
||||
Get the number of kv cache blocks.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
num_layers: The number of layers
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
page_size: The page size of the KV cache.
|
||||
"""
|
||||
num_blocks = int(available_memory // page_size // num_layers)
|
||||
num_blocks = max(num_blocks, 0)
|
||||
if vllm_config.cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = \
|
||||
vllm_config.cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
||||
num_blocks = num_gpu_blocks_override
|
||||
|
||||
block_num_per_group = env_blk_grp_size // 16
|
||||
num_blocks = num_blocks // block_num_per_group * block_num_per_group
|
||||
|
||||
return num_blocks
|
||||
|
||||
def estimate_max_model_len(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> int:
|
||||
"""
|
||||
Estimates the maximum model length that can fit in the available memory
|
||||
using binary search.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The estimated maximum model length that can fit in the available memory.
|
||||
"""
|
||||
from vllm.v1.core.kv_cache_utils import max_memory_usage_bytes
|
||||
available_kv_cache_memory = query_device_avaliable_memory(vllm_config)
|
||||
|
||||
# Define a function to check if a given model length fits in memory
|
||||
def fits_in_memory(model_len: int) -> bool:
|
||||
# Modify the max_model_len for this calculation
|
||||
vllm_config.model_config.max_model_len = model_len
|
||||
# Calculate memory needed for the given model length
|
||||
memory_needed = max_memory_usage_bytes(vllm_config,
|
||||
kv_cache_spec.values())
|
||||
# 增加vacc buffer 判断,与原始设计不一致
|
||||
# 如果所需memory 小于总的预分配空间/剩余物理空间,支持其继续分配,
|
||||
# 无需受到score模型max_model_len个预分配空间限制
|
||||
return memory_needed <= max(available_memory, available_kv_cache_memory)
|
||||
|
||||
# Binary search for the maximum model length
|
||||
current_max = vllm_config.model_config.max_model_len
|
||||
left, right = 1, current_max
|
||||
|
||||
# If even the smallest model length doesn't fit, return 0
|
||||
if not fits_in_memory(left):
|
||||
return 0
|
||||
|
||||
# Binary search for the maximum model length that fits
|
||||
result = 1
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
if fits_in_memory(mid):
|
||||
result = mid
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid - 1
|
||||
return result
|
||||
|
||||
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int):
|
||||
"""
|
||||
Checks whether `available_memory` is enough for the KV cache to hold at
|
||||
least one request with the model's max_model_len.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is not enough memory available for the KV cache.
|
||||
"""
|
||||
from vllm.v1.core.kv_cache_utils import max_memory_usage_bytes
|
||||
|
||||
if available_memory <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values())
|
||||
available_kv_cache_memory = query_device_avaliable_memory(vllm_config)
|
||||
|
||||
# 增加vacc buffer 判断,与原始设计不一致
|
||||
# 如果所需memory 小于总的预分配空间/剩余物理空间,支持其继续分配,
|
||||
# 无需受到score模型max_model_len个预分配空间限制
|
||||
if needed_memory > max(available_memory, available_kv_cache_memory):
|
||||
# Estimate the maximum model length that can fit in the available memory
|
||||
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
print("exec max model len is:", estimated_max_len)
|
||||
estimated_msg = ""
|
||||
if estimated_max_len > 0:
|
||||
estimated_msg = (
|
||||
"Based on the available memory, "
|
||||
f"the estimated maximum model length is {estimated_max_len}.")
|
||||
|
||||
raise ValueError(
|
||||
f"To serve at least one request with the models's max seq len "
|
||||
f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV "
|
||||
f"cache is needed, which is larger than the available KV cache "
|
||||
f"memory ({available_memory/GiB_bytes:.2f} GiB). "
|
||||
f"{estimated_msg} "
|
||||
f"Try increasing `gpu_memory_utilization` or decreasing "
|
||||
f"`max_model_len` when initializing the engine.")
|
||||
0
vllm_vacc/vllm/v1/core/sched/__init__.py
Normal file
0
vllm_vacc/vllm/v1/core/sched/__init__.py
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/core/sched/__pycache__/output.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/core/sched/__pycache__/output.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
89
vllm_vacc/vllm/v1/core/sched/output.py
Normal file
89
vllm_vacc/vllm/v1/core/sched/output.py
Normal file
@@ -0,0 +1,89 @@
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm._bc_linter import bc_linter_include
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class NewRequestData:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: Optional[list[int]]
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
lora_request: Optional[LoRARequest]
|
||||
prompt_embeds: Optional[torch.Tensor] = None
|
||||
deepstack_input_embeds: Optional[torch.Tensor] = None #patch
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> NewRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
lora_request=request.lora_request,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
deepstack_input_embeds=request.deepstack_input_embeds,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
prompt_embeds_shape = (self.prompt_embeds.shape
|
||||
if self.prompt_embeds else None)
|
||||
return (f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids={self.prompt_token_ids},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request},"
|
||||
f"prompt_embeds_shape={prompt_embeds_shape}"
|
||||
")")
|
||||
|
||||
# Version of __repr__ with the prompt data obfuscated
|
||||
def anon_repr(self) -> str:
|
||||
prompt_token_ids_len = len(
|
||||
self.prompt_token_ids
|
||||
) if self.prompt_token_ids is not None else None
|
||||
prompt_embeds_shape = (self.prompt_embeds.shape
|
||||
if self.prompt_embeds else None)
|
||||
return (f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids_len={prompt_token_ids_len},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request},"
|
||||
f"prompt_embeds_shape={prompt_embeds_shape}"
|
||||
")")
|
||||
930
vllm_vacc/vllm/v1/core/sched/scheduler.py
Normal file
930
vllm_vacc/vllm/v1/core/sched/scheduler.py
Normal file
@@ -0,0 +1,930 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||
compute_encoder_budget)
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
|
||||
create_request_queue)
|
||||
from vllm.v1.core.sched.utils import check_stop
|
||||
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
||||
EngineCoreOutputs)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
prefill_first = True
|
||||
|
||||
def _is_in_prefill(req: Request) -> bool:
|
||||
# 还没把 prompt 全算完
|
||||
return req.num_computed_tokens < req.num_prompt_tokens
|
||||
|
||||
def align_up(value: int, alignment: int) -> int:
|
||||
if alignment <= 0:
|
||||
return value
|
||||
return ((value + alignment - 1) // alignment) * alignment
|
||||
|
||||
class Scheduler(SchedulerInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
structured_output_manager: StructuredOutputManager,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
include_finished_set: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.kv_events_config = vllm_config.kv_events_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.log_stats = log_stats
|
||||
self.structured_output_manager = structured_output_manager
|
||||
self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder
|
||||
|
||||
# include_finished_set controls whether a separate set of finished
|
||||
# request ids should be included in the EngineCoreOutputs returned
|
||||
# by update_from_outputs(). This is currently used in the multi-engine
|
||||
# case to track request lifetimes efficiently.
|
||||
self.finished_req_ids_dict: Optional[dict[int, set[str]]] = (
|
||||
defaultdict(set) if include_finished_set else None)
|
||||
|
||||
# Scheduling constraints.
|
||||
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_scheduled_tokens = \
|
||||
self.scheduler_config.max_num_batched_tokens
|
||||
self.max_model_len = self.scheduler_config.max_model_len
|
||||
self.enable_kv_cache_events = (
|
||||
self.kv_events_config is not None
|
||||
and self.kv_events_config.enable_kv_cache_events)
|
||||
|
||||
# Create KVConnector for the Scheduler. Note that each Worker
|
||||
# will have a corresponding KVConnector with Role=WORKER.
|
||||
# KV Connector pushes/pull of remote KVs for P/D and offloading.
|
||||
self.connector = None
|
||||
if self.vllm_config.kv_transfer_config is not None:
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"Multiple KV cache groups are not currently supported "
|
||||
"with KV connectors")
|
||||
assert not self.is_encoder_decoder, (
|
||||
"Encoder-decoder models are not currently supported "
|
||||
"with KV connectors")
|
||||
self.connector = KVConnectorFactory.create_connector(
|
||||
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
||||
|
||||
self.kv_event_publisher = EventPublisherFactory.create(
|
||||
self.kv_events_config,
|
||||
self.parallel_config.data_parallel_rank,
|
||||
)
|
||||
|
||||
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
||||
|
||||
self.block_size = self.cache_config.block_size
|
||||
self.max_total_kv_tokens: Optional[int] = None
|
||||
if self.block_size is not None:
|
||||
self.max_total_kv_tokens = (
|
||||
self.kv_cache_config.num_blocks * self.block_size)
|
||||
self.dcp_world_size = \
|
||||
vllm_config.parallel_config.decode_context_parallel_size
|
||||
|
||||
# Note(hc): The scheduler’s block_size must be multiplied
|
||||
# by dcp_world_size, since block hashes are computed on the
|
||||
# original full token sequence at a granularity of
|
||||
# original_block_size × dcp_world_size.
|
||||
if self.dcp_world_size > 1:
|
||||
self.block_size *= self.dcp_world_size
|
||||
|
||||
# req_id -> Request
|
||||
self.requests: dict[str, Request] = {}
|
||||
# Scheduling policy
|
||||
if self.scheduler_config.policy == "priority":
|
||||
self.policy = SchedulingPolicy.PRIORITY
|
||||
elif self.scheduler_config.policy == "fcfs":
|
||||
self.policy = SchedulingPolicy.FCFS
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown scheduling policy: {self.scheduler_config.policy}")
|
||||
# Priority queues for requests.
|
||||
self.waiting = create_request_queue(self.policy)
|
||||
self.running: list[Request] = []
|
||||
# Pending barrier groups: gid -> list[Request]
|
||||
self._barrier_groups: dict[str, list[Request]] = {}
|
||||
# Expected group sizes: gid -> expected size
|
||||
self._barrier_expected: dict[str, int] = {}
|
||||
# Force decode-only scheduling in the next step when prefills block.
|
||||
self.force_decode_next_step = False
|
||||
|
||||
# The request IDs that are finished in between the previous and the
|
||||
# current steps. This is used to notify the workers about the finished
|
||||
# requests so that they can free the cached states for those requests.
|
||||
# This is flushed at the end of each scheduling step.
|
||||
self.finished_req_ids: set[str] = set()
|
||||
|
||||
# KV Connector: requests in process of async KV loading or recving
|
||||
self.finished_recving_kv_req_ids: set[str] = set()
|
||||
|
||||
# Encoder-related.
|
||||
# Calculate encoder cache size if applicable
|
||||
# NOTE: For now we use the same budget for both compute and space.
|
||||
# This can be changed when we make encoder cache for embedding caching
|
||||
# across requests.
|
||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
mm_registry=mm_registry,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
|
||||
# projector if needed) for MM models as well as encoder-decoder
|
||||
# transformers.
|
||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||
# NOTE: For the models without encoder (e.g., text-only models),
|
||||
# the encoder cache will not be initialized because cache size is 0
|
||||
# for these models.
|
||||
self.encoder_cache_manager = EncoderCacheManager(
|
||||
cache_size=encoder_cache_size)
|
||||
|
||||
speculative_config = vllm_config.speculative_config
|
||||
self.use_eagle = False
|
||||
self.num_spec_tokens = self.num_lookahead_tokens = 0
|
||||
if speculative_config:
|
||||
self.num_spec_tokens = speculative_config.num_speculative_tokens
|
||||
if speculative_config.use_eagle():
|
||||
self.use_eagle = True
|
||||
self.num_lookahead_tokens = self.num_spec_tokens
|
||||
|
||||
# Create the KV cache manager.
|
||||
self.kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
enable_caching=self.cache_config.enable_prefix_caching,
|
||||
use_eagle=self.use_eagle,
|
||||
log_stats=self.log_stats,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
)
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
# self.waiting.add_request(request)
|
||||
# Always track in global dict first
|
||||
self.requests[request.request_id] = request
|
||||
gid = getattr(request, "barrier_group_id", None)
|
||||
gsz = getattr(request, "barrier_group_size", 0) or 0
|
||||
|
||||
if gid and gsz > 1:
|
||||
# Stage into pending barrier group
|
||||
group_list = self._barrier_groups.get(gid)
|
||||
if group_list is None:
|
||||
self._barrier_groups[gid] = group_list = []
|
||||
self._barrier_expected[gid] = gsz
|
||||
else:
|
||||
# Reconcile inconsistent sizes by taking the max
|
||||
self._barrier_expected[gid] = max(self._barrier_expected[gid],
|
||||
gsz)
|
||||
group_list.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.QUEUED)
|
||||
|
||||
# Flush group when complete
|
||||
if len(group_list) >= self._barrier_expected[gid]:
|
||||
# Preserve arrival order
|
||||
for r in sorted(group_list, key=lambda x: x.arrival_time):
|
||||
self.waiting.add_request(r)
|
||||
# Cleanup
|
||||
del self._barrier_groups[gid]
|
||||
del self._barrier_expected[gid]
|
||||
return
|
||||
|
||||
# No barrier group; enqueue directly
|
||||
self.waiting.add_request(request)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.QUEUED)
|
||||
|
||||
def _schedule_running_requests_for_mode(
|
||||
self,
|
||||
prefill_mode: bool,
|
||||
token_budget: int,
|
||||
encoder_compute_budget: int,
|
||||
scheduled_timestamp: float,
|
||||
scheduled_running_reqs: list[Request],
|
||||
preempted_reqs: list[Request],
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks],
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
scan_index: int,
|
||||
next_req_index: int,
|
||||
skip_request_ids: set[str],
|
||||
) -> tuple[int, int, int, int]:
|
||||
"""Schedule running requests under the given mode.
|
||||
|
||||
Args:
|
||||
scan_index: The position in ``self.running`` to start scanning from.
|
||||
next_req_index: The next logical "running index" used when we
|
||||
populate ``structured_output_request_ids``.
|
||||
"""
|
||||
while scan_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[scan_index]
|
||||
|
||||
if request.request_id in skip_request_ids:
|
||||
scan_index += 1
|
||||
continue
|
||||
|
||||
if prefill_mode and not _is_in_prefill(request):
|
||||
scan_index += 1
|
||||
continue
|
||||
|
||||
if (not prefill_mode) and _is_in_prefill(request):
|
||||
scan_index += 1
|
||||
continue
|
||||
|
||||
num_new_tokens = max(
|
||||
request.num_tokens_with_spec - request.num_computed_tokens, 0)
|
||||
|
||||
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||
num_new_tokens):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
|
||||
# chunked prefill has to be enabled explicitly to allow pooling
|
||||
# requests to be chunked
|
||||
if (not self.scheduler_config.chunked_prefill_enabled
|
||||
and num_new_tokens > token_budget):
|
||||
break
|
||||
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
num_new_tokens = min(
|
||||
num_new_tokens,
|
||||
self.max_model_len - 1 - request.num_computed_tokens)
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
# new_encoder_budget = encoder_budget
|
||||
# if request.has_encoder_inputs:
|
||||
# (encoder_inputs_to_schedule, num_new_tokens,
|
||||
# new_encoder_budget) = self._try_schedule_encoder_inputs(
|
||||
# request, request.num_computed_tokens, num_new_tokens,
|
||||
# encoder_budget)
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_compute_budget
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request, request.num_computed_tokens, num_new_tokens,
|
||||
encoder_compute_budget)
|
||||
|
||||
if num_new_tokens == 0:
|
||||
scan_index += 1
|
||||
continue
|
||||
|
||||
# num_draft_tokens = max(
|
||||
# num_new_tokens + request.num_computed_tokens -
|
||||
# request.num_tokens, 0)
|
||||
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
# num_draft_tokens=num_draft_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens)
|
||||
if new_blocks is None:
|
||||
if self.policy == SchedulingPolicy.PRIORITY:
|
||||
preempted_req = max(
|
||||
self.running,
|
||||
key=lambda r: (r.priority, r.arrival_time),
|
||||
)
|
||||
self.running.remove(preempted_req)
|
||||
if preempted_req in scheduled_running_reqs:
|
||||
scheduled_running_reqs.remove(preempted_req)
|
||||
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
self.encoder_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
if _is_in_prefill(preempted_req):
|
||||
preempted_req.num_computed_tokens = 0
|
||||
if self.log_stats:
|
||||
preempted_req.record_event(
|
||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
|
||||
|
||||
self.waiting.prepend_request(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
can_schedule = False
|
||||
break
|
||||
else:
|
||||
can_schedule = True
|
||||
break
|
||||
if not can_schedule:
|
||||
break
|
||||
assert new_blocks is not None
|
||||
|
||||
scheduled_running_reqs.append(request)
|
||||
req_to_new_blocks[request.request_id] = new_blocks
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
next_req_index += 1
|
||||
|
||||
if request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (num_new_tokens +
|
||||
request.num_computed_tokens -
|
||||
request.num_tokens)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
||||
scheduled_spec_decode_tokens[request.request_id] = (
|
||||
request.spec_token_ids)
|
||||
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
# encoder_budget = new_encoder_budget
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
|
||||
scan_index += 1
|
||||
|
||||
return token_budget, encoder_compute_budget, scan_index, next_req_index
|
||||
|
||||
def finish_requests(
|
||||
self,
|
||||
request_ids: Union[str, Iterable[str]],
|
||||
finished_status: RequestStatus,
|
||||
) -> None:
|
||||
"""Handles the finish signal from outside the scheduler.
|
||||
|
||||
For example, the API server can abort a request when the client
|
||||
disconnects.
|
||||
"""
|
||||
assert RequestStatus.is_finished(finished_status)
|
||||
if isinstance(request_ids, str):
|
||||
request_ids = (request_ids, )
|
||||
else:
|
||||
request_ids = set(request_ids)
|
||||
|
||||
running_requests_to_remove = []
|
||||
waiting_requests_to_remove = []
|
||||
valid_requests = []
|
||||
|
||||
# First pass: collect requests to remove from queues
|
||||
for req_id in request_ids:
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
# Invalid request ID.
|
||||
continue
|
||||
|
||||
valid_requests.append(request)
|
||||
if request.status == RequestStatus.RUNNING:
|
||||
running_requests_to_remove.append(request)
|
||||
else:
|
||||
waiting_requests_to_remove.append(request)
|
||||
|
||||
# Remove all requests from queues at once for better efficiency
|
||||
for request in running_requests_to_remove:
|
||||
self.running.remove(request)
|
||||
if waiting_requests_to_remove:
|
||||
self.waiting.remove_requests(waiting_requests_to_remove)
|
||||
|
||||
# Handle requests that are pending in barrier groups:
|
||||
# if any item from a pending group is finished/aborted, flush the rest.
|
||||
for request in valid_requests:
|
||||
gid = getattr(request, "barrier_group_id", None)
|
||||
if not gid:
|
||||
continue
|
||||
group_list = self._barrier_groups.get(gid)
|
||||
if not group_list:
|
||||
continue
|
||||
# Remove this request from pending group if present
|
||||
try:
|
||||
if request in group_list:
|
||||
group_list.remove(request)
|
||||
except ValueError:
|
||||
pass
|
||||
# Flush remaining members (if any) to waiting, then cleanup.
|
||||
if group_list:
|
||||
for r in sorted(group_list, key=lambda x: x.arrival_time):
|
||||
self.waiting.add_request(r)
|
||||
self._barrier_groups.pop(gid, None)
|
||||
self._barrier_expected.pop(gid, None)
|
||||
|
||||
|
||||
|
||||
# Second pass: set status and free requests
|
||||
for request in valid_requests:
|
||||
request.status = finished_status
|
||||
self._free_request(request)
|
||||
|
||||
def _estimate_future_kv_tokens(self, request: Request) -> int:
|
||||
"""Estimate the KV tokens a request may need for its full lifetime."""
|
||||
alignment = min(env_blk_grp_size, self.max_model_len)
|
||||
alignment = max(alignment, 1)
|
||||
|
||||
prompt_tokens = request.num_prompt_tokens # prefill length
|
||||
prompt_reserved = align_up(prompt_tokens, alignment)
|
||||
|
||||
remaining_room = max(self.max_model_len - prompt_tokens, 0)
|
||||
max_decode_tokens = request.max_tokens # output length
|
||||
if max_decode_tokens < 0:
|
||||
max_decode_tokens = remaining_room
|
||||
else:
|
||||
max_decode_tokens = min(max_decode_tokens, remaining_room)
|
||||
|
||||
decode_budget = max_decode_tokens
|
||||
if self.num_lookahead_tokens > 0:
|
||||
decode_budget = min(decode_budget + self.num_lookahead_tokens,
|
||||
remaining_room)
|
||||
|
||||
remaining_capacity_in_prompt = max(prompt_reserved - prompt_tokens, 0)
|
||||
decode_after_prompt = max(decode_budget - remaining_capacity_in_prompt,
|
||||
0)
|
||||
if decode_after_prompt == 0:
|
||||
total_reserved = prompt_reserved
|
||||
else:
|
||||
total_reserved = prompt_reserved + \
|
||||
align_up(decode_after_prompt, alignment)
|
||||
|
||||
return min(total_reserved, self.max_model_len)
|
||||
|
||||
def _compute_total_future_kv_tokens(
|
||||
self, requests: Iterable[Request]
|
||||
) -> int:
|
||||
return sum(self._estimate_future_kv_tokens(req) for req in requests)
|
||||
|
||||
def make_stats(
|
||||
self,
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None,
|
||||
kv_connector_stats: Optional["KVConnectorStats"] = None,
|
||||
) -> Optional[SchedulerStats]:
|
||||
if not self.log_stats:
|
||||
return None
|
||||
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
|
||||
assert prefix_cache_stats is not None
|
||||
return SchedulerStats(num_running_reqs=len(self.running),
|
||||
num_waiting_reqs=len(self.waiting),
|
||||
running_seqlens = [len(running_i._all_token_ids) for running_i in self.running],
|
||||
kv_cache_usage=self.kv_cache_manager.usage,
|
||||
prefix_cache_stats=prefix_cache_stats,
|
||||
spec_decoding_stats=spec_decoding_stats,
|
||||
num_corrupted_reqs=sum(req.is_output_corrupted
|
||||
for req in self.running),
|
||||
kv_connector_stats=kv_connector_stats.data
|
||||
if kv_connector_stats else None)
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
# Each request just has the num_computed_tokens and
|
||||
# num_tokens_with_spec. num_tokens_with_spec =
|
||||
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
|
||||
# At each step, the scheduler tries to assign tokens to the requests
|
||||
# so that each request's num_computed_tokens can catch up its
|
||||
# num_tokens_with_spec. This is general enough to cover
|
||||
# chunked prefills, prefix caching, speculative decoding,
|
||||
# and the "jump decoding" optimization in the future.
|
||||
|
||||
scheduled_new_reqs: list[Request] = []
|
||||
scheduled_resumed_reqs: list[Request] = []
|
||||
scheduled_running_reqs: list[Request] = []
|
||||
preempted_reqs: list[Request] = []
|
||||
|
||||
# NOTE: structured_output_request_ids maps
|
||||
# a request's (request that uses structured output)
|
||||
# request_id to the running request index.
|
||||
# This will helps us determine to slice the grammar bitmask
|
||||
# and only applies valid mask for requests that
|
||||
# uses structured decoding.
|
||||
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
# Encoder-related.
|
||||
scheduled_encoder_inputs: dict[str, list[int]] = {}
|
||||
encoder_compute_budget = self.max_num_encoder_input_tokens
|
||||
# Spec decode-related.
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
||||
|
||||
# For logging.
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
# Determine whether we should prioritize prefills in this step.
|
||||
has_running_prefill = any(_is_in_prefill(r) for r in self.running)
|
||||
has_waiting_prefill = False
|
||||
if self.waiting and prefill_first:
|
||||
for waiting_req in self.waiting:
|
||||
if _is_in_prefill(waiting_req):
|
||||
has_waiting_prefill = True
|
||||
break
|
||||
prefill_mode = prefill_first and (has_running_prefill or has_waiting_prefill or not self.running) and len(self.running) < self.max_num_running_reqs
|
||||
if self.force_decode_next_step:
|
||||
prefill_mode = False
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
req_index = 0
|
||||
token_budget, encoder_budget, _, req_index = (
|
||||
self._schedule_running_requests_for_mode(
|
||||
prefill_mode,
|
||||
token_budget,
|
||||
encoder_compute_budget,
|
||||
scheduled_timestamp,
|
||||
scheduled_running_reqs,
|
||||
preempted_reqs,
|
||||
req_to_new_blocks,
|
||||
num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs,
|
||||
0,
|
||||
req_index,
|
||||
set(),
|
||||
))
|
||||
# Record the LoRAs in scheduled_running_reqs
|
||||
scheduled_loras: set[int] = set()
|
||||
if self.lora_config:
|
||||
scheduled_loras = set(
|
||||
req.lora_request.lora_int_id for req in scheduled_running_reqs
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0)
|
||||
assert len(scheduled_loras) <= self.lora_config.max_loras
|
||||
|
||||
# Use a temporary RequestQueue to collect requests that need to be
|
||||
# skipped and put back at the head of the waiting queue later
|
||||
skipped_waiting_requests = create_request_queue(self.policy)
|
||||
# prefill_allocation_failed = False
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
future_kv_token_usage = 0
|
||||
prefill_all_len = 0
|
||||
if self.max_total_kv_tokens is not None:
|
||||
future_kv_token_usage = (
|
||||
self._compute_total_future_kv_tokens(self.running))
|
||||
|
||||
while self.waiting and token_budget > 0:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
|
||||
request = self.waiting.peek_request()
|
||||
request_future_tokens: Optional[int] = None
|
||||
if self.max_total_kv_tokens is not None:
|
||||
request_future_tokens = self._estimate_future_kv_tokens(
|
||||
request)
|
||||
if prefill_mode:
|
||||
prefill_all_len += request.num_prompt_tokens
|
||||
if prefill_mode and ((future_kv_token_usage +
|
||||
request_future_tokens
|
||||
> self.max_total_kv_tokens) or prefill_all_len > LLM_MAX_PREFILL_SEQ_LEN):
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
# prefill_allocation_failed = True
|
||||
self.force_decode_next_step = True
|
||||
continue
|
||||
in_prefill = _is_in_prefill(request)
|
||||
if prefill_mode and not in_prefill:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
if (not prefill_mode) and in_prefill:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# KVTransfer: skip request if still waiting for remote kvs.
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
is_ready = self._update_waiting_for_remote_kv(request)
|
||||
if is_ready:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
logger.debug(
|
||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||
request.request_id)
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Skip request if the structured output request is still waiting
|
||||
# for FSM compilation.
|
||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||
structured_output_req = request.structured_output_request
|
||||
if structured_output_req and structured_output_req.grammar:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if (self.lora_config and request.lora_request and
|
||||
(len(scheduled_loras) == self.lora_config.max_loras and
|
||||
request.lora_request.lora_int_id not in scheduled_loras)):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
num_external_computed_tokens = 0
|
||||
load_kv_async = False
|
||||
|
||||
# Get already-cached tokens.
|
||||
if request.num_computed_tokens == 0:
|
||||
# Get locally-cached tokens.
|
||||
new_computed_blocks, num_new_local_computed_tokens = \
|
||||
self.kv_cache_manager.get_computed_blocks(
|
||||
request)
|
||||
|
||||
# Get externally-cached tokens if using a KVConnector.
|
||||
if self.connector is not None:
|
||||
num_external_computed_tokens, load_kv_async = (
|
||||
self.connector.get_num_new_matched_tokens(
|
||||
request, num_new_local_computed_tokens))
|
||||
if num_external_computed_tokens is None:
|
||||
# The request cannot be scheduled because
|
||||
# the KVConnector couldn't determine
|
||||
# the number of matched tokens.
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Total computed tokens (local + external).
|
||||
num_computed_tokens = (num_new_local_computed_tokens +
|
||||
num_external_computed_tokens)
|
||||
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
||||
# after async KV recvs are completed.
|
||||
else:
|
||||
new_computed_blocks = (
|
||||
self.kv_cache_manager.create_empty_block_list())
|
||||
num_new_local_computed_tokens = 0
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
encoder_inputs_to_schedule = None
|
||||
# new_encoder_budget = encoder_budget
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
|
||||
# KVTransfer: loading remote KV, do not allocate for new work.
|
||||
if load_kv_async:
|
||||
assert num_external_computed_tokens > 0
|
||||
num_new_tokens = 0
|
||||
# Number of tokens to be scheduled.
|
||||
else:
|
||||
# We use `request.num_tokens` instead of
|
||||
# `request.num_prompt_tokens` to consider the resumed
|
||||
# requests, which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
if (0 < self.scheduler_config.long_prefill_token_threshold
|
||||
< num_new_tokens):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
|
||||
# chunked prefill has to be enabled explicitly to allow
|
||||
# pooling requests to be chunked
|
||||
if not self.scheduler_config.chunked_prefill_enabled and \
|
||||
num_new_tokens > token_budget:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
# Schedule encoder inputs.
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_compute_budget
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request, num_computed_tokens, num_new_tokens,
|
||||
encoder_compute_budget)
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
# Handles an edge case when P/D Disaggregation
|
||||
# is used with Spec Decoding where an
|
||||
# extra block gets allocated which
|
||||
# creates a mismatch between the number
|
||||
# of local and remote blocks.
|
||||
effective_lookahead_tokens = (0 if request.num_computed_tokens
|
||||
== 0 else
|
||||
self.num_lookahead_tokens)
|
||||
|
||||
# Determine if we need to allocate cross-attention blocks.
|
||||
if self.is_encoder_decoder and request.has_encoder_inputs:
|
||||
# TODO(russellb): For Whisper, we know that the input is
|
||||
# always padded to the maximum length. If we support other
|
||||
# encoder-decoder models, this will need to be updated if we
|
||||
# want to only allocate what is needed.
|
||||
num_encoder_tokens =\
|
||||
self.scheduler_config.max_num_encoder_input_tokens
|
||||
else:
|
||||
num_encoder_tokens = 0
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens + num_external_computed_tokens,
|
||||
num_new_local_computed_tokens,
|
||||
new_computed_blocks,
|
||||
num_lookahead_tokens=effective_lookahead_tokens,
|
||||
delay_cache_blocks=load_kv_async,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled in this step.
|
||||
if prefill_mode:
|
||||
# prefill_allocation_failed = True
|
||||
self.force_decode_next_step = True
|
||||
break
|
||||
|
||||
# KVTransfer: the connector uses this info to determine
|
||||
# if a load is needed. Note that
|
||||
# This information is used to determine if a load is
|
||||
# needed for this request.
|
||||
if self.connector is not None:
|
||||
self.connector.update_state_after_alloc(
|
||||
request,
|
||||
new_computed_blocks + new_blocks,
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
|
||||
# Request was already popped from self.waiting
|
||||
# unless it was re-added above due to new_blocks being None.
|
||||
request = self.waiting.pop_request()
|
||||
if load_kv_async:
|
||||
# If loading async, allocate memory and put request
|
||||
# into the WAITING_FOR_REMOTE_KV state.
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
continue
|
||||
|
||||
req_index += 1
|
||||
self.running.append(request)
|
||||
if self.max_total_kv_tokens is not None:
|
||||
if request_future_tokens is None:
|
||||
request_future_tokens = (
|
||||
self._estimate_future_kv_tokens(request))
|
||||
future_kv_token_usage += request_future_tokens
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.SCHEDULED,
|
||||
scheduled_timestamp)
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
scheduled_resumed_reqs.append(request)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Invalid request status: {request.status}")
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
# req_to_new_block_ids[request.request_id] = (
|
||||
# self.kv_cache_manager.get_block_ids(request.request_id))
|
||||
req_to_new_blocks[request.request_id] = new_blocks
|
||||
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
# Count the number of prefix cached tokens.
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
# encoder_budget = new_encoder_budget
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
|
||||
if not prefill_mode:
|
||||
self.force_decode_next_step = False
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
# Since some requests in the RUNNING queue may not be scheduled in
|
||||
# this step, the total number of scheduled requests can be smaller than
|
||||
# len(self.running).
|
||||
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
|
||||
len(scheduled_running_reqs) <= len(self.running))
|
||||
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
num_common_prefix_blocks = [0] * len(
|
||||
self.kv_cache_config.kv_cache_groups)
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request, len(self.running)))
|
||||
|
||||
# Construct the scheduler output.
|
||||
# new_reqs_data = [
|
||||
# NewRequestData.from_request(req,
|
||||
# req_to_new_block_ids[req.request_id])
|
||||
# for req in scheduled_new_reqs
|
||||
# ]
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs,
|
||||
scheduled_resumed_reqs,
|
||||
num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens,
|
||||
# req_to_new_block_ids,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs +
|
||||
scheduled_resumed_reqs)
|
||||
structured_output_request_ids, grammar_bitmask = (
|
||||
self.get_grammar_bitmask(scheduled_requests,
|
||||
scheduled_spec_decode_tokens))
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=cached_reqs_data,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||
structured_output_request_ids=structured_output_request_ids,
|
||||
grammar_bitmask=grammar_bitmask,
|
||||
)
|
||||
|
||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||
# 1. Plan the KV cache store
|
||||
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||
# 3. Clear the internal states of the connector
|
||||
if self.connector is not None:
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
events = self.kv_cache_manager.take_events()
|
||||
|
||||
# collect KV cache events from connector
|
||||
if self.connector is not None:
|
||||
print('if self.connector is not None:')
|
||||
connector_events = self.connector.take_events()
|
||||
if connector_events:
|
||||
if events is None:
|
||||
events = list(connector_events)
|
||||
else:
|
||||
events.extend(connector_events)
|
||||
|
||||
# publish collected KV cache events
|
||||
|
||||
if events:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
|
||||
|
||||
61
vllm_vacc/vllm/v1/core/single_type_kv_cache_manager.py
Normal file
61
vllm_vacc/vllm/v1/core/single_type_kv_cache_manager.py
Normal file
@@ -0,0 +1,61 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Callable
|
||||
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
import os
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
|
||||
class SingleTypeKVCacheManager(ABC):
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
Free the blocks for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
"""
|
||||
# Default to [] in case a request is freed (aborted) before alloc.
|
||||
req_blocks = self.req_to_blocks.pop(request_id, [])
|
||||
|
||||
# Free blocks in reverse order so that the tail blocks are
|
||||
# freed first.
|
||||
# ordered_blocks = reversed(req_blocks)
|
||||
self.block_pool.free_blocks(req_blocks)
|
||||
self.num_cached_block.pop(request_id, None)
|
||||
|
||||
|
||||
def allocate_new_blocks(self, request_id: str,
|
||||
num_tokens: int) -> list[KVCacheBlock]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
|
||||
block_size_number_per_group = env_blk_grp_size // self.block_size #512
|
||||
num_required_blocks = (num_required_blocks + block_size_number_per_group - 1) // block_size_number_per_group * block_size_number_per_group
|
||||
num_new_blocks = num_required_blocks - len(req_blocks)
|
||||
if num_new_blocks <= 0:
|
||||
return []
|
||||
else:
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
return new_blocks
|
||||
94
vllm_vacc/vllm/v1/engine/__init__.py
Normal file
94
vllm_vacc/vllm/v1/engine/__init__.py
Normal file
@@ -0,0 +1,94 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
from vllm.v1.engine import EngineCoreOutput, UtilityOutput
|
||||
|
||||
from vllm_vacc.vllm.v1.metrics.stats import SchedulerStats
|
||||
|
||||
# These are possible values of RequestOutput.finish_reason,
|
||||
# so form part of the external API.
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort")
|
||||
|
||||
class EngineCoreRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
request_id: str
|
||||
prompt_token_ids: Optional[list[int]]
|
||||
mm_features: Optional[list[MultiModalFeatureSpec]]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
eos_token_id: Optional[int]
|
||||
arrival_time: float
|
||||
lora_request: Optional[LoRARequest]
|
||||
cache_salt: Optional[str]
|
||||
data_parallel_rank: Optional[int]
|
||||
prompt_embeds: Optional[torch.Tensor] = None
|
||||
deepstack_input_embeds: Optional[torch.Tensor] = None
|
||||
# Index of the client, used to ensure outputs are sent back to the same
|
||||
# client for this request when scaling out the front-end.
|
||||
client_index: int = 0
|
||||
|
||||
# Used in DP case to indicate which wave of requests this is expected to
|
||||
# belong to, to cover a race condition where the request is sent before
|
||||
# a wave finished notification is received.
|
||||
current_wave: int = 0
|
||||
priority: int = 0
|
||||
|
||||
trace_headers: Optional[Mapping[str, str]] = None
|
||||
|
||||
class EngineCoreOutputs(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
# NOTE(Nick): We could consider ways to make this more compact,
|
||||
# e.g. columnwise layout
|
||||
|
||||
engine_index: int = 0
|
||||
|
||||
# [num_reqs]
|
||||
outputs: list[EngineCoreOutput] = []
|
||||
scheduler_stats: Optional[SchedulerStats] = None
|
||||
timestamp: float = 0.0
|
||||
|
||||
utility_output: Optional[UtilityOutput] = None
|
||||
finished_requests: Optional[set[str]] = None
|
||||
|
||||
# In DP case, used to signal that the current wave of requests
|
||||
# has finished and the engines are paused.
|
||||
wave_complete: Optional[int] = None
|
||||
# In DP case, used to signal that a request was received for an
|
||||
# "old" wave, so the next wave needs to be started in other engines.
|
||||
start_wave: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp == 0.0:
|
||||
self.timestamp = time.monotonic()
|
||||
|
||||
class EngineCoreRequestType(enum.Enum):
|
||||
"""
|
||||
Request types defined as hex byte strings, so it can be sent over sockets
|
||||
without separate encoding step.
|
||||
"""
|
||||
ADD = b'\x00'
|
||||
ABORT = b'\x01'
|
||||
START_DP_WAVE = b'\x02'
|
||||
UTILITY = b'\x03'
|
||||
# Sentinel used within EngineCoreProc.
|
||||
EXECUTOR_FAILED = b'\x04'
|
||||
ADD_BULK = b'\x05'
|
||||
BIN
vllm_vacc/vllm/v1/engine/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/engine/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/engine/__pycache__/async_llm.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/engine/__pycache__/async_llm.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/engine/__pycache__/core.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/engine/__pycache__/core.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/engine/__pycache__/core_client.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/engine/__pycache__/core_client.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/engine/__pycache__/llm_engine.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/engine/__pycache__/llm_engine.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/engine/__pycache__/processor.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/engine/__pycache__/processor.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
108
vllm_vacc/vllm/v1/engine/async_llm.py
Normal file
108
vllm_vacc/vllm/v1/engine/async_llm.py
Normal file
@@ -0,0 +1,108 @@
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import StatLoggerFactory
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm_vacc.vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||
RequestOutputCollector)
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.async_llm import logger
|
||||
|
||||
|
||||
class AsyncLLM(EngineClient):
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
enable_log_requests: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0,
|
||||
disable_log_requests: bool = True, # Deprecated, will be removed
|
||||
) -> "AsyncLLM":
|
||||
# vacc support spec_num = 1
|
||||
from .vllm_config_checker import check_spec_model
|
||||
check_spec_model(vllm_config)
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
# Create the LLMEngine.
|
||||
from vllm.v1.engine.async_llm import AsyncLLM as DefaultAsyncLLM
|
||||
async_cls = DefaultAsyncLLM
|
||||
return async_cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
start_engine_loop=start_engine_loop,
|
||||
stat_loggers=stat_loggers,
|
||||
log_requests=enable_log_requests,
|
||||
log_stats=not disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
client_addresses=client_addresses,
|
||||
client_count=client_count,
|
||||
client_index=client_index,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
) -> "AsyncLLM":
|
||||
"""Create an AsyncLLM from the EngineArgs."""
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
# vacc support spec_num = 1
|
||||
from .vllm_config_checker import check_spec_model
|
||||
check_spec_model(vllm_config)
|
||||
|
||||
# Create the AsyncLLM.
|
||||
from vllm.v1.engine.async_llm import AsyncLLM as DefaultAsyncLLM
|
||||
async_cls = DefaultAsyncLLM
|
||||
return async_cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
)
|
||||
|
||||
async def _add_request(self, request: EngineCoreRequest,
|
||||
prompt: Optional[str],
|
||||
parent_req: Optional[ParentRequest], index: int,
|
||||
queue: RequestOutputCollector):
|
||||
|
||||
# Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, prompt, parent_req, index,
|
||||
queue)
|
||||
|
||||
# Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
|
||||
if self.log_requests:
|
||||
if request.prompt_token_ids is not None:
|
||||
logger.info("Added request: %s, prompt length: %s", request.request_id, len(request.prompt_token_ids))
|
||||
else:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
209
vllm_vacc/vllm/v1/engine/core.py
Normal file
209
vllm_vacc/vllm/v1/engine/core.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import Future
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_zmq_socket
|
||||
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.engine import (EngineCoreRequest,EngineCoreRequestType)
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash,
|
||||
generate_scheduler_kv_cache_config,
|
||||
get_kv_cache_configs,
|
||||
get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_S = 2.5
|
||||
HANDSHAKE_TIMEOUT_MINS = 5
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
|
||||
ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'
|
||||
def process_input_sockets(self, input_addresses: list[str],
|
||||
coord_input_address: Optional[str],
|
||||
identity: bytes, ready_event: threading.Event):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
bulk_add_decoder = MsgpackDecoder(list[EngineCoreRequest])
|
||||
|
||||
with ExitStack() as stack, zmq.Context() as ctx:
|
||||
input_sockets = [
|
||||
stack.enter_context(
|
||||
make_zmq_socket(ctx,
|
||||
input_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
bind=False))
|
||||
for input_address in input_addresses
|
||||
]
|
||||
if coord_input_address is None:
|
||||
coord_socket = None
|
||||
else:
|
||||
coord_socket = stack.enter_context(
|
||||
make_zmq_socket(ctx,
|
||||
coord_input_address,
|
||||
zmq.XSUB,
|
||||
identity=identity,
|
||||
bind=False))
|
||||
# Send subscription message to coordinator.
|
||||
coord_socket.send(b'\x01')
|
||||
|
||||
# Register sockets with poller.
|
||||
poller = zmq.Poller()
|
||||
for input_socket in input_sockets:
|
||||
# Send initial message to each input socket - this is required
|
||||
# before the front-end ROUTER socket can send input messages
|
||||
# back to us.
|
||||
input_socket.send(b'')
|
||||
poller.register(input_socket, zmq.POLLIN)
|
||||
|
||||
if coord_socket is not None:
|
||||
poller.register(coord_socket, zmq.POLLIN)
|
||||
|
||||
ready_event.set()
|
||||
del ready_event
|
||||
while True:
|
||||
for input_socket, _ in poller.poll():
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = input_socket.recv_multipart(
|
||||
copy=False)
|
||||
request_type = EngineCoreRequestType(
|
||||
bytes(type_frame.buffer))
|
||||
|
||||
if request_type == EngineCoreRequestType.ADD_BULK:
|
||||
# 关键:按 list[EngineCoreRequest] 解码,然后在接收线程就地 fan-out
|
||||
requests = bulk_add_decoder.decode(data_frames)
|
||||
for r in requests:
|
||||
r = self.preprocess_add_request(r)
|
||||
self.input_queue.put_nowait((EngineCoreRequestType.ADD, r))
|
||||
continue
|
||||
|
||||
# Deserialize the request data.
|
||||
if request_type == EngineCoreRequestType.ADD:
|
||||
request = add_request_decoder.decode(data_frames)
|
||||
request = self.preprocess_add_request(request)
|
||||
else:
|
||||
request = generic_decoder.decode(data_frames)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
|
||||
|
||||
class EngineCore:
|
||||
"""Inner loop of vLLM's Engine."""
|
||||
|
||||
def _initialize_kv_caches(
|
||||
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
||||
start = time.time()
|
||||
|
||||
# Get all kv cache needed by the model
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
|
||||
# get_kv_cache_specs in model_runner
|
||||
# for layer_name, layer in vllm_config.compilation_config.static_forward_context.items():
|
||||
# print(f'layer_name = {layer_name}; layer = {layer}')
|
||||
# # 只有 moe layer, 拿不到attention layer?
|
||||
# # TODO for no kv cahe model
|
||||
|
||||
|
||||
# has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
|
||||
|
||||
# if has_kv_cache:
|
||||
# if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
|
||||
# dp_group = getattr(self, "dp_group", None)
|
||||
# assert dp_group is not None
|
||||
# self.available_gpu_memory_for_kv_cache = \
|
||||
# ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
|
||||
# available_gpu_memory = [
|
||||
# self.available_gpu_memory_for_kv_cache
|
||||
# ] * len(kv_cache_specs)
|
||||
# else:
|
||||
# # Profiles the peak memory usage of the model to determine how
|
||||
# # much memory can be allocated for kv cache.
|
||||
# available_gpu_memory = (
|
||||
# self.model_executor.determine_available_memory())
|
||||
# self.available_gpu_memory_for_kv_cache = \
|
||||
# available_gpu_memory[0]
|
||||
# else:
|
||||
# # Attention free models don't need memory for kv cache
|
||||
# available_gpu_memory = [0] * len(kv_cache_specs)
|
||||
|
||||
memory_blocks = self.model_executor.determine_available_memory_block() # [(memory, blocks) * rank_number]
|
||||
available_gpu_memory = [memory_block[0] for memory_block in memory_blocks]
|
||||
num_gpu_blocks = memory_blocks[0][1]
|
||||
|
||||
assert len(kv_cache_specs) == len(available_gpu_memory)
|
||||
|
||||
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
|
||||
available_gpu_memory)
|
||||
|
||||
### patch here to support long seq_length for mtp
|
||||
for kv_cache_config in kv_cache_configs:
|
||||
for ii in range(len(kv_cache_config.kv_cache_tensors)):
|
||||
kv_cache_config.kv_cache_tensors[ii].size = kv_cache_config.kv_cache_tensors[ii].size * num_gpu_blocks // kv_cache_config.num_blocks
|
||||
|
||||
kv_cache_config.num_blocks = num_gpu_blocks
|
||||
### patch here to support long seq_length for mtp end
|
||||
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
|
||||
kv_cache_configs)
|
||||
num_gpu_blocks = scheduler_kv_cache_config.num_blocks
|
||||
num_cpu_blocks = 0
|
||||
|
||||
# Initialize kv cache and warmup the execution
|
||||
self.model_executor.initialize_from_config(kv_cache_configs)
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(("init engine (profile, create kv cache, "
|
||||
"warmup model) took %.2f seconds"), elapsed)
|
||||
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
|
||||
|
||||
def preprocess_add_request(
|
||||
self, request: EngineCoreRequest) -> tuple[Request, int]:
|
||||
"""Preprocess the request.
|
||||
|
||||
This function could be directly used in input processing thread to allow
|
||||
request initialization running in parallel with Model forward
|
||||
"""
|
||||
# Note on thread safety: no race condition.
|
||||
# `mm_receiver_cache` is reset at the end of LLMEngine init,
|
||||
# and will only be accessed in the input processing thread afterwards.
|
||||
if self.mm_receiver_cache is not None and request.mm_features:
|
||||
request.mm_features = (
|
||||
self.mm_receiver_cache.get_and_update_features(
|
||||
request.mm_features))
|
||||
|
||||
req = Request.from_engine_core_request(request,
|
||||
self.request_block_hasher)
|
||||
if req.use_structured_output:
|
||||
# Note on thread safety: no race condition.
|
||||
# `grammar_init` is only invoked in input processing thread. For
|
||||
# `structured_output_manager`, each request is independent and
|
||||
# grammar compilation is async. Scheduler always checks grammar
|
||||
# compilation status before scheduling request.
|
||||
self.structured_output_manager.grammar_init(req)
|
||||
return req, request.current_wave
|
||||
76
vllm_vacc/vllm/v1/engine/core_client.py
Normal file
76
vllm_vacc/vllm/v1/engine/core_client.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import queue
|
||||
import sys
|
||||
import uuid
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Optional, TypeVar, Union, List
|
||||
|
||||
import msgspec.msgpack
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import get_open_zmq_inproc_path, make_zmq_socket
|
||||
from vllm.v1.engine import (EngineCoreOutputs,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm_vacc.vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
||||
CoreEngineProcManager, launch_core_engines)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||
from vllm.v1.engine.core_client import MPClient
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
EngineIdentity = bytes
|
||||
|
||||
|
||||
|
||||
class EngineCoreClient(ABC):
|
||||
"""
|
||||
EngineCoreClient: subclasses handle different methods for pushing
|
||||
and pulling from the EngineCore for asyncio / multiprocessing.
|
||||
|
||||
Subclasses:
|
||||
* InprocClient: In process EngineCore (for V0-style LLMEngine use)
|
||||
* SyncMPClient: ZMQ + background proc EngineCore (for LLM)
|
||||
* AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM)
|
||||
"""
|
||||
@abstractmethod
|
||||
def _send_input(self, req_type: EngineCoreRequestType, payload: Any) -> None:
|
||||
"""Send a request to EngineCore."""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_requests(self, requests: List["EngineCoreRequest"]) -> None:
|
||||
"""一次性发多条 ADD(ADD_BULK)"""
|
||||
if not requests:
|
||||
return
|
||||
self._send_input(EngineCoreRequestType.ADD_BULK, requests)
|
||||
|
||||
class SyncMPClient(MPClient):
|
||||
"""Synchronous client for multi-proc EngineCore."""
|
||||
def add_requests(self, requests: List[EngineCoreRequest]) -> None:
|
||||
if not requests:
|
||||
return
|
||||
if self.is_dp: # 与 add_request 保持一致
|
||||
self.engines_running = True
|
||||
self._send_input(EngineCoreRequestType.ADD_BULK, requests)
|
||||
|
||||
176
vllm_vacc/vllm/v1/engine/llm_engine.py
Normal file
176
vllm_vacc/vllm/v1/engine/llm_engine.py
Normal file
@@ -0,0 +1,176 @@
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
# from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
|
||||
StatLoggerFactory)
|
||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
class LLMEngine:
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
disable_log_stats: bool = False,
|
||||
) -> "LLMEngine":
|
||||
# vacc support spec_num = 1
|
||||
from .vllm_config_checker import check_spec_model
|
||||
check_spec_model(vllm_config)
|
||||
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as DefaultLLM
|
||||
default_cls = DefaultLLM
|
||||
return default_cls(vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=(not disable_log_stats),
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
enable_multiprocessing: bool = False,
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
# vacc support spec_num = 1
|
||||
from .vllm_config_checker import check_spec_model
|
||||
check_spec_model(vllm_config)
|
||||
|
||||
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||
logger.debug("Enabling multiprocessing for LLMEngine.")
|
||||
enable_multiprocessing = True
|
||||
|
||||
# Create the LLMEngine.
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as DefaultLLM
|
||||
default_cls = DefaultLLM
|
||||
return default_cls(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=enable_multiprocessing)
|
||||
|
||||
"""Legacy LLMEngine for backwards compatibility."""
|
||||
def add_requests(
|
||||
self,
|
||||
items: list[tuple[
|
||||
str, # request_id
|
||||
PromptType, # prompt
|
||||
Union[SamplingParams, PoolingParams], # params
|
||||
Optional[float], # arrival_time
|
||||
Optional[LoRARequest], # lora_request
|
||||
Optional[dict], # tokenization_kwargs
|
||||
Optional[dict], # trace_headers
|
||||
# Optional[PromptAdapterRequest], # prompt_adapter_request
|
||||
int, # priority
|
||||
]],
|
||||
) -> None:
|
||||
"""批量把请求送入 EngineCore,一次性触发 ADD_BULK。"""
|
||||
core_reqs: list[EngineCoreRequest] = []
|
||||
|
||||
for (request_id, prompt, params, arrival_time, lora_request,
|
||||
tokenization_kwargs, trace_headers,
|
||||
priority) in items:
|
||||
|
||||
# 复用现有逐条流程的解析入口
|
||||
prompt_str, request = self.processor.process_inputs(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
# prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
|
||||
if n == 1:
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request, prompt_str, None, 0)
|
||||
# Add the request to EngineCore.
|
||||
core_reqs.append(request)
|
||||
continue
|
||||
# self.engine_core.add_request(request)
|
||||
# return
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_req = ParentRequest(request_id, params)
|
||||
for idx in range(n):
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(child_request, prompt_str,
|
||||
parent_req, idx)
|
||||
# Add the request to EngineCore.
|
||||
# self.engine_core.add_request(child_request)
|
||||
# print("add_requests: child_request id=", child_request.request_id)
|
||||
core_reqs.append(child_request)
|
||||
|
||||
# output_processor 需要为每个“实际进入引擎的 req_id”建索引。
|
||||
# 如果是 SamplingParams 且 n>1/best_of>1,要做 parent-children 拆分;
|
||||
# 否则直接登记单条。
|
||||
# if isinstance(params, SamplingParams) and (
|
||||
# (params.n is not None and params.n > 1) or
|
||||
# (getattr(params, "best_of", 1) and getattr(params, "best_of", 1) > 1)
|
||||
# ):
|
||||
# parent = self.parallel_sampler.create_parent(request_id, params)
|
||||
# # 注意:最后一个 child 可以直接复用 request,其余用 copy
|
||||
# children = self.parallel_sampler.materialize_children(parent, request)
|
||||
# for child_idx, child in enumerate(children):
|
||||
# self.output_processor.add_request(
|
||||
# request=child,
|
||||
# prompt_str=prompt_str,
|
||||
# parent=parent,
|
||||
# child_index=child_idx,
|
||||
# )
|
||||
# core_reqs.append(child)
|
||||
# else:
|
||||
# self.output_processor.add_request(request, prompt_str)
|
||||
# core_reqs.append(request)
|
||||
|
||||
# 关键:一次性下发给 Core。EngineCoreClient 会发送 ADD_BULK。
|
||||
print('self.engine_core', self.engine_core)
|
||||
self.engine_core.add_requests(core_reqs)
|
||||
|
||||
|
||||
184
vllm_vacc/vllm/v1/engine/processor.py
Normal file
184
vllm_vacc/vllm/v1/engine/processor.py
Normal file
@@ -0,0 +1,184 @@
|
||||
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import processor_cache_from_config
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
from vllm.v1.structured_output.backend_lm_format_enforcer import (
|
||||
validate_structured_output_request_lm_format_enforcer)
|
||||
from vllm.v1.structured_output.backend_outlines import (
|
||||
validate_structured_output_request_outlines)
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
validate_xgrammar_grammar)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Processor:
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> tuple[Optional[str], EngineCoreRequest]:
|
||||
|
||||
# TODO(woosuk): Support pooling models.
|
||||
self._validate_lora(lora_request)
|
||||
self._validate_params(params)
|
||||
|
||||
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
|
||||
data_parallel_size):
|
||||
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
|
||||
f"is out of range [0, {data_parallel_size}).")
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
# Optionally generate multimodal hash overrides to avoid hashing
|
||||
# multimodal data items by their content as their identifiers.
|
||||
|
||||
# NOTE: when users explicitly turn off BOTH prefix caching and input
|
||||
# processing caching, no multimodal features or embeddings will be
|
||||
# reused across requests, therefore identifying multimodal data items
|
||||
# by their content is no longer necessary, and we create uuids with
|
||||
# request id-modality-index as multimodal hash overrides.
|
||||
if (self.model_config.multimodal_config and
|
||||
self.model_config.multimodal_config.mm_processor_cache_gb == 0
|
||||
and not self.cache_config.enable_prefix_caching):
|
||||
mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
|
||||
else:
|
||||
# Otherwise, use user-provided uuids as multimodal hash overrides
|
||||
# if provided.
|
||||
self._validate_multi_modal_uuids(prompt)
|
||||
if isinstance(prompt, dict):
|
||||
mm_uuids = prompt.get("multi_modal_uuids")
|
||||
else:
|
||||
mm_uuids = None
|
||||
|
||||
# Process inputs, which includes:
|
||||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||
# 2. For multimodal models with a merged preprocessor, preprocess
|
||||
# multimodal data and expand prompt token ids accordingly.
|
||||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.validate_request(
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
processed_inputs=processed_inputs,
|
||||
)
|
||||
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id()
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
self._validate_model_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
# Mypy does not always properly infer the types of some elements of
|
||||
# discriminated unions of TypedDicts, because of how it handles
|
||||
# inheritance of TypedDict. If we explicitly extract the items we want
|
||||
# we can avoid type errors from using `dict.get` later in the method.
|
||||
prompt_str: Optional[str] = None if decoder_inputs[
|
||||
"type"] == "embeds" else decoder_inputs.get("prompt")
|
||||
prompt_token_ids = decoder_inputs[
|
||||
"prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None
|
||||
prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[
|
||||
"type"] == "embeds" else None
|
||||
deepstack_input_embeds = decoder_inputs["deepstack_input_embeds"] if decoder_inputs[
|
||||
"type"] == "embeds" else None
|
||||
|
||||
# for deepstack_input_embeds in llm.generate method
|
||||
if isinstance(deepstack_input_embeds, dict):
|
||||
all_tensors = []
|
||||
for key in deepstack_input_embeds:
|
||||
if isinstance(deepstack_input_embeds[key], torch.Tensor):
|
||||
all_tensors.append(deepstack_input_embeds[key].unsqueeze(0))
|
||||
if len(all_tensors) > 0:
|
||||
deepstack_input_embeds = torch.concatenate(all_tensors, 0)
|
||||
|
||||
sampling_params = None
|
||||
pooling_params = None
|
||||
if isinstance(params, SamplingParams):
|
||||
# TODO: can we avoid cloning here in multiproc case?
|
||||
sampling_params = params.clone()
|
||||
# If unset max tokens, then generate up to the max_model_len.
|
||||
if sampling_params.max_tokens is None:
|
||||
seq_len = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids, prompt_embeds)
|
||||
sampling_params.max_tokens = \
|
||||
self.model_config.max_model_len - seq_len
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
if self.tokenizer is not None:
|
||||
sampling_params.update_from_tokenizer(self.tokenizer)
|
||||
else:
|
||||
pooling_params = params.clone()
|
||||
|
||||
# Multimodal related.
|
||||
mm_features: Optional[list[MultiModalFeatureSpec]] = None
|
||||
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
|
||||
decoder_mm_positions = decoder_inputs["mm_placeholders"]
|
||||
decoder_mm_hashes = decoder_inputs["mm_hashes"]
|
||||
|
||||
# Merge and flatten multimodal placeholders, hashes and inputs
|
||||
# from dictionaries to lists, and sort them by each item's position
|
||||
# in the input sequence.
|
||||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||||
|
||||
mm_features = []
|
||||
for modality, idx in sorted_mm_idxs:
|
||||
mm_features.append(
|
||||
MultiModalFeatureSpec(
|
||||
data=decoder_mm_inputs[modality][idx],
|
||||
modality=modality,
|
||||
identifier=decoder_mm_hashes[modality][idx],
|
||||
mm_position=decoder_mm_positions[modality][idx]))
|
||||
|
||||
return prompt_str, EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_embeds=prompt_embeds,
|
||||
deepstack_input_embeds=deepstack_input_embeds,
|
||||
mm_features=mm_features,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
eos_token_id=eos_token_id,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
cache_salt=decoder_inputs.get("cache_salt"),
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
17
vllm_vacc/vllm/v1/engine/vllm_config_checker.py
Normal file
17
vllm_vacc/vllm/v1/engine/vllm_config_checker.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# check spec model config
|
||||
# write spec model func to config file
|
||||
def check_spec_model(vllm_config):
|
||||
# add spec tag
|
||||
speculative_mode = hasattr(vllm_config, 'speculative_config')
|
||||
if speculative_mode and \
|
||||
hasattr(vllm_config.speculative_config, 'num_speculative_tokens') and \
|
||||
vllm_config.speculative_config.num_speculative_tokens != 1:
|
||||
raise ValueError(f'run_mp_engine: only support num_speculative_tokens == 1, but get {vllm_config.speculative_config.num_speculative_tokens}')
|
||||
|
||||
default_model_infos = "default"
|
||||
if speculative_mode:
|
||||
if hasattr(vllm_config.speculative_config, 'method'):
|
||||
default_model_infos = vllm_config.speculative_config.method
|
||||
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
vllm_vacc_config_manager().update_model_infos(default_model_infos)
|
||||
0
vllm_vacc/vllm/v1/executor/__init__.py
Normal file
0
vllm_vacc/vllm/v1/executor/__init__.py
Normal file
BIN
vllm_vacc/vllm/v1/executor/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/executor/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/executor/__pycache__/abstract.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/executor/__pycache__/abstract.cpython-312.pyc
Normal file
Binary file not shown.
24
vllm_vacc/vllm/v1/executor/abstract.py
Normal file
24
vllm_vacc/vllm/v1/executor/abstract.py
Normal file
@@ -0,0 +1,24 @@
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
UniProcExecutor as UniProcExecutorV0)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
FailureCallback = Callable[[], None]
|
||||
|
||||
|
||||
class Executor(ExecutorBase):
|
||||
|
||||
def determine_available_memory_block(self) -> list[(int, int)]: # in bytes
|
||||
output = self.collective_rpc("determine_available_memory_block")
|
||||
return output
|
||||
0
vllm_vacc/vllm/v1/metrics/__init__.py
Normal file
0
vllm_vacc/vllm/v1/metrics/__init__.py
Normal file
BIN
vllm_vacc/vllm/v1/metrics/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/metrics/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/metrics/__pycache__/loggers.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/metrics/__pycache__/loggers.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/metrics/__pycache__/stats.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/metrics/__pycache__/stats.cpython-312.pyc
Normal file
Binary file not shown.
49
vllm_vacc/vllm/v1/metrics/loggers.py
Normal file
49
vllm_vacc/vllm/v1/metrics/loggers.py
Normal file
@@ -0,0 +1,49 @@
|
||||
|
||||
import time
|
||||
|
||||
from vllm.v1.metrics.loggers import StatLoggerBase
|
||||
from vllm.v1.metrics.loggers import logger
|
||||
|
||||
class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
def log(self):
|
||||
now = time.monotonic()
|
||||
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
|
||||
generation_throughput = self._get_throughput(
|
||||
self.num_generation_tokens, now)
|
||||
|
||||
self._reset(now)
|
||||
|
||||
scheduler_stats = self.last_scheduler_stats
|
||||
|
||||
log_fn = logger.info
|
||||
if not any(
|
||||
(prompt_throughput, generation_throughput,
|
||||
self.last_prompt_throughput, self.last_generation_throughput)):
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
|
||||
self.last_generation_throughput = generation_throughput
|
||||
self.last_prompt_throughput = prompt_throughput
|
||||
|
||||
# Format and print output.
|
||||
log_fn(
|
||||
"Engine %03d: "
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Waiting: %d reqs, "
|
||||
"GPU KV cache usage: %.1f%%, "
|
||||
"Prefix cache hit rate: %.1f%%, "
|
||||
"running seqlens: %s ",
|
||||
self.engine_index,
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
scheduler_stats.num_running_reqs,
|
||||
scheduler_stats.num_waiting_reqs,
|
||||
scheduler_stats.kv_cache_usage * 100,
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
str(scheduler_stats.running_seqlens),
|
||||
)
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
self.kv_transfer_logging.log(log_fn=log_fn)
|
||||
|
||||
32
vllm_vacc/vllm/v1/metrics/stats.py
Normal file
32
vllm_vacc/vllm/v1/metrics/stats.py
Normal file
@@ -0,0 +1,32 @@
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerStats:
|
||||
"""Stats associated with the scheduler."""
|
||||
|
||||
num_running_reqs: int = 0
|
||||
num_waiting_reqs: int = 0
|
||||
|
||||
running_seqlens: list[int] = None
|
||||
|
||||
# These are used for internal DP load-balancing.
|
||||
step_counter: int = 0
|
||||
current_wave: int = 0
|
||||
|
||||
kv_cache_usage: float = 0.0
|
||||
|
||||
prefix_cache_stats: PrefixCacheStats = field(
|
||||
default_factory=PrefixCacheStats)
|
||||
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||
kv_connector_stats: Optional[dict[str, Any]] = None
|
||||
|
||||
num_corrupted_reqs: int = 0
|
||||
|
||||
355
vllm_vacc/vllm/v1/request.py
Normal file
355
vllm_vacc/vllm/v1/request.py
Normal file
@@ -0,0 +1,355 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
EngineCoreRequest, FinishReason)
|
||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||
from vllm.v1.utils import ConstantList
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
|
||||
|
||||
class Request:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[list[int]],
|
||||
sampling_params: Optional[SamplingParams],
|
||||
pooling_params: Optional[PoolingParams],
|
||||
eos_token_id: Optional[int],
|
||||
client_index: int = 0,
|
||||
arrival_time: Optional[float] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
mm_features: Optional[list[MultiModalFeatureSpec]] = None,
|
||||
lora_request: Optional["LoRARequest"] = None,
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
block_hasher: Optional[Callable[["Request"],
|
||||
list["BlockHash"]]] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.client_index = client_index
|
||||
self.priority = priority
|
||||
self.sampling_params = sampling_params
|
||||
self.pooling_params = pooling_params
|
||||
# Because of LoRA, the eos token id can be different for each request.
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.structured_output_request = structured_output_request
|
||||
self.arrival_time = arrival_time if arrival_time is not None else \
|
||||
time.time()
|
||||
|
||||
self.status = RequestStatus.WAITING
|
||||
self.use_structured_output = False
|
||||
self.events: list[EngineCoreEvent] = []
|
||||
self.stop_reason: Union[int, str, None] = None
|
||||
|
||||
# P/D: Connector-specific KV transfer parameters.
|
||||
self.kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
if pooling_params is not None:
|
||||
# Pooling models.
|
||||
self.max_tokens = 1
|
||||
elif sampling_params is not None:
|
||||
# Generative models.
|
||||
assert sampling_params.max_tokens is not None
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
if sampling_params.structured_outputs is not None:
|
||||
self.status = RequestStatus.WAITING_FOR_FSM
|
||||
self.use_structured_output = True
|
||||
|
||||
if sampling_params.extra_args is not None:
|
||||
self.kv_transfer_params = \
|
||||
sampling_params.extra_args.get("kv_transfer_params")
|
||||
else:
|
||||
raise ValueError(
|
||||
"sampling_params and pooling_params can't both be unset")
|
||||
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_embeds = prompt_embeds
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids, prompt_embeds)
|
||||
self._output_token_ids: list[int] = []
|
||||
self._all_token_ids: list[int] = self.prompt_token_ids.copy(
|
||||
) if self.prompt_token_ids is not None else [0
|
||||
] * self.num_prompt_tokens
|
||||
self.num_output_placeholders = 0 # Used in async scheduling.
|
||||
self.spec_token_ids: list[int] = []
|
||||
self.num_computed_tokens = 0
|
||||
self.cache_salt: Optional[str] = cache_salt
|
||||
|
||||
# Multi-modal related
|
||||
self.mm_features = mm_features or []
|
||||
self.num_encoder_inputs = len(self.mm_features)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
|
||||
# Read-only views
|
||||
# Prevent directly appending to these lists since
|
||||
# they should also be updated simultaneously.
|
||||
self.output_token_ids = ConstantList(self._output_token_ids)
|
||||
self.all_token_ids = ConstantList(self._all_token_ids)
|
||||
# trace_headers
|
||||
self.trace_headers = trace_headers
|
||||
# State
|
||||
# The number of tokens with prefix cache hits.
|
||||
self.num_cached_tokens = -1
|
||||
|
||||
# The number of NaNs in logits. A value greater than 0
|
||||
# indicates that the output is corrupted
|
||||
self.num_nans_in_logits = 0
|
||||
|
||||
self.block_hashes: list[BlockHash] = []
|
||||
self.get_hash_new_full_blocks: Optional[Callable[
|
||||
[], list[BlockHash]]] = None
|
||||
if block_hasher is not None:
|
||||
self.get_hash_new_full_blocks = partial(block_hasher, self)
|
||||
self.block_hashes = self.get_hash_new_full_blocks()
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(
|
||||
cls, request: EngineCoreRequest,
|
||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
|
||||
) -> "Request":
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
client_index=request.client_index,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
eos_token_id=request.eos_token_id,
|
||||
arrival_time=request.arrival_time,
|
||||
lora_request=request.lora_request,
|
||||
structured_output_request=StructuredOutputRequest(
|
||||
sampling_params=request.sampling_params) \
|
||||
if request.sampling_params else None,
|
||||
cache_salt=request.cache_salt,
|
||||
priority=request.priority,
|
||||
trace_headers=request.trace_headers,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
|
||||
def append_output_token_ids(
|
||||
self,
|
||||
token_ids: Union[int, list[int]],
|
||||
) -> None:
|
||||
if isinstance(token_ids, int):
|
||||
self._output_token_ids.append(token_ids)
|
||||
self._all_token_ids.append(token_ids)
|
||||
else:
|
||||
self._output_token_ids.extend(token_ids)
|
||||
self._all_token_ids.extend(token_ids)
|
||||
|
||||
if self.get_hash_new_full_blocks is not None:
|
||||
self.block_hashes.extend(self.get_hash_new_full_blocks())
|
||||
|
||||
@property
|
||||
def is_output_corrupted(self) -> bool:
|
||||
return self.num_nans_in_logits > 0
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return len(self._all_token_ids)
|
||||
|
||||
@property
|
||||
def num_tokens_with_spec(self) -> int:
|
||||
return len(self._all_token_ids) + len(self.spec_token_ids)
|
||||
|
||||
@property
|
||||
def num_output_tokens(self) -> int:
|
||||
return len(self._output_token_ids)
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return RequestStatus.is_finished(self.status)
|
||||
|
||||
def get_finished_reason(self) -> Union[FinishReason, None]:
|
||||
return RequestStatus.get_finished_reason(self.status)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
assert input_id < len(self.mm_features)
|
||||
num_tokens = self.mm_features[input_id].mm_position.length
|
||||
return num_tokens
|
||||
|
||||
def record_event(
|
||||
self,
|
||||
event_type: EngineCoreEventType,
|
||||
timestamp: Optional[float] = None,
|
||||
) -> None:
|
||||
self.events.append(EngineCoreEvent.new_event(event_type, timestamp))
|
||||
|
||||
def take_events(self) -> Optional[list[EngineCoreEvent]]:
|
||||
if not self.events:
|
||||
return None
|
||||
events, self.events = self.events, []
|
||||
return events
|
||||
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[list[int]],
|
||||
sampling_params: Optional[SamplingParams],
|
||||
pooling_params: Optional[PoolingParams],
|
||||
eos_token_id: Optional[int],
|
||||
client_index: int = 0,
|
||||
arrival_time: Optional[float] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
deepstack_input_embeds: Optional[torch.Tensor] = None,
|
||||
mm_features: Optional[list[MultiModalFeatureSpec]] = None,
|
||||
lora_request: Optional["LoRARequest"] = None,
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
block_hasher: Optional[Callable[["Request"],
|
||||
list["BlockHash"]]] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.client_index = client_index
|
||||
self.priority = priority
|
||||
self.sampling_params = sampling_params
|
||||
self.pooling_params = pooling_params
|
||||
# Because of LoRA, the eos token id can be different for each request.
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.structured_output_request = structured_output_request
|
||||
self.arrival_time = arrival_time if arrival_time is not None else \
|
||||
time.time()
|
||||
|
||||
self.status = RequestStatus.WAITING
|
||||
self.use_structured_output = False
|
||||
self.events: list[EngineCoreEvent] = []
|
||||
self.stop_reason: Union[int, str, None] = None
|
||||
|
||||
# P/D: Connector-specific KV transfer parameters.
|
||||
self.kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
if pooling_params is not None:
|
||||
# Pooling models.
|
||||
self.max_tokens = 1
|
||||
elif sampling_params is not None:
|
||||
# Generative models.
|
||||
assert sampling_params.max_tokens is not None
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
if sampling_params.guided_decoding is not None:
|
||||
self.status = RequestStatus.WAITING_FOR_FSM
|
||||
self.use_structured_output = True
|
||||
|
||||
if sampling_params.extra_args is not None:
|
||||
self.kv_transfer_params = \
|
||||
sampling_params.extra_args.get("kv_transfer_params")
|
||||
else:
|
||||
raise ValueError(
|
||||
"sampling_params and pooling_params can't both be unset")
|
||||
|
||||
|
||||
# Strict barrier batching metadata (optional)
|
||||
self.barrier_group_id: Optional[str] = None
|
||||
self.barrier_group_size: int = 0
|
||||
if sampling_params is not None and sampling_params.extra_args:
|
||||
try:
|
||||
if "barrier_group_id" in sampling_params.extra_args:
|
||||
self.barrier_group_id = str(
|
||||
sampling_params.extra_args["barrier_group_id"])
|
||||
if "barrier_group_size" in sampling_params.extra_args:
|
||||
self.barrier_group_size = int(
|
||||
sampling_params.extra_args["barrier_group_size"])
|
||||
except Exception:
|
||||
# Be tolerant to malformed extra_args; just ignore on failure.
|
||||
self.barrier_group_id = None
|
||||
self.barrier_group_size = 0
|
||||
|
||||
|
||||
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_embeds = prompt_embeds
|
||||
self.deepstack_input_embeds = deepstack_input_embeds
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids, prompt_embeds)
|
||||
self._output_token_ids: list[int] = []
|
||||
self._all_token_ids: list[int] = self.prompt_token_ids.copy(
|
||||
) if self.prompt_token_ids is not None else [0
|
||||
] * self.num_prompt_tokens
|
||||
self.num_output_placeholders = 0 # Used in async scheduling.
|
||||
self.spec_token_ids: list[int] = []
|
||||
self.num_computed_tokens = 0
|
||||
self.cache_salt: Optional[str] = cache_salt
|
||||
|
||||
# Multi-modal related
|
||||
self.mm_features = mm_features or []
|
||||
self.num_encoder_inputs = len(self.mm_features)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
|
||||
# Read-only views
|
||||
# Prevent directly appending to these lists since
|
||||
# they should also be updated simultaneously.
|
||||
self.output_token_ids = ConstantList(self._output_token_ids)
|
||||
self.all_token_ids = ConstantList(self._all_token_ids)
|
||||
# trace_headers
|
||||
self.trace_headers = trace_headers
|
||||
# State
|
||||
# The number of tokens with prefix cache hits.
|
||||
self.num_cached_tokens = -1
|
||||
|
||||
# The number of NaNs in logits. A value greater than 0
|
||||
# indicates that the output is corrupted
|
||||
self.num_nans_in_logits = 0
|
||||
|
||||
self.block_hashes: list[BlockHash] = []
|
||||
self.get_hash_new_full_blocks: Optional[Callable[
|
||||
[], list[BlockHash]]] = None
|
||||
if block_hasher is not None:
|
||||
self.get_hash_new_full_blocks = partial(block_hasher, self)
|
||||
self.block_hashes = self.get_hash_new_full_blocks()
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(
|
||||
cls, request: EngineCoreRequest,
|
||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
|
||||
) -> "Request":
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
client_index=request.client_index,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
deepstack_input_embeds=request.deepstack_input_embeds,
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
eos_token_id=request.eos_token_id,
|
||||
arrival_time=request.arrival_time,
|
||||
lora_request=request.lora_request,
|
||||
structured_output_request=StructuredOutputRequest(
|
||||
sampling_params=request.sampling_params) \
|
||||
if request.sampling_params else None,
|
||||
cache_salt=request.cache_salt,
|
||||
priority=request.priority,
|
||||
trace_headers=request.trace_headers,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
|
||||
def record_event(
|
||||
self,
|
||||
event_type: EngineCoreEventType,
|
||||
timestamp: Optional[float] = None,
|
||||
) -> None:
|
||||
self.events.append(EngineCoreEvent.new_event(event_type, timestamp))
|
||||
0
vllm_vacc/vllm/v1/sample/__init__.py
Normal file
0
vllm_vacc/vllm/v1/sample/__init__.py
Normal file
BIN
vllm_vacc/vllm/v1/sample/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/sample/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm_vacc/vllm/v1/sample/__pycache__/metadata.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/sample/__pycache__/metadata.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm_vacc/vllm/v1/sample/__pycache__/sampler.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/sample/__pycache__/sampler.cpython-312.pyc
Normal file
Binary file not shown.
91
vllm_vacc/vllm/v1/sample/cached_pooler.py
Normal file
91
vllm_vacc/vllm/v1/sample/cached_pooler.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch_vacc
|
||||
|
||||
class BinCountTensorPooler:
|
||||
"""
|
||||
sampler类中(bin-count tensor)的缓冲器。
|
||||
核心功能:为apply_penalties 提供创建好的bin-count, 为避免重复重头计算。
|
||||
注意: 如果系统中的req_id未能确保在max_count的数量下的唯一性,要考虑是否适用该pooler
|
||||
|
||||
pooler = BinCountTensorPooler(15169, "cpu")
|
||||
list_1 = pooler.request_tensors(['1','3','aad'])
|
||||
bin_count_buffers = pooler.request_tensors(req_ids)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
max_cache_count: int = 20,
|
||||
):
|
||||
# 缓存容器:用列表维护请求ID与对应张量的顺序(FIFO淘汰)
|
||||
self._cached_tensors: list[torch.Tensor] = []
|
||||
self._cached_req_ids: list[str] = []
|
||||
|
||||
# 张量维度(+1 通常用于padding/特殊标记)
|
||||
self.vocab_size: int = vocab_size + 1
|
||||
self.device: torch.device = device
|
||||
self.max_cache_count: int = max_cache_count
|
||||
|
||||
def request_tensors(self, req_ids: list[str]) -> list[torch.Tensor]:
|
||||
"""
|
||||
批量请求张量:已缓存的直接返回,未缓存的创建并加入缓存。
|
||||
|
||||
Args:
|
||||
req_ids: 待请求的请求ID列表
|
||||
|
||||
Returns:
|
||||
与req_ids顺序对应的bin-count tensor
|
||||
"""
|
||||
if not req_ids:
|
||||
return [] # 空输入快速返回,避免无效循环
|
||||
|
||||
out_tensors = []
|
||||
for req_id in req_ids:
|
||||
if req_id not in self._cached_req_ids:
|
||||
# 注册新的req, 增加cached-tensor
|
||||
self._add_new_cache(req_id)
|
||||
|
||||
# 快速获取缓存索引(后续可优化为dict映射,提升性能)
|
||||
cache_idx = self._cached_req_ids.index(req_id)
|
||||
out_tensors.append(self._cached_tensors[cache_idx])
|
||||
|
||||
return out_tensors
|
||||
|
||||
def _add_new_cache(self, req_id: str) -> None:
|
||||
"""
|
||||
新增缓存项:创建张量并加入缓存,超出最大数量时按FIFO淘汰最早项。
|
||||
私有方法:封装内部逻辑,避免外部直接调用
|
||||
"""
|
||||
# FIFO淘汰:缓存满时移除头部(最早加入的项)
|
||||
if len(self._cached_req_ids) >= self.max_cache_count: # >= 更严谨(避免边界值问题)
|
||||
self._cached_tensors.pop(0)
|
||||
self._cached_req_ids.pop(0)
|
||||
|
||||
# 创建二进制计数张量(int32类型,初始全0)
|
||||
bin_count_tensor = torch.zeros(
|
||||
size=(1, self.vocab_size),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
requires_grad=False, # 明确禁用梯度(计数张量无需反向传播)
|
||||
)
|
||||
|
||||
# 同步添加ID和张量(保证两个列表索引一致)
|
||||
self._cached_req_ids.append(req_id)
|
||||
self._cached_tensors.append(bin_count_tensor)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清空所有缓存(可选扩展方法,方便外部手动清理)"""
|
||||
self._cached_tensors.clear()
|
||||
self._cached_req_ids.clear()
|
||||
|
||||
def get_cache_status(self) -> dict:
|
||||
"""获取缓存状态(可选扩展方法,方便监控)"""
|
||||
return {
|
||||
"current_cache_count": len(self._cached_req_ids),
|
||||
"max_cache_count": self.max_cache_count,
|
||||
"cached_req_ids": self._cached_req_ids.copy(), # 返回副本避免外部修改
|
||||
"tensor_shape": (1, self.vocab_size),
|
||||
"tensor_dtype": torch.int32,
|
||||
"device": str(self.device),
|
||||
}
|
||||
48
vllm_vacc/vllm/v1/sample/metadata.py
Normal file
48
vllm_vacc/vllm/v1/sample/metadata.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
|
||||
temperature: Optional[torch.Tensor]
|
||||
all_greedy: bool
|
||||
all_random: bool
|
||||
|
||||
top_p: Optional[torch.Tensor]
|
||||
top_k: Optional[torch.Tensor]
|
||||
|
||||
generators: dict[int, torch.Generator]
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: Optional[int]
|
||||
|
||||
no_penalties: bool
|
||||
prompt_token_ids: Optional[torch.Tensor]
|
||||
frequency_penalties: torch.Tensor
|
||||
presence_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor
|
||||
|
||||
output_token_ids: list[list[int]]
|
||||
|
||||
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||
# vocab size).
|
||||
allowed_token_ids_mask: Optional[torch.Tensor]
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
bad_words_token_ids: dict[int, list[list[int]]]
|
||||
|
||||
# Loaded logits processors
|
||||
logitsprocs: LogitsProcessors
|
||||
|
||||
temperature_cpu: Optional[torch.Tensor]
|
||||
top_p_cpu: Optional[torch.Tensor]
|
||||
top_k_cpu: Optional[torch.Tensor]
|
||||
|
||||
230
vllm_vacc/vllm/v1/sample/rejection_sampler.py
Normal file
230
vllm_vacc/vllm/v1/sample/rejection_sampler.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.sample.rejection_sampler import generate_uniform_probs, compute_probs, rejection_random_sample_kernel, sample_recovered_tokens
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
||||
GREEDY_TEMPERATURE: tl.constexpr = -1
|
||||
# Maximum number of speculative draft tokens allowed per request in a single
|
||||
# step. This value is chosen to be large enough to handle typical use cases.
|
||||
MAX_SPEC_LEN = 32
|
||||
|
||||
|
||||
def rejection_greedy_sample_python(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
target_argmax_ptr, # [num_tokens]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
is_greedy_ptr, # [batch_size] or None
|
||||
max_spec_len,
|
||||
num_warps
|
||||
):
|
||||
# print('max_spec_len', max_spec_len)
|
||||
if max_spec_len == 1:
|
||||
for bi in range(output_token_ids_ptr.shape[0]):
|
||||
output_token_ids_ptr[bi, 0] = target_argmax_ptr[bi]
|
||||
if target_argmax_ptr[bi].item() == draft_token_ids_ptr[bi].item():
|
||||
output_token_ids_ptr[bi, 1] = bonus_token_ids_ptr[bi]
|
||||
else:
|
||||
raise ValueError('TODO mtp k > 1')
|
||||
|
||||
|
||||
class RejectionSampler(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
metadata: SpecDecodeMetadata,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_logits: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Args:
|
||||
metadata:
|
||||
Metadata for spec decoding.
|
||||
draft_probs (Optional[torch.Tensor]):
|
||||
Probability distribution for the draft tokens. Shape is
|
||||
[num_tokens, vocab_size]. Can be None if probabilities are
|
||||
not provided, which is the case for ngram spec decode.
|
||||
target_logits (torch.Tensor):
|
||||
Target model's logits probability distribution.
|
||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
||||
different requests are flattened into a single tensor because
|
||||
this is the shape of the output logits.
|
||||
NOTE: `target_logits` can be updated in place to save memory.
|
||||
bonus_token_ids_tensor (torch.Tensor):
|
||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
||||
Bonus tokens are added to the end of the sequence if all
|
||||
proposed tokens are accepted. We generate the bonus tokens
|
||||
outside of the rejection sampler with the default sampling
|
||||
strategy. It allows for more flexibility in the sampling
|
||||
process such as top_p, top_k sampling.
|
||||
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
|
||||
Additional metadata needed for sampling, such as temperature,
|
||||
top-k/top-p parameters, or other relevant information.
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
'''
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
# `compute_probs` function.
|
||||
|
||||
# print(sampling_metadata)
|
||||
# rank_id = get_tensor_model_parallel_rank()
|
||||
if metadata.max_spec_len == 1:
|
||||
output_token_ids = torch.vacc.rejection_sampler_v1(
|
||||
target_logits.to(torch.float32),
|
||||
metadata.draft_token_ids,
|
||||
bonus_token_ids,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.top_p,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.all_greedy,
|
||||
sampling_metadata.all_random,
|
||||
sampling_metadata.generators
|
||||
)
|
||||
else:
|
||||
target_probs = compute_probs(
|
||||
target_logits.to(torch.float32),
|
||||
metadata.cu_num_draft_tokens,
|
||||
sampling_metadata,
|
||||
)
|
||||
output_token_ids = rejection_sample(
|
||||
metadata.draft_token_ids,
|
||||
metadata.num_draft_tokens,
|
||||
metadata.max_spec_len,
|
||||
metadata.cu_num_draft_tokens,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
|
||||
# output_token_ids_cpu = output_token_ids.cpu().tolist()
|
||||
# output_token_ids_dev_cpu = output_token_ids_dev.cpu().tolist()
|
||||
# for i in range(len(output_token_ids_cpu)):
|
||||
# for j in range(len(output_token_ids_cpu[0])):
|
||||
# if output_token_ids_cpu[i][j] != output_token_ids_dev_cpu[i][j]:
|
||||
# # print(output_token_ids_cpu)
|
||||
# # print(output_token_ids_dev_cpu)
|
||||
# print("cpu \n", output_token_ids, "vacc \n" , output_token_ids_dev)
|
||||
# exit()
|
||||
# print("cpu \n", output_token_ids, "vacc \n" , output_token_ids_dev)
|
||||
return output_token_ids
|
||||
|
||||
def rejection_sample(
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_draft_tokens: list[int],
|
||||
max_spec_len: int,
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert draft_token_ids.ndim == 1
|
||||
assert draft_probs is None or draft_probs.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
assert target_probs.ndim == 2
|
||||
|
||||
batch_size = len(num_draft_tokens)
|
||||
num_tokens = draft_token_ids.shape[0]
|
||||
vocab_size = target_probs.shape[-1]
|
||||
device = target_probs.device
|
||||
assert draft_token_ids.is_contiguous()
|
||||
assert draft_probs is None or draft_probs.is_contiguous()
|
||||
assert target_probs.is_contiguous()
|
||||
assert bonus_token_ids.is_contiguous()
|
||||
assert target_probs.shape == (num_tokens, vocab_size)
|
||||
|
||||
# Create output buffer.
|
||||
output_token_ids = torch.empty(
|
||||
(batch_size, max_spec_len + 1),
|
||||
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
|
||||
device=device,
|
||||
)
|
||||
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
if sampling_metadata.all_greedy:
|
||||
is_greedy = None
|
||||
else:
|
||||
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
# rejection_greedy_sample_kernel[(batch_size, )](
|
||||
rejection_greedy_sample_python(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
num_warps=1,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
else:
|
||||
# TODO
|
||||
raise ValueError('not support yet')
|
||||
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
uniform_probs = generate_uniform_probs(
|
||||
num_tokens,
|
||||
num_draft_tokens,
|
||||
sampling_metadata.generators,
|
||||
device,
|
||||
)
|
||||
|
||||
# Sample recovered tokens for each position.
|
||||
# [num_tokens]
|
||||
recovered_token_ids = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
num_draft_tokens,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device,
|
||||
)
|
||||
|
||||
# Rejection sampling for random sampling requests.
|
||||
rejection_random_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
num_warps=1,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
276
vllm_vacc/vllm/v1/sample/sampler.py
Normal file
276
vllm_vacc/vllm/v1/sample/sampler.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
|
||||
from .cached_pooler import BinCountTensorPooler
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Convert the different list data structures to tensors.
|
||||
"""
|
||||
output_tokens_tensor = make_tensor_with_pad(
|
||||
output_token_ids,
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
pad=vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32, # init with int32
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
return output_tokens_tensor.to(device, non_blocking=True)
|
||||
|
||||
class Sampler(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
is_first_calculate,
|
||||
req_ids: list[str],
|
||||
) -> SamplerOutput:
|
||||
# print(sampling_metadata.generators, len(sampling_metadata.generators), sampling_metadata.temperature_cpu)
|
||||
# sampling_metadata.temperature = sampling_metadata.temperature.to(logits.device)
|
||||
# sampling_metadata.top_p = sampling_metadata.top_p.to(logits.device)
|
||||
# sampling_metadata.top_k = sampling_metadata.top_k.to(logits.device)
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
# This is different from the V0 sampler, which uses the logits that
|
||||
# is used for sampling (after penalties and temperature scaling).
|
||||
# TODO(rob): provide option for logprobs post sampling.
|
||||
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Apply allowed token ids.
|
||||
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
||||
# Apply bad words exclusion.
|
||||
logits = self.apply_bad_words(logits, sampling_metadata)
|
||||
|
||||
# Apply logits processors which can impact greedy sampling
|
||||
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||
# logits = self.apply_penalties(logits, sampling_metadata)
|
||||
|
||||
|
||||
|
||||
if not sampling_metadata.no_penalties:
|
||||
if (not hasattr(self, "bin_count_pooler")):
|
||||
self.bin_count_pooler = BinCountTensorPooler(logits.shape[-1], logits.device)
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
buf_bin_buffer = self.bin_count_pooler.request_tensors(req_ids)
|
||||
batch, vocab_size = logits.shape
|
||||
logits_res = []
|
||||
|
||||
for i in range(batch):
|
||||
output_tokens_t = _convert_to_tensors([sampling_metadata.output_token_ids[i]], vocab_size, logits.device).to(torch.int32)
|
||||
output_tokens_t_lastone = output_tokens_t[:, -1:]
|
||||
if logits.shape[0] > 1:
|
||||
logits_i = torch.vacc.apply_penalties(logits[i:i+1],
|
||||
output_tokens_t_lastone,
|
||||
buf_bin_buffer[i],
|
||||
vocab_size,
|
||||
output_tokens_t_lastone.shape[-1],
|
||||
[sampling_metadata.frequency_penalties[i]],
|
||||
[sampling_metadata.presence_penalties[i]],
|
||||
is_first_calculate)
|
||||
else:
|
||||
logits_i = torch.vacc.apply_penalties(logits,
|
||||
output_tokens_t_lastone,
|
||||
buf_bin_buffer[i],
|
||||
vocab_size,
|
||||
output_tokens_t_lastone.shape[-1],
|
||||
[sampling_metadata.frequency_penalties[i]],
|
||||
[sampling_metadata.presence_penalties[i]],
|
||||
is_first_calculate)
|
||||
logits_res.append(logits_i)
|
||||
|
||||
if len(logits_res) > 1:
|
||||
logits = torch.concat(logits_res)
|
||||
else:
|
||||
logits = logits_res[0]
|
||||
|
||||
# Sample the next token.
|
||||
|
||||
# sampled = self.sample(logits, sampling_metadata)
|
||||
sampled, _ = torch.vacc.sampler_v1(logits, sampling_metadata.top_p_cpu, sampling_metadata.top_k_cpu, sampling_metadata.temperature_cpu, int(sampling_metadata.all_greedy), int(sampling_metadata.all_random), sampling_metadata.generators)
|
||||
|
||||
# Gather the logprobs of the topk and sampled token (if requested).
|
||||
# Get logprobs and rank tensors (if requested)
|
||||
logprobs_tensors = None if num_logprobs is None else \
|
||||
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled.long())
|
||||
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Sample logits based on sampling metadata.
|
||||
|
||||
The various logits processing functions called in this method
|
||||
may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
assert not (sampling_metadata.all_greedy
|
||||
and sampling_metadata.all_random)
|
||||
if sampling_metadata.all_random:
|
||||
greedy_sampled = None
|
||||
else:
|
||||
greedy_sampled = self.greedy_sample(logits)
|
||||
if sampling_metadata.all_greedy:
|
||||
return greedy_sampled
|
||||
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply logits processors that only apply to random sampling
|
||||
# (argmax invariant)
|
||||
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled = self.topk_topp_sampler(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
if greedy_sampled is None:
|
||||
return random_sampled
|
||||
|
||||
sampled = torch.where(
|
||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled,
|
||||
random_sampled,
|
||||
out=greedy_sampled, # Reuse tensor
|
||||
)
|
||||
return sampled
|
||||
|
||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
Must be int64.
|
||||
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == torch.int64
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_ids = token_ids.unsqueeze(-1)
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
# Concatenate together with the topk.
|
||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
indices = indices.to(torch.int32)
|
||||
|
||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if not sampling_metadata.no_penalties:
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
logits = apply_all_penalties(
|
||||
logits,
|
||||
sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.output_token_ids,
|
||||
)
|
||||
return logits
|
||||
|
||||
def apply_allowed_token_ids(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.allowed_token_ids_mask is not None:
|
||||
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
||||
float("-inf"))
|
||||
return logits
|
||||
|
||||
def apply_bad_words(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.bad_words_token_ids:
|
||||
apply_bad_words(
|
||||
logits,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.output_token_ids,
|
||||
)
|
||||
return logits
|
||||
0
vllm_vacc/vllm/v1/spec_decode/__init__.py
Normal file
0
vllm_vacc/vllm/v1/spec_decode/__init__.py
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/spec_decode/__pycache__/eagle.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/spec_decode/__pycache__/eagle.cpython-312.pyc
Normal file
Binary file not shown.
789
vllm_vacc/vllm/v1/spec_decode/eagle.py
Normal file
789
vllm_vacc/vllm/v1/spec_decode/eagle.py
Normal file
@@ -0,0 +1,789 @@
|
||||
|
||||
import ast
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
||||
TreeAttentionMetadataBuilder)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
from vacc_tools.trace_logger import get_trace_api
|
||||
trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
|
||||
get_trace_api("deepseek")
|
||||
)
|
||||
|
||||
# @trace_time('prepare_eagle_input_python')
|
||||
def prepare_eagle_input_python(
|
||||
out_ptr,
|
||||
cu_query_lens_ptr,
|
||||
cu_num_tokens_ptr,
|
||||
# BLOCK_SIZE
|
||||
):
|
||||
"""
|
||||
Python实现版本的prepare_eagle_input_kernel
|
||||
|
||||
参数:
|
||||
out_ptr: 输出张量
|
||||
cu_query_lens_ptr: 每个查询的起始索引张量
|
||||
cu_num_tokens_ptr: 每个查询的token数量累计张量
|
||||
BLOCK_SIZE: 块大小
|
||||
"""
|
||||
cu_query_lens_ptr_list = cu_query_lens_ptr
|
||||
cu_num_tokens_ptr_list = cu_num_tokens_ptr
|
||||
num_queries = len(cu_num_tokens_ptr) - 1
|
||||
|
||||
# out_ptr_list = np.zeros(cu_query_lens_ptr_list.shape, cu_query_lens_ptr_list.dtype)
|
||||
for pid in range(num_queries):
|
||||
start_pos = cu_num_tokens_ptr_list[pid]#.item()
|
||||
end_pos = cu_num_tokens_ptr_list[pid + 1]#.item()
|
||||
num_tokens = end_pos - start_pos
|
||||
|
||||
index_start = cu_query_lens_ptr_list[pid]#.item()
|
||||
|
||||
# offset = np.array([i for i in range(num_tokens)], dtype=cu_num_tokens_ptr_list.dtype)
|
||||
# values = index_start + offset
|
||||
# 存储到输出张量
|
||||
# out_ptr[start_pos + offset] = values
|
||||
|
||||
for i in range(num_tokens):
|
||||
out_ptr[start_pos + i] = index_start + i
|
||||
|
||||
return
|
||||
import torch
|
||||
num_queries = len(cu_num_tokens_ptr) - 1
|
||||
|
||||
for pid in range(num_queries):
|
||||
# [start_pos, end_pos)
|
||||
start_pos = cu_num_tokens_ptr[pid].item()
|
||||
end_pos = cu_num_tokens_ptr[pid + 1].item()
|
||||
num_tokens = end_pos - start_pos
|
||||
|
||||
index_start = cu_query_lens_ptr[pid].item()
|
||||
|
||||
num_blocks = (num_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE
|
||||
|
||||
for i in range(num_blocks):
|
||||
offset_start = i * BLOCK_SIZE
|
||||
offset_end = min(offset_start + BLOCK_SIZE, num_tokens)
|
||||
|
||||
# 创建当前块的偏移量
|
||||
offset = torch.arange(offset_start, offset_end, device=out_ptr.device, dtype=out_ptr.dtype)
|
||||
|
||||
# 计算要存储的值
|
||||
values = index_start + offset
|
||||
|
||||
# 存储到输出张量
|
||||
out_ptr[start_pos + offset] = values
|
||||
|
||||
class EagleProposer:
|
||||
|
||||
# @trace_time('EagleProposer_propose')
|
||||
def propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
target_token_ids: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_positions: torch.Tensor,
|
||||
# [num_tokens, hidden_size]
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
last_token_indices: Optional[torch.Tensor],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embeds: Optional[list[torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
|
||||
if last_token_indices is None:
|
||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
if self.method == "eagle3":
|
||||
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||
target_hidden_states = self.model.combine_hidden_states(
|
||||
target_hidden_states)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
# self.input_ids[last_token_indices] = next_token_ids
|
||||
if isinstance(last_token_indices, list) and len(last_token_indices) == 1:
|
||||
self.input_ids[last_token_indices[0] : last_token_indices[0]+1] = next_token_ids
|
||||
else:
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
ubatch_id = dbo_current_ubatch_id()
|
||||
attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0)
|
||||
# FIXME: support hybrid kv for draft model (remove separate indexer)
|
||||
if self.draft_indexer_metadata_builder:
|
||||
draft_indexer_metadata = (
|
||||
self.draft_indexer_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0,
|
||||
))
|
||||
else:
|
||||
draft_indexer_metadata = None
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
for layer_name in self.indexer_layer_names:
|
||||
assert draft_indexer_metadata is not None
|
||||
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||||
|
||||
num_input_tokens = num_tokens
|
||||
# copy inputs to buffer for cudagraph
|
||||
# self.positions[:num_tokens] = target_positions
|
||||
# self.hidden_states[:num_tokens] = target_hidden_states
|
||||
if self.is_multimodal_model:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=mm_embeds or None,
|
||||
)
|
||||
self.inputs_embeds[:num_tokens] = inputs_embeds
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=target_positions,
|
||||
hidden_states=target_hidden_states, #self.hidden_states[:num_input_tokens],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
# sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
|
||||
if isinstance(last_token_indices, list):
|
||||
if len(last_token_indices) == last_hidden_states.shape[0]:
|
||||
sample_hidden_states = last_hidden_states
|
||||
elif len(last_token_indices) == 1:
|
||||
sample_hidden_states = last_hidden_states[last_token_indices[0] : last_token_indices[0] + 1]
|
||||
else:
|
||||
sample_hidden_states = last_hidden_states.new_empty(last_token_indices.shape + last_hidden_states.shape[1:])
|
||||
torch.ops.aten.index(hidden_states, [torch.tensor(last_token_indices, dtype=torch.int32)], out=sample_hidden_states)
|
||||
else:
|
||||
assert isinstance(last_token_indices, torch.Tensor)
|
||||
if last_token_indices.shape[0] == last_hidden_states.shape[0]:
|
||||
sample_hidden_states = last_hidden_states
|
||||
else:
|
||||
sample_hidden_states = last_hidden_states.new_empty(last_token_indices.shape + last_hidden_states.shape[1:])
|
||||
torch.ops.aten.index(hidden_states, [last_token_indices], out=sample_hidden_states)
|
||||
|
||||
|
||||
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1:
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
return draft_token_ids.view(-1, 1)
|
||||
else:
|
||||
raise ValueError(f'not support self.num_speculative_tokens > 1, but get {self.num_speculative_tokens}')
|
||||
|
||||
'''
|
||||
positions = target_positions[last_token_indices]
|
||||
if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
|
||||
hidden_states = None #self.hidden_states[last_token_indices]
|
||||
else:
|
||||
hidden_states = None #hidden_states[last_token_indices]
|
||||
|
||||
if isinstance(attn_metadata, TreeAttentionMetadata):
|
||||
# Draft using tree attention.
|
||||
draft_token_ids_list = self.propose_tree(
|
||||
batch_size=batch_size,
|
||||
logits=logits,
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
# [batch_size, num_tree_tokens]
|
||||
return torch.cat(draft_token_ids_list, dim=1)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
if self.allowed_attn_types is not None and \
|
||||
not isinstance(attn_metadata, self.allowed_attn_types):
|
||||
raise ValueError(
|
||||
f"Unsupported attention metadata type for speculative "
|
||||
"decoding with num_speculative_tokens > 1: "
|
||||
f"{type(attn_metadata)}. Supported types are: "
|
||||
f"{self.allowed_attn_types}")
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
if self.use_cuda_graph and \
|
||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
|
||||
common_attn_metadata.num_actual_tokens = batch_size
|
||||
common_attn_metadata.max_query_len = 1
|
||||
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||||
self.token_arange_np[:batch_size + 1]).clone()
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
# but adjust the position ids and slot mappings to avoid the
|
||||
# out-of-range access during the model execution. The draft tokens
|
||||
# generated with this adjustment should be ignored.
|
||||
exceeds_max_model_len = positions >= self.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
|
||||
# Increment the sequence lengths.
|
||||
common_attn_metadata.seq_lens += 1
|
||||
common_attn_metadata.seq_lens_cpu += 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
|
||||
1)
|
||||
|
||||
common_attn_metadata.num_computed_tokens_cpu = \
|
||||
common_attn_metadata.seq_lens_cpu - 1
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_numbers = clamped_positions // self.block_size
|
||||
block_ids = common_attn_metadata.block_table_tensor.gather(
|
||||
dim=1, index=block_numbers.view(-1, 1))
|
||||
block_ids = block_ids.view(-1)
|
||||
common_attn_metadata.slot_mapping = (
|
||||
block_ids * self.block_size +
|
||||
clamped_positions % self.block_size)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
common_attn_metadata.slot_mapping.masked_fill_(
|
||||
exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1)
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
# self.hidden_states[:batch_size] = hidden_states
|
||||
if self.is_multimodal_model:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
self.inputs_embeds[:batch_size] = inputs_embeds
|
||||
inputs_embeds = self.inputs_embeds[:input_batch_size]
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
input_ids = self.input_ids[:input_batch_size]
|
||||
|
||||
# Run the model.
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.positions[:input_batch_size],
|
||||
# hidden_states=self.hidden_states[:input_batch_size],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size])
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
# [batch_size, num_speculative_tokens]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
'''
|
||||
|
||||
def prepare_next_token_ids_padded(self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
discard_request_indices: torch.Tensor,
|
||||
num_discarded_requests: int) -> \
|
||||
tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It calculates the next token ids and the number of valid sampled tokens
|
||||
for each request, considering the "discarded" requests whose next token
|
||||
is not sampled and comes from `request.get_token_id()` instead.
|
||||
It also accounts for the rejected tokens in `sampled_token_ids`.
|
||||
This function must use device functions to operate on the inputs, and
|
||||
should not introduce any blocking CPU-GPU synchronization.
|
||||
"""
|
||||
# TODO(Ben): Combine this into a custom fused kernel
|
||||
|
||||
# Precompute get_token_id for when there is no valid next token
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
self.backup_next_token_ids.np[:num_reqs] = np.array([
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||||
common_attn_metadata.seq_lens_cpu[i])
|
||||
for i in range(num_reqs)
|
||||
])
|
||||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||||
|
||||
# Mask out the sampled tokens indices that should not be sampled.
|
||||
discard_sampled_tokens_req_indices = \
|
||||
discard_request_indices[:num_discarded_requests]
|
||||
|
||||
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
||||
valid_sampled_token_ids_gpu.index_fill_(
|
||||
0, discard_sampled_tokens_req_indices, -1)
|
||||
|
||||
# Generate a mask for all valid tokens within those requests
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
|
||||
dtype=torch.bool)
|
||||
else:
|
||||
valid_mask = (
|
||||
(valid_sampled_token_ids_gpu != -1) &
|
||||
(valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))
|
||||
|
||||
# Count the number of valid tokens in each request
|
||||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||||
|
||||
# Get the rightmost valid index per row
|
||||
last_valid_indices = valid_sampled_tokens_count - 1
|
||||
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
||||
|
||||
# Get last valid token from each row
|
||||
# (assume undefined state where there is no valid token)
|
||||
selected_tokens = torch.gather(
|
||||
valid_sampled_token_ids_gpu, 1,
|
||||
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Use last token if valid, pre-computed backup if not
|
||||
batch_size = valid_sampled_token_ids_gpu.shape[0]
|
||||
next_token_ids = torch.where(
|
||||
last_valid_indices != -1, selected_tokens,
|
||||
self.backup_next_token_ids.gpu[:batch_size])
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
# @trace_time('prepare_inputs_padded')
|
||||
def prepare_inputs_padded(self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
valid_sampled_tokens_count: torch.Tensor) -> \
|
||||
tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding
|
||||
It updates the common_attn_metadata for speculative decoding,
|
||||
but does not consider the rejected tokens. Instead, all tokens
|
||||
are included as inputs to the speculator, with the rejected tokens
|
||||
used as padding and filtered out later by `token_indices_to_sample`.
|
||||
No blocking CPU operations should be introduced in this function.
|
||||
"""
|
||||
num_draft_tokens_gpu = torch.cat([
|
||||
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
||||
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
||||
spec_decode_metadata.cu_num_draft_tokens[:-1]
|
||||
])
|
||||
|
||||
num_rejected_tokens_gpu = torch.where(
|
||||
num_draft_tokens_gpu > 0,
|
||||
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
||||
torch.zeros_like(num_draft_tokens_gpu))
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
new_query_len_per_req = (query_start_loc_cpu[1:] -
|
||||
query_start_loc_cpu[:-1])
|
||||
|
||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||
token_indices = self.arange[:total_num_tokens]
|
||||
|
||||
spec_common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
||||
num_computed_tokens_cpu=common_attn_metadata.
|
||||
num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
max_seq_len=max(common_attn_metadata.seq_lens_cpu),
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
||||
causal=True,
|
||||
)
|
||||
|
||||
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \
|
||||
- num_rejected_tokens_gpu
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
|
||||
# @trace_time('prepare_inputs')
|
||||
def prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_draft_tokens: list[int],
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It updates to the common_attn_metadata to account for the rejected
|
||||
tokens (and newly sampled tokens). It also returns the token indices
|
||||
of the tokens that should be fed to the speculator.
|
||||
"""
|
||||
# E.g.
|
||||
# common_attn_metadata.query_start_loc{_cpu}:
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3]
|
||||
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
|
||||
# num_rejected_tokens: [n1, n2, n3]
|
||||
# This function computes the intermediate values:
|
||||
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
|
||||
# And returns:
|
||||
# common_attn_metadata.query_start_loc{_cpu}:
|
||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||
# common_attn_metadata.seq_lens{_cpu}:
|
||||
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
|
||||
# token_indices: [0, 1, ..., q1 - n1 - 1,
|
||||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||||
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
# num_rejected_tokens = torch.tensor(num_rejected_tokens,
|
||||
# dtype=torch.int32)
|
||||
|
||||
# device = common_attn_metadata.query_start_loc.device
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
# new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
|
||||
# - num_rejected_tokens
|
||||
# new_seq_lens_cpu = [i-j for i,j in zip(common_attn_metadata.seq_lens_cpu, num_rejected_tokens)]
|
||||
|
||||
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
||||
# new_query_len_per_req = (query_start_loc_cpu[1:] -
|
||||
# query_start_loc_cpu[:-1])
|
||||
# new_query_len_per_req = [query_start_loc_cpu[i+1] - query_start_loc_cpu[i] for i in range(len(query_start_loc_cpu)-1)]
|
||||
new_query_len_per_req = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]).tolist() # [2] *(bs+1)
|
||||
|
||||
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
|
||||
# new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
|
||||
new_num_tokens_per_req = [i-j for i,j in zip(new_query_len_per_req, num_rejected_tokens)]
|
||||
|
||||
new_num_tokens_per_req_np = np.array(new_num_tokens_per_req)
|
||||
|
||||
# common_attn_metadata.seq_lens_cpu is list[int], length is max_seq_num,
|
||||
# seq_lens_cpu come from VACCModelRunner _prepare_inputs, 预先加了 k + 1
|
||||
# if seq=[31,63] bs=2, max_seq_num=4, seq_lens_cpu=[31,63,0,0]
|
||||
# new_seq_lens_cpu need all real seq,
|
||||
# if num_rejected_tokens=[0,1] new_seq_lens_cpu = [30,31,62] means bs=1 接受了, bs=2拒绝了 只有一个recover_token
|
||||
# if num_rejected_tokens=[1,0] new_seq_lens_cpu = [30,62,63] means bs=2 接受了, bs=1拒绝了 只有一个recover_token
|
||||
# if num_rejected_tokens=[0,0] new_seq_lens_cpu = [30,31,62,63] means 都接受了
|
||||
# if num_rejected_tokens=[1,1] new_seq_lens_cpu = [30,62] means 都拒绝了
|
||||
new_seq_lens_cpu = []
|
||||
for i in range(len(num_rejected_tokens)):
|
||||
for j in range(new_num_tokens_per_req[i]):
|
||||
new_seq_lens_cpu.append(common_attn_metadata.seq_lens_cpu[i] - new_num_tokens_per_req[i] + 1 - num_rejected_tokens[i] + j)
|
||||
|
||||
# [q1 - n1, q2 - n2, q3 - n3] ->
|
||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||
new_query_start_loc_cpu = torch.zeros(
|
||||
query_start_loc_cpu.shape,
|
||||
dtype=torch.int32,
|
||||
pin_memory=is_pin_memory_available())
|
||||
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
|
||||
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
|
||||
|
||||
total_num_tokens = new_query_start_loc_np[-1]
|
||||
# Example assuming num_tokens_per_req_np = [2, 4, 3]
|
||||
# this implies that `new_query_start_locs` is:
|
||||
# [0, 2, 6, 9] ->
|
||||
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
|
||||
# _r1_ ____r2____ ___r3__
|
||||
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
|
||||
new_num_tokens_per_req_np)
|
||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
|
||||
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
|
||||
# _r1_ ____r2____ ___r3__
|
||||
token_offests = self.token_arange_np[:total_num_tokens] \
|
||||
- new_query_start_locs_expanded
|
||||
|
||||
# Expand starting positions to match token pattern
|
||||
# [0, q1, q1 + q2] ->
|
||||
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
|
||||
# _r1_ _____r2_______ ___________r3____________
|
||||
old_query_start_locs_expanded = np.repeat(
|
||||
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
|
||||
# Final token indices are:
|
||||
# [0, 1, // req 1
|
||||
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
||||
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
||||
token_indices_np = token_offests + old_query_start_locs_expanded
|
||||
token_indices = token_indices_np.tolist()
|
||||
# token_indices = torch.from_numpy(token_indices_np).to(
|
||||
# device, non_blocking=True)
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# print('token_indices', token_indices, common_attn_metadata.slot_mapping.shape)
|
||||
# [0] or [0,1] bs1
|
||||
# [0, 2, 4, 6] ... or [0, 1, 2, 3, 4, 6, 7] for bs4
|
||||
|
||||
# opt slot_mapping slice
|
||||
#common_attn_metadata.slot_mapping[token_indices] : copy + copy + index_out
|
||||
if len(token_indices) == common_attn_metadata.slot_mapping.shape[0]: # no need slice
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
elif len(token_indices) == 1:
|
||||
slot_mapping = common_attn_metadata.slot_mapping[token_indices[0] : token_indices[0] + 1]
|
||||
else:
|
||||
slot_mapping = common_attn_metadata.slot_mapping.new_empty(len(token_indices))
|
||||
torch.ops.aten.index(common_attn_metadata.slot_mapping, [torch.tensor(token_indices, dtype=torch.int32)], out=slot_mapping)
|
||||
|
||||
spec_common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=new_query_start_loc_cpu,
|
||||
# query_start_loc=new_query_start_loc_cpu.to(device,
|
||||
# non_blocking=True),
|
||||
seq_lens=new_seq_lens_cpu, #.to(device, non_blocking=True),
|
||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||
seq_lens_cpu=new_seq_lens_cpu,
|
||||
num_computed_tokens_cpu=common_attn_metadata.
|
||||
num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
max_query_len=max(new_query_len_per_req),#new_query_len_per_req.max().item(),
|
||||
max_seq_len=None, #max(new_seq_lens_cpu),#new_seq_lens_cpu.max().item(),
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=slot_mapping, #common_attn_metadata.slot_mapping[token_indices],
|
||||
causal=True,
|
||||
)
|
||||
|
||||
return spec_common_attn_metadata, token_indices
|
||||
|
||||
|
||||
|
||||
# @trace_time('EagleProposer_prepare_inputs')
|
||||
@staticmethod
|
||||
def prepare_inputs_9_2(
|
||||
self,
|
||||
# [batch_size + 1]
|
||||
cu_target_query_lens: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_rejected_tokens: torch.Tensor,
|
||||
num_tokens: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
||||
# num_rejected_tokens: [n1, n2, n3]
|
||||
# num_tokens_per_req: [a - n1, b - n2, c - n3]
|
||||
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||
# token_indices: [0, 1, ..., a - n1 - 1,
|
||||
# a, a + 1, ..., a + b - n2 - 1,
|
||||
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
||||
|
||||
# [0, a, a + b, a + b + c] -> [a, b, c]
|
||||
# torch
|
||||
# query_len_per_req = (cu_target_query_lens[1:] -
|
||||
# cu_target_query_lens[:-1])
|
||||
# [a, b, c] -> [a - n1, b - n2, c - n3]
|
||||
# num_tokens_per_req = query_len_per_req - num_rejected_tokens
|
||||
|
||||
# list
|
||||
num_tokens_per_req = [cu_target_query_lens[i+1] - cu_target_query_lens[i] - num_rejected_tokens[i] for i in range(len(cu_target_query_lens)-1)]
|
||||
|
||||
|
||||
# [a - n1, b - n2, c - n3] ->
|
||||
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||
# torch style
|
||||
# cu_num_tokens = torch.zeros_like(cu_target_query_lens)
|
||||
# torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
|
||||
|
||||
# list style
|
||||
cu_num_tokens = [0] * len(cu_target_query_lens)
|
||||
for i in range(len(cu_target_query_lens)-1):
|
||||
cu_num_tokens[i+1] = cu_num_tokens[i] + num_tokens_per_req[i]
|
||||
|
||||
# token_indices = torch.empty(
|
||||
# num_tokens,
|
||||
# dtype=torch.int32,
|
||||
# device=cu_target_query_lens.device,
|
||||
# )
|
||||
# batch_size = num_rejected_tokens.shape[0]
|
||||
# BLOCK_SIZE = 1024
|
||||
# prepare_eagle_input_kernel[(batch_size, )](
|
||||
# token_indices,
|
||||
# cu_target_query_lens,
|
||||
# cu_num_tokens,
|
||||
# BLOCK_SIZE=BLOCK_SIZE,
|
||||
# )
|
||||
token_indices = [0] * num_tokens
|
||||
prepare_eagle_input_python(
|
||||
token_indices,
|
||||
cu_target_query_lens,
|
||||
cu_num_tokens
|
||||
)
|
||||
|
||||
|
||||
return cu_num_tokens, token_indices
|
||||
|
||||
def EagleProposer_init_(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
assert self.speculative_config is not None
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
self.method = self.speculative_config.method
|
||||
|
||||
self.runner = runner
|
||||
self.device = device
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.num_speculative_tokens = (
|
||||
self.speculative_config.num_speculative_tokens)
|
||||
self.max_num_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
||||
# We need to get the hidden size from the draft model config because
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
|
||||
self.is_multimodal_model = vllm_config.model_config \
|
||||
.is_multimodal_model
|
||||
|
||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||
self.draft_indexer_metadata_builder: Optional[
|
||||
AttentionMetadataBuilder] = None
|
||||
self.attn_layer_names: list[str] = []
|
||||
self.indexer_layer_names: list[str] = []
|
||||
|
||||
self.use_cuda_graph = False
|
||||
self.cudagraph_batch_sizes = []
|
||||
|
||||
|
||||
# persistent buffers for cuda graph
|
||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.positions = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
# self.hidden_states = torch.zeros(
|
||||
# (self.max_num_tokens, self.hidden_size),
|
||||
# dtype=self.dtype,
|
||||
# device=device)
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
|
||||
self.arange = torch.arange(max_num_slots_for_arange,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
else:
|
||||
self.inputs_embeds = None
|
||||
|
||||
self.backup_next_token_ids = CpuGpuBuffer(
|
||||
max_batch_size,
|
||||
dtype=torch.int32,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
device=device,
|
||||
with_numpy=True)
|
||||
|
||||
# Determine allowed attention backends once during initialization.
|
||||
self.allowed_attn_types: Optional[tuple] = None
|
||||
|
||||
# Parse the speculative token tree.
|
||||
spec_token_tree = self.speculative_config.speculative_token_tree
|
||||
self.tree_choices: list[tuple[int,
|
||||
...]] = ast.literal_eval(spec_token_tree)
|
||||
tree_depth = len(self.tree_choices[-1])
|
||||
# Precompute per-level properties of the tree.
|
||||
num_drafts_per_level = [0] * tree_depth
|
||||
for node in self.tree_choices:
|
||||
num_drafts_per_level[len(node) - 1] += 1
|
||||
self.cu_drafts_per_level = [num_drafts_per_level[0]]
|
||||
self.child_drafts_per_level = [num_drafts_per_level[0]]
|
||||
for level in range(1, tree_depth):
|
||||
self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] +
|
||||
num_drafts_per_level[level])
|
||||
self.child_drafts_per_level.append(num_drafts_per_level[level] //
|
||||
num_drafts_per_level[level - 1])
|
||||
# Precompute draft position offsets in flattened tree.
|
||||
self.tree_draft_pos_offsets = torch.arange(
|
||||
1,
|
||||
len(self.tree_choices) + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
).repeat(max_batch_size, 1)
|
||||
|
||||
0
vllm_vacc/vllm/v1/worker/__init__.py
Normal file
0
vllm_vacc/vllm/v1/worker/__init__.py
Normal file
BIN
vllm_vacc/vllm/v1/worker/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/worker/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/v1/worker/__pycache__/block_table.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/worker/__pycache__/block_table.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_vacc/vllm/v1/worker/__pycache__/vacc_worker.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/worker/__pycache__/vacc_worker.cpython-312.pyc
Normal file
Binary file not shown.
68
vllm_vacc/vllm/v1/worker/block_table.py
Normal file
68
vllm_vacc/vllm/v1/worker/block_table.py
Normal file
@@ -0,0 +1,68 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockTable:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
max_num_reqs: int,
|
||||
max_num_blocks_per_req: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.block_size = block_size
|
||||
self.max_num_reqs = max_num_reqs
|
||||
# self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
max_num_blocks_per_req = (max_num_blocks_per_req + env_blk_grp_size//16 - 1) // (env_blk_grp_size//16) * (env_blk_grp_size//16)
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
# self.block_table = torch.zeros(
|
||||
# (max_num_reqs, max_num_blocks_per_req),
|
||||
# device=self.device,
|
||||
# dtype=torch.int32,
|
||||
# )
|
||||
self.block_table = self._make_buffer(max_num_reqs,
|
||||
max_num_blocks_per_req,
|
||||
dtype=torch.int32)
|
||||
# self.block_table_cpu = torch.zeros(
|
||||
# (max_num_reqs, max_num_blocks_per_req),
|
||||
# device="cpu",
|
||||
# dtype=torch.int32,
|
||||
# pin_memory=pin_memory,
|
||||
# )
|
||||
# self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
# self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
|
||||
# dtype=torch.int32,
|
||||
# device="cpu",
|
||||
# pin_memory=self.pin_memory)
|
||||
# self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
# self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||
# dtype=torch.int32,
|
||||
# device=self.device)
|
||||
self.slot_mapping = self._make_buffer(self.max_num_batched_tokens,
|
||||
dtype=torch.int32)
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
491
vllm_vacc/vllm/v1/worker/gpu_input_batch.py
Normal file
491
vllm_vacc/vllm/v1/worker/gpu_input_batch.py
Normal file
@@ -0,0 +1,491 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Datastructures defining a GPU input batch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
|
||||
LogitsProcessors,
|
||||
MoveDirectionality)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestState:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: Optional[list[int]]
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
generator: Optional[torch.Generator]
|
||||
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
output_token_ids: list[int]
|
||||
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
prompt_embeds: Optional[torch.Tensor] = None
|
||||
deepstack_input_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds)
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
||||
|
||||
# Temporary back-compatibility for plugins that define model runner
|
||||
@property
|
||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||
def mm_inputs(self) -> list[MultiModalKwargsItems]:
|
||||
return [
|
||||
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
|
||||
if f.data is not None
|
||||
]
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
if self.prompt_token_ids is None:
|
||||
raise ValueError(
|
||||
f"Tried to access token index {idx}, but that token was "
|
||||
"provided via prompt_embeds, and its ID is unknown.")
|
||||
return self.prompt_token_ids[idx]
|
||||
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
|
||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||
else:
|
||||
return -1
|
||||
class InputBatch:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self._req_ids: list[Optional[str]] = []
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
|
||||
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
||||
# Find a way to reduce the CPU memory usage.
|
||||
# This buffer is not directly transferred to the GPU, so it does not
|
||||
# need to be pinned.
|
||||
self.token_ids_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=bool,
|
||||
pin_memory=False)
|
||||
# Store prompt embeddings per request to avoid OOM from large upfront
|
||||
# allocation if max_model_len is big.
|
||||
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
|
||||
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
|
||||
#patch to add req_deepstack_input_embeds
|
||||
self.req_deepstack_input_embeds: dict[int, torch.Tensor] = {}
|
||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, ),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.num_computed_tokens_cpu = \
|
||||
self.num_computed_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Block table.
|
||||
self.block_table = MultiGroupBlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
block_sizes=block_sizes,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
||||
self.greedy_reqs: set[str] = set()
|
||||
self.random_reqs: set[str] = set()
|
||||
|
||||
self.top_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
||||
self.top_p_reqs: set[str] = set()
|
||||
|
||||
self.top_k = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: set[str] = set()
|
||||
|
||||
# IDs of requests which do not support spec decoding
|
||||
self.spec_decode_unsupported_reqs: set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
# self.frequency_penalties = torch.empty((max_num_reqs, ),
|
||||
# dtype=torch.float,
|
||||
# device=device)
|
||||
self.frequency_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.frequency_penalties_cpu = \
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_reqs: set[str] = set()
|
||||
|
||||
# Presence penalty related data structures
|
||||
# self.presence_penalties = torch.empty((max_num_reqs, ),
|
||||
# dtype=torch.float,
|
||||
# device=device)
|
||||
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
||||
)
|
||||
self.presence_penalties_reqs: set[str] = set()
|
||||
|
||||
# Repetition penalty related data structures
|
||||
# self.repetition_penalties = torch.empty((max_num_reqs, ),
|
||||
# dtype=torch.float,
|
||||
# device=device)
|
||||
self.repetition_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.repetition_penalties_cpu = \
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# Speculative decoding
|
||||
self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ),
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.num_accepted_tokens_cpu = \
|
||||
self.num_accepted_tokens_cpu_tensor.numpy()
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||
dtype=np.int32)
|
||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
||||
|
||||
# req_index -> generator
|
||||
# NOTE(woosuk): The indices of the requests that do not have their own
|
||||
# generator should not be included in the dictionary.
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: dict[str, int] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
|
||||
# Internal representation of per-step batch state changes, used for
|
||||
# reordering persistent batch and generating logitsprocs batch state
|
||||
# updates. Should reset each step.
|
||||
self.batch_update_builder = BatchUpdateBuilder()
|
||||
|
||||
# TODO convert this to LogitsProcessor
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
||||
# the value is False. Since we use masked_fill_ to set -inf.
|
||||
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||
|
||||
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
|
||||
dtype=bool)
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
# Store provided logitsprocs. If none are provided, initialize empty
|
||||
# data structure
|
||||
self.logitsprocs = logitsprocs or LogitsProcessors()
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
self.pooling_params: dict[str, PoolingParams] = {}
|
||||
|
||||
# Cached reference to the GPU tensor of previously sampled tokens
|
||||
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
||||
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
) -> int:
|
||||
req_index = self._register_add_request(request)
|
||||
|
||||
req_id = request.req_id
|
||||
if req_index == len(self._req_ids):
|
||||
self._req_ids.append(req_id)
|
||||
self.req_output_token_ids.append(request.output_token_ids)
|
||||
else:
|
||||
self._req_ids[req_index] = req_id
|
||||
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
request.prompt_token_ids, request.prompt_embeds)
|
||||
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
||||
start_idx = num_prompt_tokens
|
||||
end_idx = start_idx + len(request.output_token_ids)
|
||||
if request.prompt_token_ids is not None:
|
||||
self.token_ids_cpu[
|
||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
self.is_token_ids[req_index, :num_prompt_tokens] = True
|
||||
else:
|
||||
self.is_token_ids[req_index, :num_prompt_tokens] = False
|
||||
if request.prompt_embeds is not None:
|
||||
self.req_prompt_embeds[req_index] = request.prompt_embeds
|
||||
#patch to add req_deepstack_input_embeds
|
||||
if request.deepstack_input_embeds is not None:
|
||||
self.req_deepstack_input_embeds[req_index] = request.deepstack_input_embeds
|
||||
self.token_ids_cpu[req_index,
|
||||
start_idx:end_idx] = request.output_token_ids
|
||||
self.is_token_ids[req_index, start_idx:end_idx] = True
|
||||
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
|
||||
# NOTE(woosuk): This may include spec decode tokens.
|
||||
self.num_tokens[req_index] = request.num_tokens
|
||||
# Number of tokens without spec decode tokens.
|
||||
self.num_tokens_no_spec[req_index] = request.num_tokens
|
||||
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
if sampling_params := request.sampling_params:
|
||||
if (self.is_spec_decode
|
||||
and is_spec_decode_unsupported(sampling_params)):
|
||||
self.spec_decode_unsupported_reqs.add(req_id)
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Should avoid division by zero later when apply_temperature.
|
||||
self.temperature_cpu[req_index] = 0.0
|
||||
self.greedy_reqs.add(req_id)
|
||||
else:
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1:
|
||||
self.top_p_reqs.add(req_id)
|
||||
top_k = sampling_params.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_reqs.add(req_id)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k_cpu[req_index] = top_k
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[
|
||||
req_index] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
if request.generator is not None:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = (self.vocab_size
|
||||
if sampling_params.logprobs == -1
|
||||
else sampling_params.logprobs)
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = (
|
||||
self.vocab_size if sampling_params.prompt_logprobs == -1
|
||||
else sampling_params.prompt_logprobs)
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||
# Lazy allocation for this tensor, which can be large.
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device)
|
||||
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device="cpu")
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids] = False
|
||||
|
||||
if sampling_params.bad_words_token_ids:
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
elif pooling_params := request.pooling_params:
|
||||
self.pooling_params[req_id] = pooling_params
|
||||
self.logits_processing_needs_token_ids[req_index] = (
|
||||
pooling_params.requires_token_ids)
|
||||
else:
|
||||
raise NotImplementedError("Unrecognized request type")
|
||||
|
||||
# Speculative decoding: by default 1 token is generated.
|
||||
self.num_accepted_tokens_cpu[req_index] = 1
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
if lora_id not in self.lora_id_to_request_ids:
|
||||
self.lora_id_to_request_ids[lora_id] = set()
|
||||
|
||||
self.request_lora_mapping[req_index] = lora_id
|
||||
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
||||
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
||||
else:
|
||||
# No LoRA
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
return req_index
|
||||
|
||||
|
||||
def _make_sampling_metadata(self) -> SamplingMetadata:
|
||||
num_reqs = self.num_reqs
|
||||
# if not self.all_greedy:
|
||||
# temperature = copy_slice(self.temperature_cpu_tensor,
|
||||
# self.temperature, num_reqs)
|
||||
# else:
|
||||
# temperature = None
|
||||
# if not self.no_top_p:
|
||||
# copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
||||
# if not self.no_top_k:
|
||||
# copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
||||
|
||||
# if not self.no_penalties:
|
||||
# # Since syncing these tensors is expensive only copy them
|
||||
# # if necessary i.e. if there are requests which require
|
||||
# # penalties to be applied during sampling.
|
||||
# copy_slice(self.frequency_penalties_cpu_tensor,
|
||||
# self.frequency_penalties, num_reqs)
|
||||
# copy_slice(self.presence_penalties_cpu_tensor,
|
||||
# self.presence_penalties, num_reqs)
|
||||
# copy_slice(self.repetition_penalties_cpu_tensor,
|
||||
# self.repetition_penalties, num_reqs)
|
||||
|
||||
needs_prompt_token_ids = (
|
||||
not self.no_penalties
|
||||
or self.logits_processing_needs_token_ids[:num_reqs].any())
|
||||
if needs_prompt_token_ids:
|
||||
# The prompt tokens are used only for applying penalties or
|
||||
# step pooling during the sampling/pooling process.
|
||||
# Hence copy these tensors only when there are requests which
|
||||
# need penalties/step_pooler to be applied.
|
||||
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
||||
else:
|
||||
prompt_token_ids = None
|
||||
|
||||
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
if not self.no_allowed_token_ids:
|
||||
assert self.allowed_token_ids_mask is not None
|
||||
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
|
||||
self.allowed_token_ids_mask, num_reqs)
|
||||
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=self.temperature_cpu_tensor[:num_reqs].to(self.device),
|
||||
all_greedy=self.all_greedy,
|
||||
all_random=self.all_random,
|
||||
# top_p=None if self.no_top_p else self.top_p_cpu_tensor[:num_reqs].to(self.device),
|
||||
# top_k=None if self.no_top_k else torch.tensor([40 for _ in range(num_reqs)]).to(torch.int32).to(self.device),
|
||||
top_p=torch.tensor([1 for _ in range(num_reqs)]).to(torch.float32).to(self.device) if self.no_top_p else self.top_p_cpu_tensor[:num_reqs].to(self.device),
|
||||
top_k=torch.tensor([40 for _ in range(num_reqs)]).to(torch.int32).to(self.device),
|
||||
generators=self.generators,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
frequency_penalties=self.frequency_penalties_cpu_tensor[:num_reqs].tolist(),
|
||||
presence_penalties=self.presence_penalties_cpu_tensor[:num_reqs].tolist(),
|
||||
repetition_penalties=self.repetition_penalties_cpu_tensor[:num_reqs].tolist(),
|
||||
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
||||
no_penalties=self.no_penalties,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
logitsprocs=self.logitsprocs,
|
||||
temperature_cpu = self.temperature_cpu_tensor[:num_reqs],
|
||||
top_p_cpu = self.top_p_cpu_tensor[:num_reqs],
|
||||
top_k_cpu = torch.tensor([40 for _ in range(num_reqs)]).to(torch.int32),
|
||||
)
|
||||
1605
vllm_vacc/vllm/v1/worker/vacc_model_runner.py
Normal file
1605
vllm_vacc/vllm/v1/worker/vacc_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
340
vllm_vacc/vllm/v1/worker/vacc_worker.py
Normal file
340
vllm_vacc/vllm/v1/worker/vacc_worker.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""A VACC worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from importlib import util
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
|
||||
get_dtype_size)
|
||||
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput, AsyncModelRunnerOutput
|
||||
# from vllm.v1.worker.cpu_model_runner import VACCModelRunner
|
||||
from vllm_vacc.vllm.v1.worker.vacc_model_runner import VACCModelRunner
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
# from vllm.worker.cache_engine import CacheEngine
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
from vllm.utils import GiB_bytes
|
||||
|
||||
TP_GROUP_ID = 1234
|
||||
def generate_rank_info_list():
|
||||
global TP_GROUP_ID
|
||||
from vllm.distributed import get_tp_group
|
||||
# generate ran
|
||||
get_tp_group().generate_rank_device_infos()
|
||||
get_tp_group().generate_group_id(TP_GROUP_ID)
|
||||
|
||||
def generate_tp_group_id():
|
||||
global TP_GROUP_ID
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
workspace_path = Path.cwd()
|
||||
|
||||
bootinfo_config = f'{workspace_path}/.bootinfos'
|
||||
bootinfo_inited = os.path.exists(bootinfo_config)
|
||||
|
||||
current_bootinfos = "default"
|
||||
if bootinfo_inited:
|
||||
try:
|
||||
with open(bootinfo_config) as w:
|
||||
current_bootinfos = w.readline()
|
||||
except Exception as e:
|
||||
print("[WARN] bootinfo load fail ", e)
|
||||
|
||||
if current_bootinfos is not None:
|
||||
unique_value = uuid.uuid5(uuid.NAMESPACE_URL, current_bootinfos).int
|
||||
|
||||
int32_value = unique_value & 0xFFFFFFFF
|
||||
if int32_value >= 2**31:
|
||||
int32_value -= 2**32
|
||||
TP_GROUP_ID = int32_value
|
||||
# print("current_bootinfos:", current_bootinfos, TP_GROUP_ID)
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
vllm_config: VllmConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
backend: str = "vccl",
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank, backend)
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.decode_context_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
generate_tp_group_id()
|
||||
generate_rank_info_list()
|
||||
|
||||
|
||||
def get_cache_block_size(
|
||||
cache_config: CacheConfig,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> int:
|
||||
head_size = model_config.get_head_size()
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
dtype = model_config.dtype
|
||||
else:
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
key_cache_entry = num_heads * head_size
|
||||
|
||||
# For MLA there is no value cache, since the latent vector
|
||||
# is joint keys and values.
|
||||
value_cache_entry = key_cache_entry if not model_config.use_mla else 0
|
||||
total = num_attention_layers * cache_config.block_size * \
|
||||
(key_cache_entry + value_cache_entry)
|
||||
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
return dtype_size * total
|
||||
|
||||
|
||||
class VACCWorker(Worker):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False):
|
||||
super().__init__(vllm_config,
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
self.parallel_config.disable_custom_all_reduce = True
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "vacc":
|
||||
try:
|
||||
self.device = torch.device(f"vacc:{self.local_rank}")
|
||||
torch.vacc.set_device(self.device)
|
||||
gc.collect()
|
||||
torch.vacc.empty_cache()
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"device init fail: {e} ",
|
||||
f"self.device: {self.device}, check /dev/* or VACC_VISIBLE_DEVICES")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner: VACCModelRunner = VACCModelRunner(
|
||||
self.vllm_config, self.device)
|
||||
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
logger.warning("sleep mode is not supported on VACC, ignore it.")
|
||||
pass
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
logger.warning("sleep mode is not supported on VACC, ignore it.")
|
||||
pass
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Get the size of the KV cache block size in bytes.
|
||||
"""
|
||||
return get_cache_block_size(self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Determine the number of available KV blocks.
|
||||
|
||||
Swapping is not yet supported, so always return num_cpu_blocks=0.
|
||||
|
||||
We configure num_gpu_blocks to be equal to max_num_seqs.
|
||||
"""
|
||||
available_kv_cache_memory_min, num_gpu_blocks = self.determine_available_memory_block()
|
||||
return available_kv_cache_memory_min
|
||||
|
||||
|
||||
def determine_available_memory_block(self) -> int:
|
||||
"""Determine the number of available KV blocks.
|
||||
|
||||
Swapping is not yet supported, so always return num_cpu_blocks=0.
|
||||
|
||||
We configure num_gpu_blocks to be equal to max_num_seqs.
|
||||
"""
|
||||
available_kv_cache_memory= int(os.getenv("VLLM_VACC_KVCACHE_SPACE", "16")) * GiB_bytes
|
||||
|
||||
max_seq_num = int(os.getenv("MAX_SEQ_NUM", 4))
|
||||
max_num_gpu_blocks=0
|
||||
if available_kv_cache_memory ==0:
|
||||
torch.vacc.empty_cache()
|
||||
torch.vacc.reset_peak_memory_stats()
|
||||
total_memory = torch.vacc.mem_get_info()[1]
|
||||
self.model_runner.profile_run()
|
||||
torch.vacc.synchronize()
|
||||
peak_memory = torch.vacc.max_memory_allocated()
|
||||
torch.vacc.empty_cache()
|
||||
torch_allocated_bytes = torch.vacc.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
total_allocated_bytes = torch.vacc.mem_get_info(
|
||||
)[1] - torch.vacc.mem_get_info()[0]
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
if non_torch_allocations > 0:
|
||||
peak_memory += non_torch_allocations
|
||||
available_kv_cache_memory=total_memory*self.cache_config.gpu_memory_utilization - peak_memory
|
||||
|
||||
if self.model_config.hf_config.model_type == "deepseek_v3":
|
||||
assert self.model_config.max_model_len <= 65536*self.vllm_config.parallel_config.pipeline_parallel_size, f"unsupported max model len, should less equal 65536 but got {self.model_config.max_model_len}"
|
||||
# Rules:
|
||||
# 1. always reserve N * 8K blocks
|
||||
# 2. no less than (MAX_SEQ_NUM + 1) * 8K blocks
|
||||
minimum_num_gpu_blocks_required = (max_seq_num + 1) * env_blk_grp_size // self.cache_config.block_size
|
||||
max_model_len = (self.model_config.max_model_len + env_blk_grp_size - 1) // env_blk_grp_size * env_blk_grp_size
|
||||
max_num_gpu_blocks = max_model_len // self.cache_config.block_size
|
||||
|
||||
# limited by available_kv_cache_memory
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
if cache_block_size == 0:
|
||||
num_gpu_blocks = 0
|
||||
num_cpu_blocks = 0
|
||||
else:
|
||||
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
|
||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||
cache_block_size)
|
||||
assert num_gpu_blocks >= minimum_num_gpu_blocks_required, \
|
||||
f"num_gpu_blocks should >= {minimum_num_gpu_blocks_required} please increase VLLM_VACC_KVCACHE_SPACE"
|
||||
torch.vacc.empty_cache()
|
||||
# if self.model_runner.lora_manager:
|
||||
# self.model_runner.remove_all_loras()
|
||||
gc.collect()
|
||||
if max_num_gpu_blocks != 0:
|
||||
num_gpu_blocks = min(max_num_gpu_blocks, num_gpu_blocks)
|
||||
|
||||
num_gpu_blocks = max(num_gpu_blocks, minimum_num_gpu_blocks_required)
|
||||
available_kv_cache_memory_min = num_gpu_blocks * cache_block_size
|
||||
|
||||
return available_kv_cache_memory_min, num_gpu_blocks
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
# self.model_runner.warming_up_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
|
||||
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
||||
return output
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return None
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
|
||||
"""Return VACCs id binding based on NUMA nodes.
|
||||
"""
|
||||
rank_to_cpus = self.local_omp_cpuid
|
||||
# Setup OpenMP thread affinity based on NUMA nodes automatically
|
||||
world_size = self.vllm_config.parallel_config.world_size
|
||||
libnuma_found = util.find_spec("numa") is not None
|
||||
psutil_found = util.find_spec("psutil") is not None
|
||||
if libnuma_found and psutil_found:
|
||||
import psutil
|
||||
from numa import info
|
||||
cpu_count = psutil.cpu_count(logical=False)
|
||||
cpus_allow_list = psutil.Process().cpu_affinity()
|
||||
numa_size = info.get_num_configured_nodes()
|
||||
cpu_count_per_numa = cpu_count // numa_size
|
||||
num_of_reserved_cpu = min(envs.VLLM_VACC_NUM_OF_RESERVED_VACC,
|
||||
cpu_count_per_numa // 2)
|
||||
|
||||
# check allow node_to_cpus list
|
||||
node_to_cpus = []
|
||||
for i in range(numa_size):
|
||||
node_intersect = set(
|
||||
info.node_to_cpus(i)).intersection(cpus_allow_list)
|
||||
if bool(node_intersect):
|
||||
node_to_cpus.append(list(node_intersect))
|
||||
|
||||
if world_size > len(node_to_cpus):
|
||||
logger.error(
|
||||
"Auto thread-binding failed due to "
|
||||
"world size: %d is larger than "
|
||||
"allowed NUMA nodes number: %d."
|
||||
"Please try to bind threads manually.", world_size,
|
||||
len(node_to_cpus))
|
||||
else:
|
||||
end = cpu_count_per_numa - num_of_reserved_cpu
|
||||
rank_to_cpus_list = node_to_cpus[self.rank][:end]
|
||||
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
|
||||
logger.info("auto thread-binding list: %s", rank_to_cpus)
|
||||
else:
|
||||
logger.warning(
|
||||
"Auto thread-binding is not supported due to "
|
||||
"the lack of package numa and psutil,"
|
||||
"fallback to no thread-binding. To get better performance,"
|
||||
"please try to manually bind threads.")
|
||||
return rank_to_cpus
|
||||
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
Reference in New Issue
Block a user